import os
import json
import random
from PIL import Image, ImageFile
from pycocotools.coco import COCO
from eval.chair import CHAIR
from PIL import Image,ImageDraw
from math import ceil
import numpy as np
from tqdm import tqdm
import torch
from datasets import *
from utils import get_model, add_diffusion_noise, get_layers
import matplotlib.pyplot as plt
import seaborn as sns
from ilvad import AttentionEnhancerContext




def get_image_position(num_img_tokens,image_token_id, input_ids):
    if num_img_tokens is None:
        num_img_tokens = (input_ids == image_token_id).nonzero().shape[0]

    # llava
    img_start = (input_ids == image_token_id).nonzero()[0, 1].item()
    respond_start = input_ids.shape[1] - (input_ids == image_token_id).nonzero().shape[0] + num_img_tokens
    img_token_end = (input_ids == image_token_id).nonzero()[-1, 1].item() + 1

    img_end = img_start + num_img_tokens

    return img_start, img_end, respond_start, img_token_end


def compute_attention_heatmap(attn: torch.Tensor, scale: float) -> torch.Tensor:

    heatmap = attn
    mean, std = heatmap.mean(), heatmap.std()
    lower, upper = mean - 3 * std, mean + 3 * std
    heatmap = torch.clamp(heatmap, min=lower.item(), max=upper.item()) 
    heatmap = ((heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())) * scale

    return torch.exp(heatmap)

def get_enhanced_map(output_attentions, img_start, img_end, scale, tau):
 
    l = len(output_attentions)
    layer_attentions = []
    all_attn = []
    
    img_num = img_end - img_start
    
    device = output_attentions[0][0].device

    num_heads = len(output_attentions[0][0][0])
    k = int(num_heads * 0.5)

 
    for layer in range(len(output_attentions[0])):

        attn = torch.zeros(img_num, device=device)

        for t in range(1, l):
            attentions = output_attentions[t][layer]
            img_attentions = attentions[0, :, -1, img_start:img_end]
            head_importance = img_attentions.sum(dim=-1)
            top_indices = torch.topk(head_importance, k=k).indices
            selected_attn = img_attentions[top_indices, :]
            step_saliency = torch.mean(selected_attn, dim=0)
            
            attn += step_saliency  
        
        attn = attn / (len(output_attentions[0][0][0][0])-img_end)
        
        all_attn.append(attn)
        
        mean_value = attn.mean()
        binary_attn = (attn > tau * mean_value).float()
        layer_attentions.append(binary_attn)

    all_attn_tensor = torch.stack(all_attn)
    avg_attn = all_attn_tensor.mean(dim=0)
    avg = avg_attn.mean()
    
    bi_attn = (avg_attn > avg).float()

    layer_attentions_tensor = torch.stack(layer_attentions)
    layer_diffs = layer_attentions_tensor[1:] - layer_attentions_tensor[:-1]
    positive_diffs = torch.maximum(layer_diffs, torch.tensor(0.0, device=device)) 
    temp_flow = torch.sum(positive_diffs, dim=0)
    enchanced_map = temp_flow * bi_attn
    
    return compute_attention_heatmap(enchanced_map, scale)


model_name = "llava-v1.5-7b-hf" 

model_path = ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
attn_impl = "eager"
max_new_tokens = 512
Alpha = [5]
beta = 2
Gamma = [1]
TAU = [5]
T = [10]

generate_kwargs = {
    "num_beams": 1,
    "do_sample": False,
    "max_new_tokens": max_new_tokens,
}
processor, model, template, ans_start, num_img_tokens, image_token_id = get_model(model_path, device, attn_impl)

dataset = ["chair","pope_popular","pope_random", "pope_adversarial"] 

for dataset_name in dataset:
    for t in T:
        alpha = 5
        tau = 5
        if "chair" in dataset_name:
            gamma=0.2
        else:
            gamma = 1
        scale = alpha


        output_file_path = f"" 
        os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
        output_file = open(output_file_path, "w")

        output_file_path1 = f""
        os.makedirs(os.path.dirname(output_file_path1), exist_ok=True)
        output_file1 = open(output_file_path1, "w")

        seed = 0
        random.seed(seed)
        if "chair" in dataset_name:
            dataset = CHAIRBench(500)
        else:
            dataset = POPE(dataset_name)


        for qid in tqdm(dataset.all):
            img, qu, gt = dataset[qid]
                
                
            prompt = template.format(question=qu)
            inputs = processor(prompt, img, return_tensors="pt").to(model.device)
                   

            img_start, img_end, respond_start, img_token_end = get_image_position(num_img_tokens, image_token_id, inputs['input_ids'])
                
            with torch.inference_mode():
                outputs = model.generate(**inputs, max_new_tokens = t, output_attentions=True, output_scores=True, return_dict_in_generate=True)
                output_attentions = outputs.attentions
                
            enhanced_map = get_enhanced_map(output_attentions, img_start, img_end, scale, tau)
        
            layers_to_modify=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]

            enhancer = AttentionEnhancerContext(
                model=model,
                img_start=img_start,  
                img_end=img_end,  
                respond_start=respond_start, 
                enhanced_map=enhanced_map,
                layers_to_modify=layers_to_modify,
                beta = beta,
                gamma = gamma,
                model_name=model_name
            )

            with enhancer:
                # 在这个块内，model.generate() 会使用被修改的注意力机制
                with torch.inference_mode():
                    outputs = model.generate(**inputs, **generate_kwargs, output_attentions=True, output_scores=True, return_dict_in_generate=True)
                    output_attentions = outputs.attentions
                    

                gen_answer = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
                gen_answer = gen_answer.split(ans_start)[1].strip() if ans_start is not None else gen_answer.strip()
                probs = outputs.scores[0][0].softmax(-1)
                ans_conf = probs.topk(1)[0].cpu().numpy()[0]
                ans_conf = str(ans_conf * 100)[:5]
                output_file.write(json.dumps({"question_id": qid, "text": gen_answer, "gt": gt, "ans_conf": ans_conf, "question": qu}) + "\n")
           
            output_file.close()
