#!/usr/bin/env python3
# extract_attention.py

import os
import csv
import argparse
from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm
import random

from transformers import (
    Qwen2_5OmniForConditionalGeneration,
    Qwen2_5OmniProcessor
)
from qwen_omni_utils import process_mm_info

random.seed(42)

# Paths
p = argparse.ArgumentParser(description="Extract per-layer attention vectors for fact parts and save as NPZ")
p.add_argument("--type", default="independent")
p.add_argument("--pooling", choices=["mean", "max", "none"], default="none")
p.add_argument("--mod_order", type=str, default=None,
               help="Fixed modality order in assets: any permutation of I (image), A (audio), T (text); e.g. 'IAT', 'AIT', 'TAI'")
args = p.parse_args()

BASE_DIR   = "/path/to/dataset"
QUEST_PATH = os.path.join(BASE_DIR, f"reasoning_meta/reasoning_{args.type}_dataset.csv")
if args.type == "recognition":
    ASSET_PATH = os.path.join(BASE_DIR, f"assets/multimodal_datasets_independent.csv")
else:
    ASSET_PATH = os.path.join(BASE_DIR, f"assets/multimodal_datasets_{args.type}.csv")
if args.mod_order is not None:
    OUT_DIR    = (f"/path/to/output/{args.type}_{args.pooling}_{args.mod_order}")
else:
    OUT_DIR    = (f"/path/to/output/{args.type}_{args.pooling}")
os.makedirs(OUT_DIR, exist_ok=True)
MODEL_PATH = "/path/to/Qwen2.5-Omni-7B"

SYSTEM_MSG = (
    "You are an assistant tasked with solving multiple-choice questions that require logical"
    " reasoning over the supplied knowledge diagrams."
    "Use only the information explicitly given—do not rely on outside or commonsense knowledge."
    "Read the question and given information, think step-by-step and answer the question."
    "At the end of your answer, answer precisely in the format 'Answer: X' where X is the chosen letter A / B / C / D."
)

# For saving modality vectors
MODALITY_TO_ID = {"text_fact": 0, "image_fact": 1, "audio_fact": 2, "rule": 3, "question": 4, "other": 5}
MODALITY_VOCAB = np.array(["text_fact", "image_fact", "audio_fact", "rule", "question", "other"], dtype=str)

# Utilities
def load_asset_dict(asset_csv, mod_order=None):
    asset_dict = {}
    with open(asset_csv, newline='', encoding='utf-8') as f:
        for row in csv.DictReader(f):
            sg = row["subgraph_id"]
            if mod_order is None:
                order = random.sample([1, 2, 3], k=3)
            else:
                order = [0, 0, 0]
                for i, m in enumerate(mod_order):
                    if m == 'I':
                        order[0] = i+1
                    elif m == 'A':
                        order[1] = i+1
                    elif m == 'T':
                        order[2] = i+1
                    else:
                        raise ValueError(f"Unknown modality character: {m}")
            asset_dict[sg] = {
                "img": row[f"modality{order[0]}_img"],
                "wav": row[f"modality{order[1]}_wav"],
                "txt": row[f"modality{order[2]}_txt"],
                "order_img": order[0],
                "order_wav": order[1],
                "order_txt": order[2],
            }
    return asset_dict

ASSETS = load_asset_dict(ASSET_PATH, args.mod_order)

def make_conversation(row, use_image, use_audio, use_text):
    """
    Build a Qwen-style conversation (= list of dicts) given one dataset row.
    """
    user_content = []

    if use_image:
        user_content.append({"type": "image", "image": row["info_img"]})
    if use_audio:
        user_content.append({"type": "audio", "audio": row["info_wav"]})
    if use_text:
        user_content.append({"type": "text",  "text": row["info_text"]})

    random.shuffle(user_content)

    if args.type != "recognition":
        rules = row["rules"]
        user_content.append({"type": "text", "text": f"\nRules are as follows: {rules}\n"})
    else:
        user_content.append({"type": "text", "text": f"\n"})
    user_content.append({"type": "text", "text": row["question_text"]})

    facts_txt = row["info_text"]
    if args.type != "recognition":
        rules_txt = f"\nRules are as follows: {rules}\n"
    else:
        rules_txt = f"\n"
    ques_txt = row["question_text"]

    return [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]},
        {"role": "user",   "content": user_content},
    ],facts_txt, rules_txt, ques_txt

