import json
import math

import torch
from torch import nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaAttention, apply_rotary_pos_emb, repeat_kv
from transformers.models.mistral.modeling_mistral import MistralAttention
# from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from datasets import load_dataset
from functools import partial
import tqdm
import matplotlib.pyplot as plt
import numpy as np

def get_calib_dataset(tokenizer=None, n_samples=256, block_size=512):
    dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
    dataset = dataset.shuffle(seed=42)
    samples = []
    n_run = 0
    for data in dataset:
        line = data["text"]
        line = line.strip()
        line_encoded = tokenizer.encode(line)
        if len(line_encoded) > block_size:
            continue
        sample = torch.tensor([line_encoded])
        if sample.numel() == 0:
            continue
        samples.append(sample)
        n_run += 1
        if n_run == n_samples:
            break

    # now concatenate all samples and split according to block size
    cat_samples = torch.cat(samples, dim=1)
    n_split = cat_samples.shape[1] // block_size
    print(f" * Split into {n_split} blocks")
    return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]


@torch.no_grad()
def attn_similarity_across_layer(model, tokenizer):

    output_dict = dict()

    print("Collecting attention similarity...")
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = model.device

    samples = get_calib_dataset(tokenizer)
    pbar = tqdm.tqdm(samples)
    for input_ids in pbar:
        input_ids = input_ids.to(device)
        outputs = model(input_ids, output_attentions=True, return_dict=True)
        print(outputs["attentions"][0].shape)
        for layer_id in range(len(outputs["attentions"])-1):
            # [batch_size, num_heads, seq_len, seq_len]
            attn1 = outputs["attentions"][layer_id].transpose(0, 1)
            attn2 = outputs["attentions"][layer_id+1].transpose(0, 1)

            # [num_heads, batch_size, seq_len]
            cos_sim = F.cosine_similarity(attn1, attn2, dim=-1)

            mean_cos_sim = cos_sim.mean(dim=1).mean(dim=1).cpu().detach() # [num_heads]
            print(mean_cos_sim)


            if layer_id not in output_dict:
                output_dict[layer_id] = [mean_cos_sim]
            else:
                output_dict[layer_id] += [mean_cos_sim]



    return output_dict

model_path = "meta-llama/Llama-2-7b-hf"
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--huggyllama--llama-7b/snapshots/8416d3fefb0cb3ff5775a7b13c1692d10ff1aa16"
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--facebook--opt-6.7b/snapshots/a45aa65bbeb77c1558bc99bedc6779195462dab0"
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-chat-hf/snapshots/94b07a6e30c3292b8265ed32ffdeccfdadf434a8"
# model_path = "/home/ec2-user/.cache/huggingface/hub/models--lmsys--vicuna-7b-v1.5-16k/snapshots/9a93d7d11fac7f3f9074510b80092b53bc1a5bec"
# model_path = "mistralai/Mistral-7B-v0.1"
# model_path = "meta-llama/Llama-2-70b-chat-hf"
# model_path = "meta-llama/Llama-2-70b-hf"
# model_path = "mistralai/Mixtral-8x7B-v0.1"


kwargs = {"torch_dtype": torch.float16, "device_map": "auto"}

model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
# model = LlamaForCausalLM.from_pretrained(model_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# tokenizer = LlamaTokenizer.from_pretrained(model_path)

# output_dict = get_calib_feat(model, tokenizer)
output_dict = attn_similarity_across_layer(model, tokenizer)

sim_dict = dict()

for k, v in output_dict.items():
    sim = torch.stack(v)
    sim_dict[k] = sim.mean(dim=0).tolist()

# with open("llama2-7b-attn-sim.json", "w") as f:
#     json.dump(sim_dict, f)