# eval_with_attn.py
import os,sys
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import re
import json
import random
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from transformers import LlamaForCausalLM, LlamaTokenizer
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, IMAGE_TOKEN_INDEX
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.utils import disable_torch_init
from llava.conversation import conv_templates
from llava.model.builder import load_pretrained_model
from FICD.get_attn.attn_scores import LlamaAttentionWithLogits

# ========== 配置区 ==========
MODEL_PATH      = "/llava-v1.5-7b"
IMAGE_FOLDER    = "/coco_val2014"
OUTPUT_DIR      = "/results"
SEED            = 42
REGION_START    = 35   
REGION_END      = 611   

os.makedirs(OUTPUT_DIR, exist_ok=True)
disable_torch_init()
random.seed(SEED)


def load_image(img_path):
    return Image.open(img_path).convert("RGB")

def patch_model_for_logits(model):
    """
    use LlamaAttentionWithLogits replace self_attn
    """
    dtype = next(model.parameters()).dtype  
    device = next(model.parameters()).device  
    if hasattr(model.model, 'layers'):
        layers = model.model.layers
    elif hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers'):
        layers = model.model.decoder.layers
    else:
        raise ValueError("Can not get model")
    for i, layer in enumerate(layers):
        orig = layer.self_attn
        new = LlamaAttentionWithLogits(orig.config, layer_idx=i)
        new.load_state_dict(orig.state_dict(), strict=False)
        new = new.to(dtype=dtype, device=device)
        layer.self_attn = new


def extract_image_region_attn(model, attentions, v_st, v_end):
    """
    attentions: List of length = num_tokens_generated, each 为 List[layers] of
                tensors [batch, heads, 1, seq_len]
    v_st/v_end: image token index range（inclusive）
    return: shape = [layers, heads], avg attn of image tokens
    """
    # last token 
    last = attentions[-1]  # List[layers] of [1, heads, 1, seq_len]
    layers = len(last)
    heads = last[0].shape[1]

    out_attn_scores = np.zeros((layers,heads), dtype=np.float32)
    out_attn_weights = np.zeros((layers,heads), dtype=np.float32)

    for l in range(layers):
        layer_attn_module = model.model.layers[l].self_attn
        attn_scores = layer_attn_module.get_attn_scores()
        if attn_scores is not None:
            # attn_scores
            last_scores = attn_scores
            wsp = last_scores.squeeze(0).squeeze(1)       # [heads, seq_len]
            # get image region
            img_w = wsp[:, v_st : v_end]          # [heads, img_len]
            # get avg attn
            head_means = img_w.mean(axis=1).cpu().numpy()
            out_attn_scores[l] = head_means

            # attn_weights = softmax(attn_scores)
            wsp_aw = last[l].squeeze(0).squeeze(1)       # [heads, seq_len]
            # get image region
            img_w_weights = wsp_aw[:, v_st : v_end]          # [heads, img_len]
            # get avg attn
            head_means_weihts = img_w_weights.mean(axis=1).cpu().numpy()
            out_attn_weights[l] = head_means_weihts

    return torch.tensor(out_attn_scores, dtype=torch.float32), torch.tensor(out_attn_weights, dtype=torch.float32) # [layers,heads]


def run_two_prompts(model, tokenizer, image_tensor, image_size, base_prompt, pert_prompt):
    """
    generate twice return：
      - texts: (base_text, pert_text)
      - atts : (base_attn[layers], pert_attn[layers])
    """
    prompt_list = [base_prompt, pert_prompt]
    input_ids_list = []
    for qs in prompt_list:
        image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
        if IMAGE_PLACEHOLDER in qs:
            if model.config.mm_use_im_start_end:
                qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
            else:
                qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
        else:
            if model.config.mm_use_im_start_end:
                qs = image_token_se + "\n" + qs
            else:
                qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

        conv_mode = "vicuna_v1"
        conv = conv_templates[conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        #print(prompt)

        # image token + prompt
        input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .to(model.device)
                    )
        input_ids_list.append(input_ids)
        
    
    # 生成 base
    with torch.inference_mode():
        out1 = model.generate(
            input_ids_list[0],
            images=image_tensor,
            image_sizes=image_size,
            do_sample=False,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=512,
            use_cache=False
            )
    text1 = tokenizer.batch_decode(out1.sequences, skip_special_tokens=True)[0]
    attn1_s, attn1_w = extract_image_region_attn(model, out1.attentions, REGION_START, REGION_END)

    # perturbed
    with torch.inference_mode():
        out2 = model.generate(
            input_ids_list[1],
            images=image_tensor,
            image_sizes=image_size,
            do_sample=False,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=512,
            use_cache=False
        )
    text2 = tokenizer.batch_decode(out2.sequences, skip_special_tokens=True)[0]
    attn2_s, attn2_w = extract_image_region_attn(model, out2.attentions, REGION_START, REGION_END)

    return (text1, text2), (attn1_s, attn2_s), (attn1_w, attn2_w)


def main():
    # 1. load model
    model_name = get_model_name_from_path(MODEL_PATH)
    tokenizer, model, image_processor, _ = load_pretrained_model(
        model_path=MODEL_PATH,
        model_base=None,
        model_name=model_name,
        torch_dtype=torch.float16,
    )
    model.to(model.device)
    patch_model_for_logits(model)

    # 2. input data
    input_json ='original_reslult.json'    
    data = json.load(open(input_json, "r", encoding="utf-8"))

    results = []
    attn_dict = {}
    attn_weight = {}

    for item in tqdm(data, desc="Inference"):
        img_name = item["image"]
        img = load_image(os.path.join(IMAGE_FOLDER, img_name))
        img_tensor = process_images([img], image_processor, model.config).to(model.device, dtype=torch.float16)
        img_size   = img.size

        # base prompt
        base = "Describe the image in detail."
        # random choice
        cand = [d for d in data if d["image"] != img_name]
        rand = random.choice(cand)
        # print(img_name, rand['original_description'])
        pert = f"Describe the image in detail. {rand['original_description']}"

        # 3. infer
        (t1, t2), (a1, a2), (w1, w2) = run_two_prompts(model, tokenizer, img_tensor, img_size, base, pert)

        # get results
        results.append({
            "image": img_name,
            "rand": rand['original_description'],
            "base_output": t1,
            "pert_output": t2,
        })
        attn_dict[img_name] = {
            "base_attn": a1,
            "pert_attn": a2,
        }
        attn_weight[img_name] = {
            "base_attn": w1,
            "pert_attn": w2,
        }

    # 4. save results
    json.dump(results,
              open(os.path.join(OUTPUT_DIR, "results_of_1k_whole_change.json"), "w", encoding="utf-8"),
              ensure_ascii=False, indent=2)

    torch.save(attn_dict, os.path.join(OUTPUT_DIR, "attn_scores_output_1k.pt"))
    torch.save(attn_weight, os.path.join(OUTPUT_DIR, "attn_weights_output_1k.pt"))
    print(f"Done.files saved in {OUTPUT_DIR}")
if __name__ == "__main__":
    main()