def tokenize_plain(tokenizer, text: str):
    enc = tokenizer(text, add_special_tokens=False, return_attention_mask=False, return_token_type_ids=False)
    return enc["input_ids"]

def find_subsequence(haystack, needle):
    if not needle or len(needle) > len(haystack):
        return -1
    first = needle[0]
    i, max_i = 0, len(haystack) - len(needle)
    while i <= max_i:
        try:
            i = haystack.index(first, i)
        except ValueError:
            return -1
        if haystack[i:i+len(needle)] == needle:
            return i
        i += 1
    return -1

def build_prompt_labels_from_inputs(tok, seq_prompt_ids, facts_txt, rules_txt, ques_txt):
    """
    Return labels (len = len(seq_prompt_ids)) over the *prompt* tokens only.
    Uses native BOS/EOS for image/audio and subsequence match for text sections.
    """
    labels = np.array(["other"] * len(seq_prompt_ids), dtype=object)

    tid = tok.convert_tokens_to_ids
    V_BOS, V_EOS = tid("<|vision_bos|>"), tid("<|vision_eos|>")
    A_BOS, A_EOS = tid("<|audio_bos|>"),  tid("<|audio_eos|>")

    def mark_between(bos_id, eos_id, tag):
        i = 0
        N = len(seq_prompt_ids)
        while i < N:
            try:
                s = seq_prompt_ids.index(bos_id, i)
            except ValueError:
                break
            try:
                e = seq_prompt_ids.index(eos_id, s + 1)
            except ValueError:
                e = N
            if e > s + 1:
                labels[s+1:min(e, N)] = tag 
            i = e + 1

    mark_between(V_BOS, V_EOS, "image_fact")
    mark_between(A_BOS, A_EOS, "audio_fact")

    def mark_text_span(txt, tag):
        ids = tokenize_plain(tok, txt)
        start = find_subsequence(seq_prompt_ids, ids)
        if start >= 0:
            end = min(start + len(ids), len(seq_prompt_ids))
            for j in range(start, end):
                if labels[j] == "other":
                    labels[j] = tag

    mark_text_span(facts_txt, "text_fact")
    mark_text_span(rules_txt, "rule")
    mark_text_span(ques_txt,  "question")

    return labels


# Model setup
print("Loading Qwen-2.5-Omni …")
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa",
)
processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)
print("✓ model ready")

