"""
This script is used to directly load model to compute attention sink
"""

import glob
import math
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union
import math
import lightning as L
import torch
import torch.nn as nn
from lightning.fabric.strategies import FSDPStrategy, XLAStrategy
from torch.utils.data import DataLoader
from functools import partial
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
# from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually
from lit_gpt.model_infer import GPT, Config
from lit_gpt.model_sigmoid_infer import GPTsigmoid
from lit_gpt.model_kv_bias_infer import GPTkvbias
from lit_gpt.packed_dataset import CombinedDataset, PackedDataset
from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor
from lit_gpt.speed_monitor import estimate_flops, measure_flops
from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load
from pytorch_lightning.loggers import WandbLogger
from lit_gpt import FusedCrossEntropyLoss
import random
import yaml
import os
import json
import numpy as np
from lit_gpt import Tokenizer


def load_data():
    tokenizer_path = Path("./preprocess/tokenizer/gptneox")
    tokenizer = Tokenizer(tokenizer_path)
    

    with open("datasets/probe_valid.jsonl", 'r') as f:
        prompts = [json.loads(line)["text"] for line in f]
     

    print(f"Tokenizer: BOS token {tokenizer.bos_id}, EOS token {tokenizer.eos_id}, vocabulary size {tokenizer.vocab_size}")
    print(len(prompts))
    print(prompts[0])
    # print(tokenizer.encode(prompts[0], bos=True, eos=False))
    
    all_inputs = []
    for prompt in prompts:
        all_inputs.append(tokenizer.encode(prompt, eos=False).unsqueeze(dim=0).to(torch.long))
    # print(all_inputs)
    return all_inputs



def evaluate_model(model_name, model_class, load_from, fixed_token=0):
    device = torch.device('cuda')
    all_inputs = load_data()
    model_config = Config.from_name(model_name)
    # define the model
    if model_class == "sigmoid":
        model = GPTsigmoid(model_config)
    elif model_class == 'base':
        model = GPT(model_config)
    elif model_class == "kv_bias":
        model = GPTkvbias(model_config)

    if load_from is not None:
        # use torch.load to load the model
        print("loading model from {}".format(load_from))
        state_dict = torch.load(load_from, map_location=device)
        if "model" in state_dict:
            state_dict = state_dict["model"]
        model.load_state_dict(state_dict, strict=True, assign=True)
    

    model.to(torch.bfloat16)
    token_length = 64
    size = 100
    count = 0
    num_layers = model_config.n_layer
    loss_func = FusedCrossEntropyLoss()
    attention_scores_all_sample = []
    hidden_states_all_sample = []
    all_loss = []
    for data in all_inputs:
        inputs = data[:, :token_length].to(device).contiguous()
        labels = data[:, 1:token_length+1].to(device).contiguous()
        outputs, all_attns, all_hiddens = model(
            inputs,
            output_attention=True,
        )
        # print(all_attns)
        
        loss = loss_func(outputs, labels)
        all_loss.append(loss.item())
        # print(f"loss: {loss.item()}")
        attention_scores_all_layer = []
        for l in range(num_layers):
            attentions_layer = all_attns[l] #.cpu()
            attention_scores_all_layer.append(attentions_layer)
            # print(attentions_layer.mean(dim=(0,1)))
        # break
        attention_scores_all_layer = torch.cat(attention_scores_all_layer, dim=0)
        attention_scores_all_sample.append(attention_scores_all_layer.unsqueeze(dim=0))

        count += data.shape[0]
        if count >= size:
            break
        
        hidden_states_all_layer = torch.cat(all_hiddens, dim=0)
        hidden_states_all_sample.append(hidden_states_all_layer.unsqueeze(dim=0))
    print(f"averaged loss: {sum(all_loss) / size}")
    attention_scores_all_sample = torch.cat(attention_scores_all_sample, dim=0)  # (num_samples, num_layers, num_heads, num_tokens)
    # print(attention_scores_all_sample.mean(dim=(0,2,3)))
    
    # data_list = np.concatenate(data_list, axis=0)
    attention_scores = attention_scores_all_sample.detach().cpu() # torch.from_numpy(attention_scores)
    num_samples, num_layers, num_heads, num_tokens1, num_tokens2 = attention_scores.shape
    ratios2 = torch.arange(num_tokens1, 0, -1)[None, None, None, :].expand(num_samples, num_layers, num_heads, num_tokens1, num_tokens2).to(attention_scores)

    epsilon = 0.3
    importance_scores = (attention_scores / ratios2).sum(dim=-2)  # (num_samples, num_layers, num_heads, num_tokens)
    metric1 = (importance_scores > epsilon).to(torch.float).mean(dim=(1,2))
    # print(metric1[:, :10] * 100)

    print(metric1[:, :10].mean(dim=0) * 100)

    # norm
    hidden_states_all_sample = torch.cat(hidden_states_all_sample, dim=0)  # sample, layer, seqence, dim
    hidden_states_all_sample_norm = hidden_states_all_sample.norm(p=2, dim=-1)
    for l in range(num_layers+1):
        print(f"layer: {l}, norm for each token: {hidden_states_all_sample_norm[:, l].mean(dim=0)}") # \pm {hidden_states_all_sample_norm[:, l].std(dim=0)}")



def evaluate_all_models():
    model_name = f"tinyllama_60M"
    model_ckpt = f"{model_name}"
    steps = "020000"
    load_from = f"./checkpoints/{model_ckpt}/iter-{steps}-ckpt.pth"
    print(f"load model from {load_from}")
    evaluate_model(model_name, 'base', load_from)


    model_name = f"tinyllama_60M"
    model_ckpt = f"{model_name}_sigmoid"
    steps = "020000"
    load_from = f"./checkpoints/{model_ckpt}/iter-{steps}-ckpt.pth"
    print(f"load model from {load_from}")
    evaluate_model(model_name, 'sigmoid', load_from)


    model_name = f"tinyllama_60M_k_head_bias"
    model_ckpt = f"{model_name}"
    steps = "020000"
    load_from = f"./checkpoints/{model_ckpt}/iter-{steps}-ckpt.pth"
    print(f"load model from {load_from}")
    evaluate_model(model_name, 'kv_bias', load_from)


if __name__ == "__main__":
    evaluate_all_models()