#!/usr/bin/env python3
# attention_manipulation.py

import os
import csv
import argparse
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
import random

from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from qwen_omni_utils import process_mm_info

# ------------------------------- CLI -------------------------------- #
p = argparse.ArgumentParser(
    description="Run Qwen2.5-Omni with per-head temperature on a SELECTED subset of heads; save only CSV outputs."
)
p.add_argument("--type", default="independent")
p.add_argument("--mod_order", type=str, default=None,
               help="Permutation of I/A/T, e.g. 'IAT','AIT','TAI' (fixes which slot each modality occupies).")
p.add_argument("--setting", choices=["vanilla", "layer"], default="vanilla",
               help="Which subset of heads to modify.")
p.add_argument("--layer", choices=["bottom", "middle", "top"], default="bottom", help="select layers when setting=layer")
p.add_argument("--temp_mode", choices=["trust", "anti"], default="trust",
               help="Mapping of importance -> temperature for non-vanilla settings.")
p.add_argument("--scale", type=float, default=0.0, help="Manipulate amount")

args = p.parse_args()

BASE_DIR   = "/path/to/dataset"
QUEST_PATH = os.path.join(BASE_DIR, f"reasoning_meta/reasoning_{args.type}_dataset.csv")
ASSET_PATH = os.path.join(BASE_DIR, f"assets/multimodal_datasets_{args.type}.csv")
if args.mod_order is not None:
    OUT_DIR = (f"/path/to/output/"
               f"{args.type}_{args.mod_order}_{args.setting}_{args.temp_mode}_{args.layer}_{args.scale}")
else:
    OUT_DIR = (f"/path/to/output/"
               f"{args.type}_{args.setting}_{args.temp_mode}_{args.layer}_{args.scale}")
os.makedirs(OUT_DIR, exist_ok=True)
MODEL_PATH = "/path/tp/Qwen2.5-Omni-7B"

SYSTEM_MSG = (
    "You are an assistant tasked with solving multiple-choice questions that require logical"
            " reasoning over the supplied knowledge diagrams."
            "Use only the information explicitly given—do not rely on outside or commonsense knowledge."
            "Read the question and given information, think step-by-step and answer the question."
            "At the end of your answer, answer precisely in the format 'Answer: X' where X is the chosen letter A / B / C / D."
)


random.seed(args.random_state)
np.random.seed(args.random_state)

def load_asset_dict(asset_csv, mod_order=None):
    asset_dict = {}
    with open(asset_csv, newline='', encoding='utf-8') as f:
        for row in csv.DictReader(f):
            sg = row["subgraph_id"]
            if mod_order is None:
                order = random.sample([1, 2, 3], k=3)
            else:
                order = [0, 0, 0]
                for i, m in enumerate(mod_order):
                    if m == 'I':
                        order[0] = i+1
                    elif m == 'A':
                        order[1] = i+1
                    elif m == 'T':
                        order[2] = i+1
                    else:
                        raise ValueError(f"Unknown modality character: {m}")
            asset_dict[sg] = {
                "img": row[f"modality{order[0]}_img"],
                "wav": row[f"modality{order[1]}_wav"],
                "txt": row[f"modality{order[2]}_txt"],
            }
    return asset_dict

ASSETS = load_asset_dict(ASSET_PATH, args.mod_order)

def make_conversation(row, use_image=True, use_audio=True, use_text=True):
    user_content = []
    if use_image: user_content.append({"type": "image", "image": row["info_img"]})
    if use_audio: user_content.append({"type": "audio", "audio": row["info_wav"]})
    if use_text:  user_content.append({"type": "text",  "text": row["info_text"]})
    random.shuffle(user_content)

    rules = row["rules"]
    user_content.append({"type": "text", "text": f"\nRules are as follows: {rules}\n"})
    user_content.append({"type": "text", "text": row["question_text"]})
    return [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_MSG}]},
        {"role": "user",   "content": user_content},
    ]

def _extract_L_H(model):
    L = int(model.config.thinker_config.text_config.num_hidden_layers)
    H = int(model.config.thinker_config.text_config.num_attention_heads)
    return L, H

def _select_heads_mask(L, H, setting, layer=None):
    total = L * H
    mask = np.zeros((L, H), dtype=bool)

    if setting == "vanilla":
        return mask

    elif setting == "layer":
        if layer is None:
            raise ValueError(f"Layer is not provided when setting is layer")
        if layer == "bottom":
            mask[:4,:]=True
        elif layer == "middle":
            mask[12:16, :] = True
        elif layer == "top":
            mask[-4:,:] = True
        return mask
    else:
        raise ValueError(f"Unknown setting: {setting}")


