import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import torch
import torch.nn.functional as F
from models.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from models.llava.conversation import conv_templates
from models.llava.model.builder import load_pretrained_model
from models.llava.utils import disable_torch_init
from models.llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
from nltk.corpus import stopwords
import nltk

nltk.download('stopwords')

def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

def aggregate_llm_attention(attentions):
    layer_attns = attentions.squeeze(0)
    attns_per_head = layer_attns.mean(dim=0)
    cur = attns_per_head[:-1].cpu().clone()
    cur[1:, 0] = 0.
    cur[1:] = cur[1:] / cur[1:].sum(-1, keepdim=True)
    return cur

def aggregate_vit_attention(attentions, select_layer=-1, all_prev_layers=False):
    if all_prev_layers:
        attentions = torch.stack(attentions[:select_layer+1])
        attentions = torch.prod(attentions, dim=0)
    else:
        attentions = attentions[select_layer]
    residual_att = torch.eye(attentions.size(-1))
    aug_att_mat = attentions + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(-1).unsqueeze(-1)
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]
    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
    return joint_attentions[-1]

def heterogenous_stack(tensors):
    max_len = max(t.shape[0] for t in tensors)
    stacked = []
    for t in tensors:
        if t.shape[0] < max_len:
            padded = F.pad(t, (0, max_len - t.shape[0], 0, max_len - t.shape[0]))
            stacked.append(padded)
        else:
            stacked.append(t)
    return torch.stack(stacked).mean(0)

def show_mask_on_image(img, mask):
    mask = mask - np.min(mask)
    mask = mask / np.max(mask)
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    img_with_heatmap = cv2.addWeighted(np.array(img), 0.5, heatmap, 0.5, 0)
    return img_with_heatmap, heatmap

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"
disable_torch_init()

model_path = "liuhaotian/llava-v1.6-7b"
load_8bit = False
load_4bit = False

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, 
    None,
    get_model_name_from_path(model_path),
    load_8bit,
    load_4bit,
    device=device
)

model.to(device)
model.get_model().vision_tower.to(device)
if hasattr(model.get_model().vision_tower, 'vision_model'):
    model.get_model().vision_tower.vision_model.to(device)
model.get_model().mm_projector.to(device)

image_path_or_url = "data/test/plots/agriculture_95.jpeg"
prompt_text = """
I will give you a time series about the yearly Aggregated input index (2015=100) in the country of El Salvador, from 2009 to 2018. 
The Aggregated input index (2015=100) comprises the following components: land, labor, capital, materials.
    Here is the time series: 
 95.38, 95.71, 97.55, 98.41, 99.3, 96.97, 100.0, 101.86, 102.48, 99.83
    
 Describe this time series by focusing on trends and patterns. 
    Discuss concrete numbers you see and pay attention to the dates.
    For numerical values, ensure consistency with the provided time series. If making percentage comparisons, round to the nearest whole number. Report the dates when things happened.
          
    Compare the trends in this time series to global or regional norms, explaining whether they are higher, lower, or follow expected seasonal patterns.
    When making comparisons, clearly state whether differences are minor, moderate, or significant.
    Use descriptive language to create engaging, natural-sounding text.
    Avoid repetitive phrasing and overused expressions.
    Answer in a single paragraph of four sentences at most, without bullet points or any formatting.
    
"""

model_name = get_model_name_from_path(model_path)
if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

conv = conv_templates[conv_mode].copy()
roles = conv.roles

image = load_image(image_path_or_url)
image_tensor, images = process_images([image], image_processor, model.config)
image = images[0]
image_size = image.size
if type(image_tensor) is list:
    image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

if model.config.mm_use_im_start_end:
    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + prompt_text
else:
    inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text

conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
prompt = prompt.replace("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. ", "")

input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)

with torch.inference_mode():
    outputs = model.generate(
        input_ids,
        images=image_tensor,
        image_sizes=[image_size],
        do_sample=False,
        max_new_tokens=200,
        use_cache=True,
        return_dict_in_generate=True,
        output_attentions=True,
    )

text = tokenizer.decode(outputs["sequences"][0]).strip()
print(text)

aggregated_prompt_attention = []
for layer in outputs["attentions"][0]:
    layer_attns = layer.squeeze(0)
    attns_per_head = layer_attns.mean(dim=0)
    cur = attns_per_head[:-1].cpu().clone()
    cur[1:, 0] = 0.
    cur[1:] = cur[1:] / cur[1:].sum(-1, keepdim=True)
    aggregated_prompt_attention.append(cur)

aggregated_prompt_attention = torch.stack(aggregated_prompt_attention).mean(dim=0)
llm_attn_matrix = heterogenous_stack([torch.tensor([1])] + list(aggregated_prompt_attention) + list(map(aggregate_llm_attention, outputs["attentions"])))

