import os
import time
import json
from fontTools.merge.util import first
from fontTools.ttLib.tables.grUtils import num2tag
from sympy.physics.tests.test_secondquant import att
import torch
import random
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from models.falcon_mamba.modeling_falcon_mamba import FalconMambaForCausalLM
import pandas as pd

NUM_CHANNELS = 200

def get_attention_score(model, tokenizer, prompts, results_path, token_length=50, device=torch.device("cuda")):
    num_layers = model.config.num_hidden_layers
    num_heads = NUM_CHANNELS
    
    attention_scores_at_layers_per_sample = []
    
    for prompt in tqdm(prompts):
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        for key in inputs.keys():
            assert inputs[key].shape[1] >= token_length
            inputs[key] = inputs[key][:, :token_length]
        
        outputs = model.generate(
            **inputs,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=1
        )
        
        attentions = outputs['attentions']
        assert len(attentions) == 1
        attention_scores_at_layers_per_sample.append(torch.cat(attentions[0], dim=0).unsqueeze(0)) # (1, num_layers, num_heads, seq_len, seq_len)
        
    attention_scores_at_layers_per_samples = torch.cat(attention_scores_at_layers_per_sample, dim=0) # (num_samples, num_layers, num_heads, seq_len, seq_len)
    return attention_scores_at_layers_per_samples

def measure_attention_sink(model, tokenizer, prompts, results_path, token_length=50, device=torch.device("cuda"), recompute_attention=False):
    attention_path = os.path.join(results_path, f"attention_scores_per_layers.npy")
    num_layers = model.config.num_hidden_layers
    num_heads = NUM_CHANNELS
    if os.path.exists(attention_path) and not recompute_attention:
        attention_scores_at_layers_per_samples = np.load(attention_path)
    else:
        attention_scores_at_layers_per_samples = get_attention_score(model, tokenizer, prompts, results_path, token_length, device)
        attention_scores_at_layers_per_samples = attention_scores_at_layers_per_samples.cpu().numpy()
        # for each layer, each head, normalize the attention scores and plot heatmap
        if not os.path.exists(results_path):
            os.makedirs(results_path)
        # save attention scores to npy file
        np.save(os.path.join(results_path, f"attention_scores_per_layers.npy"), attention_scores_at_layers_per_samples)

def compute_attention_sink(score_path, epsilon):
    attention_scores = np.load(score_path)
    num_samples, num_layers, num_heads, num_tokens1, num_tokens2 = attention_scores.shape
    assert num_tokens1 == num_tokens2
    attention_scores = torch.from_numpy(attention_scores)
    ratios = torch.arange(num_tokens1, 0, -1)[None, None, None, :].expand(num_samples, num_layers, num_heads, num_tokens1, num_tokens2).to(attention_scores)
    importance_scores = (attention_scores / ratios).sum(dim=-2) # (num_samples, num_layers, num_heads, num_tokens)
    metric1 = (importance_scores > epsilon).to(torch.float).mean(dim=(0,2))
    return metric1 * 100


def measure_open_sourced_lms():
    # load model family
    device = torch.device("cuda")
    os.makedirs("results", exist_ok=True)
    ########################################
    model_path = "tiiuae/Falcon3-Mamba-7B-Instruct"
    model_name = model_path.split("/")[-1]
    os.makedirs(f"results/{model_name}", exist_ok=True)

    model = FalconMambaForCausalLM.from_pretrained(
        model_path,
        device_map="auto"
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_path
    )
    #########################################
    
    # load data and feed them into LLMs
    file_path = "probe_valid.jsonl"
    token_length = 16
    epsilon = 0.3
        
    results_path = f"results/{model_name}_token{token_length}"
    with open(file_path, 'r') as f:
        prompts = [json.loads(line)["text"] for line in f]
    measure_attention_sink(model, tokenizer, prompts, results_path, token_length, device, recompute_attention=False)
    metric = compute_attention_sink(os.path.join(results_path, f"attention_scores_per_layers.npy"), epsilon=epsilon)
    
    # plot attention sink metric at each layer for token 0 s
    plt.figure(figsize=(12, 8))

    first_token_metric = metric[:, 0]
    layers = np.arange(1, len(first_token_metric) + 1)

    # Plot raw data with blue color scheme (matching your heatmaps)
    plt.plot(layers, first_token_metric, 
            marker='o', markersize=7, linewidth=2.5, 
            color='#4A90E2', alpha=0.9, label='Raw Data')

    # Improved moving average (no padding needed)
    window_size = 12
    moving_avg = pd.Series(first_token_metric).rolling(
        window=window_size, center=True, min_periods=1
    ).mean().values

    plt.plot(layers, moving_avg, 
            color='#E74C3C', linewidth=3, 
            label='Moving Average')

    # Styling
    plt.ylim(0, 100)
    plt.xlabel('Layer', fontsize=32)
    plt.ylabel(f'Attention Sink (%)', fontsize=32)
    plt.xticks(np.arange(0, len(first_token_metric) + 8, 8), fontsize=28)
    plt.xlim(0.5, len(first_token_metric) + 0.5) 
    plt.yticks(fontsize=28)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=28)
    plt.tight_layout()

    plt.savefig(os.path.join(results_path, f'attention_sink_metric_token0_e{epsilon}.pdf'),
                bbox_inches='tight', facecolor='white', dpi=300)
    plt.close()
    
    # compare the metric at different tokens averaged over layers
    plt.figure(figsize=(12, 8))
    metric = metric.mean(dim=0)
    tokens = np.arange(1, len(metric) + 1)
    # bar plot
    plt.bar(tokens, metric, color='#4A90E2', alpha=0.9)
    plt.xlabel('Token Position', fontsize=32)
    plt.ylabel(f'Attention Sink (%)', fontsize=32)
    plt.xlim(0.5, len(metric) + 0.5) 
    plt.xticks(np.arange(2, len(metric) + 1, 2), fontsize=28)
    plt.yticks(fontsize=28)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(results_path, f'attention_sink_metric_all_tokens_e{epsilon}.pdf'),
                bbox_inches='tight', facecolor='white', dpi=300)
    plt.close()


if __name__ == "__main__":
    measure_open_sourced_lms()
