import argparse
import torch
import os
import json
from tqdm import tqdm

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

from PIL import Image
import math

from AttnGrad import AttnGrad, compute_significance
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import os

def save_heatmap(norm_sig: np.ndarray,
                 out_path: str = "attn_grad_heatmap.png",
                 cmap: str = "magma",
                 clip_percent: float = 1.0) -> None:
    """
    norm_sig : (n_layer, n_head) numpy array, values in [0, 1] (or any range)
    out_path : output file path (PNG)
    cmap     : matplotlib colormap name
    clip_percent : percent to clip on each tail for outlier suppression (e.g. 1 → clip 1% low & 1% high)
    """

    # 1. optional outlier clipping
    if clip_percent > 0:
        low_p  = np.percentile(norm_sig, clip_percent)
        high_p = np.percentile(norm_sig, 100 - clip_percent)
        norm_sig = np.clip(norm_sig, low_p, high_p)

    # 2. create output dir
    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    # 3. plot
    n_layer, n_head = norm_sig.shape
    plt.figure(figsize=(n_head * 0.25 + 2, n_layer * 0.25 + 2))
    plt.imshow(norm_sig, aspect='equal', cmap=cmap)
    plt.colorbar(label='Normalized significance (clipped)')
    plt.xlabel('Head')
    plt.ylabel('Layer')
    plt.title(f'Normalized |grad × attn| (cmap={cmap}, clip={clip_percent}%)')

    # 4. save
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()
    print(f"Heatmap saved to {out_path}")


def split_list(lst, n):
    """Split a list into n (roughly) equal-sized chunks"""
    chunk_size = math.ceil(len(lst) / n)  # integer division
    return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

def get_chunk(lst, n, k):
    chunks = split_list(lst, n)
    return chunks[k]

def eval_model(args):
    # Model
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)

    n_layers   = len(model.model.layers)
    n_heads    = model.config.num_attention_heads
    sig_sum    = torch.zeros(n_layers, n_heads, dtype=torch.float64) 
    sig_count  = torch.zeros(n_layers, dtype=torch.long) 


    if args.use_attn_back == True:
        for i, layer in enumerate(model.model.layers):
            attn_adap = AttnGrad(layer.self_attn.config)
            attn_adap.load_state_dict(layer.self_attn.state_dict())
            layer.self_attn = attn_adap.half().cuda()

    questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
    questions = get_chunk(questions, args.num_chunks, args.chunk_idx)

    for line in tqdm(questions):
        
        idx = line["question_id"]
        image_file = line["image"]
        qs = line["text"]
        label = line["label"]

        label = line['label']
        label = tokenizer.convert_tokens_to_ids(label)
        label = torch.tensor(label, dtype=torch.int64)
        label = label.cuda()

        if model.config.mm_use_im_start_end:
            qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

        qs = qs + "\nAnswer the question using a single word or phrase."

        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

        image = Image.open(os.path.join(args.image_folder, image_file))
        image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

        model.zero_grad()
        output_ids = model.forward(
            input_ids,
            images=image_tensor.unsqueeze(0).half().cuda(),
            use_cache=True,
            output_attentions=True,
            return_dict=True,
        )

        pred_logit = output_ids['logits'][:,-1,:].squeeze(0)
        loss = F.cross_entropy(pred_logit, label)
        loss.backward()

        for l_idx, layer in enumerate(model.model.layers):
             # (bsz, head, q, kv)
            grads  = layer.self_attn.attn_grads.cpu()        
            attns  = layer.self_attn.attn_weights.detach().cpu()
            if grads is None or attns is None:
                continue  
            sig_b = compute_significance(grads, attns)
            sig_sum[l_idx] += sig_b.sum(dim=0).double()    
            sig_count[l_idx] += sig_b.size(0)

    mean_sig = sig_sum / sig_count[:, None].clamp(min=1)

    # # --- per-layer sum over all heads ---
    # layer_sum = mean_sig.sum(dim=1)           # (n_layer,)

    # for idx, val in enumerate(layer_sum.tolist()):
    #     print(f"Layer {idx:02d}  total Σ|grad*attn| = {val:.6e}")

    min_v, max_v = mean_sig.min(), mean_sig.max()
    norm_sig = (mean_sig - min_v) / (max_v - min_v) 

    norm_sig_np = norm_sig.cpu().numpy() 
    heatmap_path =  args.heatmap_path 
    save_heatmap(norm_sig_np,
             out_path = heatmap_path,
             cmap="hot",       
             clip_percent=1.0)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-folder", type=str, default="")
    parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
    parser.add_argument("--answers-file", type=str, default="answer.jsonl")
    parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
    parser.add_argument("--num-chunks", type=int, default=1)
    parser.add_argument("--chunk-idx", type=int, default=0)
    parser.add_argument("--use-attn-back", action='store_true', default=False) 
    parser.add_argument("--heatmap-path", type=str)    
    args = parser.parse_args()

    eval_model(args)