input_token_len = model.get_vision_tower().num_patches + len(input_ids[0]) - 1
vision_token_start = len(tokenizer(prompt.split("<image>")[0], return_tensors='pt')["input_ids"][0])
vision_token_end = vision_token_start + model.get_vision_tower().num_patches
output_token_len = len(outputs["sequences"][0])
output_token_start = input_token_len

overall_attn_weights_over_vis_tokens = []
for i in range(output_token_start, output_token_len + output_token_start - 1):
    vec = llm_attn_matrix[i][vision_token_start:vision_token_end]
    overall_attn_weights_over_vis_tokens.append(vec.sum().item())

mean_variance = torch.tensor(overall_attn_weights_over_vis_tokens).var().item()
print(f"Mean variance across tokens: {mean_variance:.6f}")

tokens = tokenizer.convert_ids_to_tokens(outputs["sequences"][0], skip_special_tokens=False)
grouped_tokens = []
token_groups = []
current_group = ""
current_indices = []

for i, token in enumerate(tokens):
    if token.startswith("▁") or token[0].isalpha() and not token.startswith("##"):
        if current_group:
            grouped_tokens.append(current_group.strip())
            token_groups.append(current_indices)
        current_group = token.replace("▁", "")
        current_indices = [i]
    else:
        current_group += token.replace("▁", "")
        current_indices.append(i)

if current_group:
    grouped_tokens.append(current_group.strip())
    token_groups.append(current_indices)

grouped_attn_values = []
for group in token_groups:
    group = [i for i in group if i < len(overall_attn_weights_over_vis_tokens)]
    grouped_attn_values.append(sum(overall_attn_weights_over_vis_tokens[i] for i in group))

top_n = 12
top_n_indices = torch.tensor(grouped_attn_values).topk(top_n).indices.tolist()

tokens_to_exclude = {"(2015=100)", "2009,", ""}
filtered = [
    (tok, grp, val)
    for tok, grp, val in zip(grouped_tokens, token_groups, grouped_attn_values)
    if tok not in tokens_to_exclude
]

if filtered:
    grouped_tokens, token_groups, grouped_attn_values = zip(*filtered)
    grouped_tokens = list(grouped_tokens)
    token_groups = list(token_groups)
    grouped_attn_values = list(grouped_attn_values)

fig, ax = plt.subplots(figsize=(min(2 * len(grouped_tokens), 50), 10))
ax.plot(grouped_attn_values, marker="o")
ax.set_xticks(range(len(grouped_tokens)))
ax.set_xticklabels(grouped_tokens, rotation=75, ha='right', fontsize=8)
ax.set_ylabel("Attention to Vision Tokens")
ax.set_title("Top-N Grouped Token Attention over Vision Tokens")
plt.tight_layout()
plt.grid(True)
plt.show()

vision_model = model.get_model().vision_tower.vision_tower
with torch.inference_mode():
    vit_outputs = vision_model(image_tensor, output_attentions=True, return_dict=True)

vis_attn_matrix = aggregate_vit_attention(
    vit_outputs.attentions,
    select_layer=-2,
    all_prev_layers=True
)

grid_size = model.get_vision_tower().num_patches_per_side
image_ratio = image_size[0] / image_size[1]
num_image_per_row = 4
num_rows = len(grouped_tokens) // num_image_per_row + (1 if len(grouped_tokens) % num_image_per_row != 0 else 0)

fig, axes = plt.subplots(
    num_rows, num_image_per_row, 
    figsize=(10, (10 / num_image_per_row) * image_ratio * num_rows), 
    dpi=150
)
plt.subplots_adjust(wspace=0.05, hspace=0.2)

output_token_inds = list(range(llm_attn_matrix.shape[0] - output_token_len, llm_attn_matrix.shape[0]))

for i, ax in enumerate(axes.flatten()):
    if i >= len(token_groups):
        ax.axis("off")
        continue

    target_indices = [output_token_inds[j] for j in token_groups[i] if j < len(output_token_inds)]
    if not target_indices:
        ax.axis("off")
        continue

    attn_weights_over_vis_tokens = sum(
        llm_attn_matrix[tidx][vision_token_start:vision_token_end]
        for tidx in target_indices
    )
    attn_over_image = []
    for weight, vis_attn in zip(attn_weights_over_vis_tokens, vis_attn_matrix):
        vis_attn = vis_attn.reshape(grid_size, grid_size)
        attn_over_image.append(vis_attn * weight)
    attn_over_image = torch.stack(attn_over_image).sum(dim=0)
    attn_over_image = attn_over_image / attn_over_image.max()

    attn_over_image = F.interpolate(
        attn_over_image.unsqueeze(0).unsqueeze(0),
        size=image.size,
        mode='nearest'
    ).squeeze()

    np_img = np.array(image)[:, :, ::-1]
    img_with_attn, heatmap = show_mask_on_image(np_img, attn_over_image.numpy())
    ax.imshow(img_with_attn)
    ax.set_title(grouped_tokens[i], fontsize=7, pad=1)
    ax.axis("off")

for ax in axes.flatten()[len(grouped_tokens):]:
    ax.axis("off")
