import argparse
import time
import os
import json
import random
import numpy as np
from tqdm import tqdm
import torch
import transformers
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from datasets import load_dataset

from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()

# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
                                    inputs_embeds, past_key_values_length):
    # [bsz, seq_len]
    if input_shape[-1] > 1 and past_key_values_length == 0:  # encode
        return attention_mask
    return transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask(self, attention_mask,
                                                                                              input_shape,
                                                                                              inputs_embeds,
                                                                                              past_key_values_length)

transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask


@torch.no_grad()
def greedy_generate(model, tokenizer, input_ids, max_gen_len):
    outputs = model(input_ids=input_ids, past_key_values=None, use_cache=False)
    past_key_values = outputs.past_key_values
    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
    generated_ids = [pred_token_idx.item()]
    pos = 0
    for i in range(max_gen_len - 1):
        outputs = model(
            input_ids=pred_token_idx,
            position_ids=torch.tensor([[input_ids.shape[-1] + len(generated_ids) - 1]], device=input_ids.device),
            past_key_values=past_key_values,
            use_cache=False,
        )
        past_key_values = outputs.past_key_values
        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        generated_ids.append(pred_token_idx.item())

        if pred_token_idx == tokenizer.eos_token_id:
            break
    return generated_ids


# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
    if "llama" in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    elif "longchat" in model_name or "vicuna" in model_name:
        from fastchat.model import get_conversation_template
        conv = get_conversation_template("vicuna")
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
    else:
        raise NotImplementedError
    return prompt


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)


def post_process(response, model_name):
    if "xgen" in model_name:
        response = response.strip().replace("Assistant:", "")
    elif "internlm" in model_name:
        response = response.split("<eoa>")[0]
    return response


if __name__ == '__main__':
    seed_everything(42)
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--datasets', type=str, default='all', help='')
    parser.add_argument('--kv_method', type=str, default='exact', help='')
    parser.add_argument('--hh_size', type=int, default=1024, help='')
    parser.add_argument('--recent_size', type=int, default=1024, help='')
    parser.add_argument('--max_length', type=int, default=20000, help='')
    parser.add_argument('--dist', type=str, default='euclidean', help='', choices=['euclidean', 'angle'])
    parser.add_argument('--seed_select', type=str, default='greedy', help='', choices=['greedy', 'kpp'])
    args = parser.parse_args()

    for n_, v_ in args.__dict__.items():
        print(f"{n_:<10}, {v_}")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model_name = "lmsys/longchat-7b-v1.5-32k"
    dtype = torch.bfloat16
    kv_method = args.kv_method
    hh_size = args.hh_size
    recent_size = args.recent_size

    max_length = 20000

    output_dir = f"results_longbench/{model_name}_{max_length}_{kv_method}"
    if kv_method != "exact":
        output_dir += f"_hh{hh_size}_recent{recent_size}"
    if kv_method == 'clustergen' and args.dist != 'euclidean':
        output_dir += f"_{args.dist}"
    if kv_method == 'clustergen' and args.seed_select != 'greedy':
        output_dir += f"_{args.seed_select}"

    print(f"kv_method: {kv_method}")
    print(f"output_dir: {output_dir}")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    config = LlamaConfig.from_pretrained(model_name)
    config._flash_attn_2_enabled = True
    if 'longchat' in model_name:
        tokenizer = AutoTokenizer.from_pretrained(model_name, 
                                            use_fast=False, 
                                            trust_remote_code=True, 
                                            tokenizer_type='llama')
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)

    tic = time.time()
    if kv_method in ['exact']:
        tic = time.time()
        model = LlamaForCausalLM.from_pretrained(
                pretrained_model_name_or_path=model_name,
                config=config,
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
                device_map="auto",
            )
    
    elif kv_method == 'sink':
        from modeling_llama_sink import SinkLlamaForCausalLM, SinkLlamaAttention
        config.hh_size = hh_size
        config.recent_size = recent_size
        config.cache_size = hh_size + recent_size
        model = SinkLlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto")

    elif kv_method == 'clustergen':
        from modeling_llama_clustergen import ClustergenLlamaForCausalLM, ClustergenLlamaAttention
        config.hh_size = hh_size
        config.recent_size = recent_size
        config.cache_size = hh_size + recent_size
        config.method = 'kcenter'
        config.dist = args.dist
        config.seed_select = args.seed_select
        model = ClustergenLlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto")
    else:
        import pdb; pdb.set_trace();
    print(f"model loaded ... {time.time() - tic:.4f} sec")

    dataset2prompt = json.load(open("./config/dataset2prompt.json", "r"))
    dataset2maxlen = json.load(open("./config/dataset2maxlen.json", "r"))
    
    if args.datasets == 'all':
        datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
    else:
        datasets = [args.datasets]

    for dataset in datasets:
        print(f"=================== dataset: {dataset} ===================")
        data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
        prompt_format = dataset2prompt[dataset]
        max_gen = dataset2maxlen[dataset]

        model_name = model_name.split("/")[-1]
        out_path = f"{output_dir}/{dataset}.jsonl"
        if os.path.exists(out_path):
            print(f"{out_path} exists... pass {dataset} dataset")
            continue

        preds = []
        for json_obj in tqdm(data):
            prompt = prompt_format.format(**json_obj)
            # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
            tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
            if len(tokenized_prompt) > max_length:
                half = int(max_length/2)
                prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
            if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
                prompt = build_chat(tokenizer, prompt, model_name)
            input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
            context_length = input.input_ids.shape[-1]
            
            if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
                output = model.generate(
                    **input,
                    max_new_tokens=max_gen,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0,
                    min_length=context_length+1,
                    eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
                )[0]
            else:
                if kv_method == 'sink':
                    output = greedy_generate(model, tokenizer, input.input_ids, max_gen)
                else:
                    output = model.generate(
                        **input,
                        max_new_tokens=max_gen,
                        num_beams=1,
                        do_sample=False,
                        temperature=1.0,
                    )[0]
            pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
            pred = post_process(pred, model_name)
            preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]})

            if kv_method == 'sink':
                for name, m in model.named_modules():
                    if isinstance(m, SinkLlamaAttention):
                        m._clean_cache()

            if kv_method == 'clustergen':
                for name, m in model.named_modules():
                    if isinstance(m, ClustergenLlamaAttention):
                        m._clean_cache()
        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        with open(out_path, "w", encoding="utf-8") as f:
            for pred in preds:
                json.dump(pred, f, ensure_ascii=False)
                f.write('\n')
        # preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name)