def _map_importance_to_temps_selected(mask, mode, scale):
    L, H = mask.shape
    temps = np.ones((L, H), dtype=np.float32)
    if not mask.any():
        return temps

    sel_vals = temps[mask]
    if mode == "trust":
        sel_vals = sel_vals*(1-scale)
    elif mode == "anti":
        sel_vals = sel_vals*(1+scale)
    else:
        raise ValueError(mode)

    temps[mask] = sel_vals.astype(np.float32)
    return temps

def _iter_text_layers(model):
    core = getattr(model, "thinker", None)
    if core is None:
        raise RuntimeError("Could not find model.thinker (text backbone)")
    layers = getattr(getattr(core, "model", core), "layers", None)
    if layers is None:
        raise RuntimeError("Could not locate text layers under model.thinker.model.layers")
    for i, block in enumerate(layers):
        attn = getattr(block, "self_attn", None)
        if attn is None:
            continue
        yield i, attn

def _register_q_proj_hooks(model, tempsLH):
    """tempsLH: (L,H). Scales q_proj output per head, per layer."""
    L, H = tempsLH.shape
    hooks = []
    for l, attn in _iter_text_layers(model):
        if l >= L:
            continue
        q_proj = getattr(attn, "q_proj", None)
        if q_proj is None:
            continue
        t_vec = torch.tensor(tempsLH[l], dtype=torch.float32)  # (H,)
        def make_hook(t_vec_local):
            def hook(module, inputs, output):
                out = output
                B, S, D = out.shape
                Hlocal = t_vec_local.numel()
                Hd = D // Hlocal
                out = out.view(B, S, Hlocal, Hd)
                tv = t_vec_local.to(out.device, dtype=out.dtype).view(1,1,Hlocal,1)
                # Divide queries by T per head: softmax(QK^T / sqrt(d) / T)
                out = out / tv
                return out.view(B, S, D)
            return hook
        h = q_proj.register_forward_hook(make_hook(t_vec))
        hooks.append(h)
    return hooks

print("Loading Qwen2.5-Omni …")
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="sdpa",
)
processor = Qwen2_5OmniProcessor.from_pretrained(MODEL_PATH)
print("✓ model ready")

_hooks = []
L, H = _extract_L_H(model)

if args.setting != "vanilla":
    maskLH = _select_heads_mask(L, H, args.setting, args.layer)        # (L,H)
    tempsLH = _map_importance_to_temps_selected(
        maskLH, mode=args.temp_mode, scale=args.scale,
    )                                          
    _hooks = _register_q_proj_hooks(model, tempsLH)
    num_sel = int(maskLH.sum())
    print(f"✓ {args.setting}: selected {num_sel} heads; temp_mode={args.temp_mode}; hooks={len(_hooks)}")
else:
    print("✓ vanilla: no temperature hooks registered")

# ------------------------------- Run -------------------------------- #
def run_qwen(row, use_image=True, use_audio=True, use_text=True):
    conversation = make_conversation(row, use_image, use_audio, use_text)
    prompt_template = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
    audios, images, _ = process_mm_info(conversation, use_audio_in_video=False)
    inputs = processor(
        text=prompt_template, images=images, audio=audios,
        return_tensors="pt", padding=True, use_audio_in_video=False
    ).to(model.device).to(model.dtype)

    with torch.no_grad():
        sequences = model.generate(**inputs, use_audio_in_video=False, return_audio=False)
    prompt_len = inputs["input_ids"].shape[1]
    reply_ids = sequences[:, prompt_len:]
    reply = processor.batch_decode(reply_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0].strip()
    return reply

def evaluate(use_image=True, use_audio=True, use_text=True):
    out_csv = os.path.join(OUT_DIR, f"model_results.csv")
    results = []

    with open(QUEST_PATH, newline="", encoding="utf-8") as f:
        data = list(csv.DictReader(f))

    for row in tqdm(data, desc="Running"):
        sg_id = row["id"]
        asset = ASSETS.get(sg_id)
        if asset is None:
            continue
        row["info_img"]  = asset["img"]
        row["info_wav"]  = asset["wav"]
        row["info_text"] = asset["txt"]

        pred = run_qwen(row, use_image, use_audio, use_text).strip()

        if args.type == "contradictory":
            results.append({
                "id":           row["id"],
                "rules":        row["rules"],
                "question":     row["questions"],
                "option_role_map": row.get("option_role_map", ""),
                "options":      row["options"],
                "model_answer": pred,
            })
        else:
            results.append({
                "id":           row["id"],
                "rules":        row["rules"],
                "question":     row["questions"],
                "options":      row["options"],
                "gt_answer":    row["correct_answer"],
                "model_answer": pred,
            })

    if results:
        with open(out_csv, "w", newline="", encoding="utf-8") as f:
            fieldnames = list(results[0].keys())
            w = csv.DictWriter(f, fieldnames=fieldnames)
            w.writeheader(); w.writerows(results)
        print(f"Saved {len(results)} results to {out_csv}")
    else:
        print("No results to save.")

if __name__ == "__main__":
    evaluate(True, True, True)
