import os
from datasets import load_dataset
import torch
import json
from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM,AutoConfig
from tqdm import tqdm
import numpy as np
import random
import argparse
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers.models.llama.modeling_llama import (
    repeat_kv,
)
from collections import Counter
from typing import List
from asym_kv.util.pred_utils import greedy_generate, post_process
from asym_kv.method.compress_433 import KVCompressCache
from asym_kv.method.compress_forward.compress_llama_433 import enable_llama_pos_shift_compress_attention_433

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from torch.nn import CrossEntropyLoss
def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="Llama-3-8B-Instruct", choices=["gemma-1.1-2b","gemma-1.1-7b","Mistral-7B-Instruct-v0.3", "Llama-3.1-8B-Instruct", "Llama-3-8B-Instruct", "llama2-7b-chat-4k"])
    parser.add_argument('--method', type=str, default="compress", choices=["llm","compress"])
    parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
    return parser.parse_args(args)

def build_chat(prompt, model_name,tokenizer):
    if "llama2" in model_name:
        prompt=f"[INST]{prompt}[/INST]"
    elif "Llama-3" in model_name:
        prompt=f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    elif "Qwen" in model_name or "Mistral" in model_name:
        chat = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": prompt}]
        prompt = tokenizer.apply_chat_template(chat, tokenize=False,add_generation_prompt=True)
    return prompt
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def load_model_and_tokenizer(path, model_name, device,method):
    tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
    # if "Qwen" in model_name or "Mistral" in model_name:
    config= AutoConfig.from_pretrained(path, trust_remote_code=True)
    config._attn_implementation ="eager"
    model = AutoModelForCausalLM.from_pretrained(path,config=config, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)

    if tokenizer.pad_token_id is None:
        if tokenizer.eos_token_id is not None:
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer.pad_token_id = 0
    model = model.eval()
    return model, tokenizer  

def get_pred(model, tokenizer, rank, world_size, data_all, max_gen, prompt_format, dataset, device, model_name, out_path,method):
    data = data_all[rank::world_size]
    k_seq_dim = v_seq_dim = 2
    recent_size=2048
    if model_name=="llama2-7b-chat-4k":
        recent_size=2048
    elif model_name=="Llama-3-8B-Instruct":
        recent_size=2048
    else:
        recent_size=2048
    if method=="compress":
        enable_llama_pos_shift_compress_attention_433(model)
    diagonal_mean_list=[]
    cross_terms_mean_list=[]
    recent_scale=40
    for json_obj in tqdm(data, desc=f"Processing dataset {dataset} on rank {rank}"):
        past_key_values = None
        if dataset in ["lsht", "trec", "triviaqa", "samsum"] or dataset in ["lcc", "repobench-p"]:
            prompt = prompt_format.format(**json_obj)
        else:
            prompt = build_chat(prompt_format.format(**json_obj),model_name,tokenizer)
        
        input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        context_length = input.input_ids.shape[-1]
        recent_size = 2048
        if method=="compress":
            kv_cache=KVCompressCache(
                start_size=32,
                recent_size=recent_size,
                k_seq_dim=k_seq_dim,
                v_seq_dim=v_seq_dim,
            )
        else:
            kv_cache=None
        input_window = 512
        outputs=None
        hessian_diagnoal = []
        attns_global=[]
        delta_idx=0
        alpha=0
        for idx in range(0, context_length-1,input_window):
            if idx + input_window<context_length:
                input_ids = input.input_ids[:, idx : idx + input_window].to(device)
            elif idx>context_length:
                input_ids = input.input_ids[:, idx-input_window:].to(device)
            elif idx+input_window>=context_length:
                input_ids = input.input_ids[:, idx:].to(device)

            if kv_cache is not None and past_key_values is not None:
                if  method=="compress":
                    num_key_value_groups = model.model.layers[0].self_attn.num_key_value_groups

                    past_key_values = tuple(tuple(kv.detach().requires_grad_(True) for kv in layer) for layer in past_key_values)
                        
                    outputs = model(input_ids=input_ids,
                                    past_key_values=past_key_values,
                                    use_cache=True,
                                    output_attentions=True,
                                    labels=input_ids)
                    loss = outputs.loss
                   

                    attns = outputs.attentions

                    loss.backward()

                    if loss.isnan():
                        hessian_diagnoal_local = None
                        print("loss is Nan for idx", idx)
                    else:

                        if len(hessian_diagnoal)==0 or hessian_diagnoal[0][0].shape[-2]<recent_size+input_window+32:
                            hessian_diagnoal = []
                            for layer in past_key_values:
                                layer_gradients = []
                                for kv in layer[:2]: #kv not len
                                    layer_gradients.append(kv.grad)
                                hessian_diagnoal.append(tuple(layer_gradients))
                        else:
                            hessian_diagnoal_local = []
                            for i,layer in enumerate(past_key_values):
                                layer_gradients = []
                                hk_part1 = ((1 / delta_idx)+alpha) * layer[0].grad[:, :, :hessian_diagnoal[i][0].shape[2], :] + \
                                        (((delta_idx - 1) / delta_idx)-alpha) * hessian_diagnoal[i][0]
                                hk_part2 = layer[0].grad[:, :, hessian_diagnoal[i][0].shape[2]:, :]

                                hk = torch.cat((hk_part1, hk_part2), dim=2)
                                layer_gradients.append(hk)

                                hv_part1 = ((1 / delta_idx)+alpha) * layer[1].grad[:, :, :hessian_diagnoal[i][1].shape[2], :] + \
                                        (((delta_idx - 1) / delta_idx)-alpha) * hessian_diagnoal[i][1]

                                hv_part2 = layer[1].grad[:, :, hessian_diagnoal[i][1].shape[2]:, :]

                                hv = torch.cat((hv_part1, hv_part2), dim=2)
                                layer_gradients.append(hv)
                                hessian_diagnoal_local.append(tuple(layer_gradients))
                            hessian_diagnoal = hessian_diagnoal_local
                        hessian_diagnoal_local = hessian_diagnoal
                    torch.cuda.empty_cache()
                    
                    past_key_values = tuple(tuple(kv.detach().requires_grad_(False) for kv in layer) for layer in past_key_values)
                    attns = [attn[:,:,:,:-input_ids.shape[-1]] for attn in attns]
                    attns_global=attns
                        
                    past_key_values,hessian_diagnoal_local = kv_cache(past_key_values,attns_global,num_key_value_groups, hessian_diagnoal_local,return_Cache=True)
                    if hessian_diagnoal_local==[]:
                        hessian_diagnoal=None
                    else:
                        hessian_diagnoal=hessian_diagnoal_local
            with torch.no_grad():
                outputs = model(
                                input_ids,
                                past_key_values=past_key_values,
                                use_cache=True,
                                labels=input_ids,
                            )
                past_key_values = outputs.past_key_values

        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        pred = greedy_generate(
            model, tokenizer, pred_token_idx, past_key_values, max_gen_len=max_gen
        )
        pred = post_process(pred, model_name)
        with open(out_path, "a", encoding="utf-8") as f:
            json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"] if "all_classes" in json_obj else None, "length": 10000}, f, ensure_ascii=False)
            f.write('\n')

    if world_size>1:
        dist.barrier()  

