import os
import time
import json
import torch
import random
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from models.falcon_mamba.modeling_falcon_mamba import FalconMambaForCausalLM

def get_attention_score(model, tokenizer, prompts, results_path, token_length=50, device=torch.device("cuda")):
    num_layers = model.config.num_hidden_layers
    num_heads = 10
    
    attention_scores_at_layers = torch.zeros((num_layers, num_heads, token_length, token_length), device=device)
    
    for prompt in tqdm(prompts):
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        for key in inputs.keys():
            assert inputs[key].shape[1] >= token_length
            inputs[key] = inputs[key][:, :token_length]
        
        outputs = model.generate(
            **inputs,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=1
        )
        
        attentions = outputs['attentions']
        assert len(attentions) == 1
        for l in range(num_layers):
            attentions_layer = attentions[0][l] # (batch_size=1, num_heads, seq_len, seq_len)
            # softmax over last dimension
            attention_scores_at_layers[l] += attentions_layer[0]
    attention_scores_at_layers /= len(prompts) # divide by number of prompts to get the average over samples
    return attention_scores_at_layers

def measure_attention_sink(model, tokenizer, prompts, results_path, token_length=50, device=torch.device("cuda"), recompute_attention=False):
    attention_path = os.path.join(results_path, f"attention_scores.npy")
    num_layers = model.config.num_hidden_layers
    num_heads = 10
    if os.path.exists(attention_path) and not recompute_attention:
        attention_scores_at_layers = np.load(attention_path)
    else:
        attention_scores_at_layers = get_attention_score(model, tokenizer, prompts, results_path, token_length, device)
        attention_scores_at_layers = attention_scores_at_layers.cpu().numpy()
        # for each layer, each head, normalize the attention scores and plot heatmap
        if not os.path.exists(results_path):
            os.makedirs(results_path)
        # save attention scores to npy file
        np.save(os.path.join(results_path, f"attention_scores.npy"), attention_scores_at_layers)
    
    cmap = plt.cm.coolwarm
    cmap.set_bad(color='#808080') 
    
    for l in range(num_layers):
        for h in range(num_heads):
            # plot heatmap
            plt.figure(figsize=(10, 8))
            # set upper triangle to nan
            attention_data = attention_scores_at_layers[l, h]
            mask = np.tril(attention_data, k=0)
            attention_data = np.where(mask == 0, np.nan, attention_data)
            plt.imshow(attention_data, cmap=cmap, aspect='auto')
            cbar = plt.colorbar()
            cbar.ax.tick_params(labelsize=22)
            # plt.title(f'Layer {l+1} Channel {h+1}', fontsize=36)
            plt.tick_params(axis='both', labelsize=32)
            plt.xticks(np.arange(0, token_length, 2))
            plt.yticks(np.arange(0, token_length, 2))
            # plt.xlabel('Key Position')
            # plt.ylabel('Query Position')
            plt.savefig(os.path.join(results_path, f'attention_layer{l+1}_head{h+1}.pdf'),
                        bbox_inches='tight', facecolor='white')
            plt.close()


def measure_open_sourced_lms():
    # load model family
    device = torch.device("cuda")
    os.makedirs("results", exist_ok=True)
    ########################################
    model_path = "tiiuae/Falcon3-Mamba-7B-Instruct"
    model_name = model_path.split("/")[-1]
    os.makedirs(f"results/{model_name}", exist_ok=True)

    model = FalconMambaForCausalLM.from_pretrained(
        model_path,
        device_map="auto"
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_path
    )
    #########################################
    
    # load data and feed them into LLMs
    file_path = "probe_valid.jsonl"
    token_length = 16
        
    results_path = f"results/{model_name}_token{token_length}"
    with open(file_path, 'r') as f:
        prompts = [json.loads(line)["text"] for line in f]
    measure_attention_sink(model, tokenizer, prompts, results_path, token_length, device)
    


if __name__ == "__main__":
    measure_open_sourced_lms()