def run_qwen(row, use_image, use_audio, use_text, pooling='mean'):
    conversation, facts_txt, rules_txt, ques_txt = make_conversation(row, use_image, use_audio, use_text)
    prompt_template = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)

    audios, images, _ = process_mm_info(conversation, use_audio_in_video=False)
    inputs = processor(
        text=prompt_template, images=images, audio=audios,
        return_tensors="pt", padding=True, use_audio_in_video=False
    ).to(model.device).to(model.dtype)

    tok = processor.tokenizer
    im_end = tok.convert_tokens_to_ids("<|im_end|>")

    with torch.no_grad():
        gen = model.generate(**inputs,
                                use_audio_in_video=False, return_audio=False)

    prompt_len = inputs["input_ids"].shape[1]
    all_ids = gen
    reply_ids = all_ids[:, prompt_len:]
    reply = processor.batch_decode(reply_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip()

    full_inputs = dict(inputs)
    full_inputs["input_ids"] = all_ids
    attn_mask = inputs["attention_mask"]
    pad = torch.ones((attn_mask.shape[0], all_ids.shape[1] - attn_mask.shape[1]),
                     dtype=attn_mask.dtype, device=attn_mask.device)
    full_inputs["attention_mask"] = torch.cat([attn_mask, pad], dim=1)
    out_full = model.thinker(**full_inputs, use_audio_in_video=False, output_attentions=True)

    raw_attn = torch.stack(out_full.attentions, dim=0).squeeze(1)
    L, H, S, _ = raw_attn.shape

    # pool heads
    if pooling == 'mean':
        attn = raw_attn.mean(dim=1)
        headwise = False
    elif pooling == 'max':
        attn_max = raw_attn.amax(dim=1)
        row_sums = attn_max.sum(dim=-1, keepdim=True).clamp_min(1e-9)
        attn = attn_max / row_sums
        headwise = False
    elif pooling == 'none':
        headwise = True
    else:
        raise ValueError(f"Unknown pooling: {pooling}")


    # --- labels over the ACTUAL prompt token IDs (this gets image/audio spans >1) ---
    seq_prompt_ids = inputs["input_ids"][0][:prompt_len].tolist()
    labels_prompt = build_prompt_labels_from_inputs(tok, seq_prompt_ids, facts_txt, rules_txt, ques_txt)

    # extend to full sequence (generated = "other")
    labels_all = np.array(["other"] * S, dtype=object)
    labels_all[:prompt_len] = labels_prompt

    # masks
    device = model.device
    gen_mask = torch.zeros(S, dtype=torch.bool, device=device); gen_mask[prompt_len:] = True
    def to_mask(tag):
        m = (labels_all[:prompt_len] == tag)
        out = torch.zeros(S, dtype=torch.bool, device=device)
        if m.size:
            out[:prompt_len] = torch.from_numpy(m).to(device=device)
        return out

    k_text     = to_mask("text_fact")
    k_image    = to_mask("image_fact")
    k_audio    = to_mask("audio_fact")
    k_rule     = to_mask("rule")
    k_question = to_mask("question")
    k_other    = to_mask("other")

    def layer_sums_and_means_layerwise(key_mask):
        """Returns (L,), (L,) for pooled (mean/max) heads."""
        qn = int(gen_mask.sum().item())
        kn = int(key_mask.sum().item())
        if qn == 0 or kn == 0:
            return (np.zeros((L,), dtype=np.float32),
                    np.zeros((L,), dtype=np.float32))
        totals = torch.stack([attn[l][gen_mask][:, key_mask].sum() for l in range(L)]).float()  # (L,)
        per_query_sum  = (totals / float(qn)).detach().cpu().numpy().astype(np.float32)
        per_token_mean = (totals / float(qn * kn)).detach().cpu().numpy().astype(np.float32)
        return per_query_sum, per_token_mean

    def layer_sums_and_means_headwise(key_mask):
        """Returns (L*H,), (L*H,) when not pooling heads."""
        qn = int(gen_mask.sum().item())
        kn = int(key_mask.sum().item())
        if qn == 0 or kn == 0:
            z = np.zeros((L * H,), dtype=np.float32)
            return z, z
        sums_list  = []
        means_list = []
        for l in range(L):
            # block: (H, qn, kn)
            block = raw_attn[l][:, gen_mask][:, :, key_mask].float()
            sums_h = block.sum(dim=(1, 2))                        # (H,)
            sums_list.append((sums_h / float(qn)).detach().cpu().numpy())
            means_list.append((sums_h / float(qn * kn)).detach().cpu().numpy())
        sums = np.concatenate(sums_list, axis=0).astype(np.float32)   # (L*H,)
        means = np.concatenate(means_list, axis=0).astype(np.float32) # (L*H,)
        return sums, means

    if headwise:
        reducer = layer_sums_and_means_headwise
    else:
        reducer = layer_sums_and_means_layerwise

    sum_vecs  = {}
    mean_vecs = {}

    sum_vecs  = {}
    mean_vecs = {}

    sum_vecs["text_fact"],  mean_vecs["text_fact"]  = reducer(k_text)
    sum_vecs["image_fact"], mean_vecs["image_fact"] = reducer(k_image)
    sum_vecs["audio_fact"], mean_vecs["audio_fact"] = reducer(k_audio)
    sum_vecs["rule"],       mean_vecs["rule"]       = reducer(k_rule)
    sum_vecs["question"],   mean_vecs["question"]   = reducer(k_question)
    sum_vecs["other"],      mean_vecs["other"]      = reducer(k_other)

    return sum_vecs, mean_vecs, reply

def evaluate(use_image, use_audio, use_text, pooling='mean'):
    COMBO      = "_".join(l for l, flag in zip(["Image","Audio","Text"], (use_image, use_audio, use_text)) if flag)
    out_csv    = os.path.join(OUT_DIR, f"model_results.csv")
    output_npz = os.path.join(OUT_DIR, f"attention_vectors.npz")
    results    = []

    with open(QUEST_PATH, newline="", encoding="utf-8") as f:
        data = list(csv.DictReader(f))

    sum_vectors = []
    mean_vectors = []
    modality = []
    sample_id = []
    fact_slots = []

    for row in tqdm(data, desc=f"Running"):
        sg_id = row["id"]
        asset = ASSETS.get(sg_id)
        if asset is None:
            continue

        row["info_img"]  = asset["img"]
        row["info_wav"]  = asset["wav"]
        row["info_text"] = asset["txt"]

        row["ord_img"] = asset["order_img"]
        row["ord_wav"] = asset["order_wav"]
        row["ord_txt"] = asset["order_txt"]

        sum_vecs, mean_vecs, pred = run_qwen(row, use_image, use_audio, use_text, pooling=pooling)
        pred = pred.strip()

        if args.type == "contradictory":
            results.append({
                "id":           row["id"],
                "rules":        row["rules"],
                "question":     row["questions"],
                "option_role_map": row["option_role_map"],
                "options":      row["options"],
                "model_answer": pred,
            })
        else:
            results.append({
                "id":           row["id"],
                "rules":        row["rules"],
                "question":     row["questions"],
                "options":      row["options"],
                "gt_answer":    row["correct_answer"],
                "model_answer": pred,
            })

        for name in ("text_fact", "image_fact", "audio_fact", "rule", "question", "other"):
            sum_vectors.append(sum_vecs[name])
            mean_vectors.append(mean_vecs[name])
            modality.append(MODALITY_TO_ID[name])
            sample_id.append(row["id"])
            slot = {
                "text_fact":  row["ord_txt"],
                "image_fact": row["ord_img"],
                "audio_fact": row["ord_wav"],
            }.get(name, -1)
            fact_slots.append(slot)

    # Save predictions
    if results:
        with open(out_csv, "w", newline="", encoding="utf-8") as f:
            fieldnames = list(results[0].keys())
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader(); w.writerows(results)
        print(f"Saved {len(results)} results to {out_csv}")
    else:
        print("No results to save.")

    # Stack & save vectors
    if sum_vectors:
        X_sum = np.stack(sum_vectors, axis=0).astype(np.float32)  # (N, L)
        X_mean = np.stack(mean_vectors, axis=0).astype(np.float32)  # (N, L)
        y = np.asarray(modality, dtype=np.int8)           # (N,)
        sids = np.asarray(sample_id, dtype=str)           # (N,)
        slots = np.asarray(fact_slots, dtype=np.int8)

        out_path = Path(output_npz)
        out_path.parent.mkdir(parents=True, exist_ok=True)
        np.savez_compressed(
            out_path,
            sum_vectors=X_sum,
            mean_vectors=X_mean,
            modality=y,
            sample_id=sids,
            modality_vocab=MODALITY_VOCAB,
            fact_slot=slots,
        )
        print(f"Saved {X_sum.shape[0]} vectors × {X_sum.shape[1]} layers to {out_path}")

if __name__ == "__main__":
    combos = [(True,True,True)]
    for use_image, use_audio, use_text in combos:
        evaluate(use_image, use_audio, use_text, pooling=args.pooling)