if __name__ == '__main__':
    seed_everything(42)
    args = parse_args()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    if world_size > 1:
        dist.init_process_group(backend='nccl')

    model2path = json.load(open("config/model2path.json", "r"))
    model2maxlen = json.load(open("config/model2maxlen.json", "r"))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_name = args.model
    if args.e:
        datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \
            "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
    else:
        datasets = ["narrativeqa", "qasper", "multifieldqa_en", \
                    "hotpotqa", "2wikimqa", "musique", \
                     "gov_report", "qmsum", "multi_news", \
                    "trec", "triviaqa", "samsum",  \
                    "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
    dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
    dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
    # predict on each dataset
    if local_rank == 0:
        if not os.path.exists("pred_delta"):
            os.makedirs("pred_delta")
        if not os.path.exists("pred_e"):
            os.makedirs("pred_e")

    torch.cuda.set_device(local_rank)
    device = torch.device(f'cuda:{local_rank}')
    model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device,args.method)

    for dataset in datasets:
        if args.e:
            data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
            if not os.path.exists(f"pred_e/{model_name}"):
                os.makedirs(f"pred_e/{model_name}")
            out_path = f"pred_e/{model_name}/{dataset}.jsonl"
        else:
            data = load_dataset('THUDM/LongBench', dataset, split='test')
            if not os.path.exists(f"pred_delta/{args.method}/{model_name}"):
                os.makedirs(f"pred_delta/{args.method}/{model_name}")
            out_path = f"pred_delta/{args.method}/{model_name}/{dataset}.jsonl"
        if local_rank == 0 and os.path.exists(out_path):
            os.remove(out_path)
        prompt_format = dataset2prompt[dataset]
        max_gen = dataset2maxlen[dataset]
        data_all = [data_sample for data_sample in data]
        # data_all=random.sample(data_all,50)
        get_pred(model, tokenizer, local_rank, world_size, data_all, max_gen, prompt_format, dataset, device, model_name, out_path,args.method)
    if world_size > 1:
        dist.destroy_process_group()