import sys

sys.path.append(".")

import argparse
import os
import time
import json
import logging
import pprint
from tqdm import tqdm
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
from take import config_logging
from eval.longbench.longbench_utils import calculate_task_len


def build_chat(tokenizer, prompt, model_name):
    if "Llama-2" in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    else:
        messages = [{"role": "user", "content": prompt}]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True,
                                               return_tensors="pt")

    return prompt


@torch.inference_mode()
def generate_longbench(data, max_length, max_gen, prompt_format,
                       dataset, model_name, model, tokenizer,
                       out_path, args):
    device = model.device
    for json_obj in tqdm(data, desc="Generating Responses..."):
        if args.dynamic_task_len:
            task_prompt = prompt_format['instruction'].format(**json_obj)
            prompt = prompt_format["context"].format(**json_obj) + task_prompt
            task_query_len = calculate_task_len(tokenizer, task_prompt)
        else:
            prompt = (prompt_format["context"] + prompt_format["instruction"]).format(**json_obj)
            task_query_len = args.task_query_len
        task_query_len += args.extra_task_len

        if dataset in ["lcc", "repobench-p"]:
            model.tao_kwargs["separators"] = json_obj["language"].upper()
        else:
            model.tao_kwargs["separators"] = None

        # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
        tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
        if len(tokenized_prompt) > max_length:
            half = int(max_length / 2)
            prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + tokenizer.decode(
                tokenized_prompt[-half:], skip_special_tokens=True)

        # chat models are better off without build prompts on these tasks
        if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
            prompt = build_chat(tokenizer, prompt, model_name)

        input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        context_length = input.input_ids.shape[-1]
        #调整task len
        model.tao_kwargs["task_query_len"] = task_query_len
        logging.info(f"task_query_len: {task_query_len}")
        output = model.generate(
            **input,
            num_beams=1,
            do_sample=False,
            temperature=1.0,
            top_p=1.0,
            min_length=context_length + 1,
            max_new_tokens=max_gen,
        )[0]
        pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)

        with open(out_path, "a", encoding="utf-8") as f:
            json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"],
                       "length": json_obj["length"]}, f, ensure_ascii=False)
            f.write('\n')

        torch.cuda.empty_cache()


def load_tao_config(dataset):
    config_path = "eval/longbench/config/dataset2tao.json"
    with open(config_path, "r") as f:
        dataset2tao = json.load(f)
    
    if dataset not in dataset2tao:
        raise ValueError(f"Dataset {dataset} not found in TAO configuration file")
    
    return dataset2tao[dataset]


def main(args):
    set_seed(args.seed)

    if "/" in args.model:
        model_version = args.model.split("/")[-1]
    else:
        model_version = args.model

    if args.mode == "take":
        tao_config = load_tao_config(args.dataset)
        logging.info(f"Loading TAO config for dataset {args.dataset}:")
        logging.info(pprint.pformat(tao_config))

        for key, value in tao_config.items():
            setattr(args, key, value)
    
    version_num = args.version.split(".")[0]
    if args.mode == "take":
        args.save_path += f"_cs{args.chunk_size}_ks{args.kernel_size}_b{args.kv_budget}_tql{args.task_query_len}_wl{args.warmup_layers}_{args.pooling}"
    args.save_path = os.path.join(f"outputs/{model_version}/longbench", args.mode, args.version, args.save_path)
    Path(args.save_path).mkdir(parents=True, exist_ok=True)

    config_logging(os.path.join(args.save_path, f'process.log'))
    logging.info('Arguments: ')
    logging.info(pprint.pformat(vars(args)))
    logging.info('--' * 30)

    # Setup for Configuration
    model2maxlen = json.load(open("eval/longbench/config/model2maxlen.json", "r"))
    if "Llama-3.1-8B" in args.model:
        max_length = model2maxlen["meta-llama/Meta-Llama-3.1-8B-Instruct"]
    else:
        max_length = model2maxlen[args.model]

    if args.longbench_type == "longbench-e":
        datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \
                    "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
    elif args.longbench_type == "longbench":
        datasets = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", \
                    "gov_report", "qmsum", "multi_news", "trec", "triviaqa", "samsum", \
                    "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]

    if args.dataset not in datasets:
        raise ValueError(f"Dataset {args.dataset} not found in datasets")

    # We design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
    dataset2prompt = json.load(open("eval/longbench/config/dateset2prompt_task.json", "r"))
    dataset2maxlen = json.load(open("eval/longbench/config/dataset2maxlen.json", "r"))

    dataset = args.dataset

    if args.longbench_type == "longbench-e":
        data = load_dataset(f"eval/longbench/data", data_files=f"{dataset}_e.jsonl", split="train")
    elif args.longbench_type == "longbench":
        data = load_dataset(f"eval/longbench/data", data_files=f"{dataset}.jsonl", split="train")
    else:
        raise ValueError

    out_path = os.path.join(args.save_path, f"{dataset}.jsonl")

    prompt_format = dataset2prompt[dataset]
    max_gen = dataset2maxlen[dataset]
    data_all = [data_sample for data_sample in data]

    # Load Model & Tokenizer
    logging.info(f'Load Model & Tokenizer...')
    tokenizer = AutoTokenizer.from_pretrained(args.model, device_map='auto', trust_remote_code=True)

    tao_kwargs = None
    if args.mode == "take":
        from take.take.chunk import TakeKwargs
        from take.take.transformers_take.llama.modeling_llama_take import LlamaForCausalLM
        tao_kwargs = TakeKwargs(
            kv_budget=args.kv_budget,
            kv_warmup_budget=args.kv_warmup_budget,
            kv_prune_trigger_size=args.kv_prune_trigger_size,
            chunk_size=args.chunk_size,
            kernel_size=args.kernel_size,
            pooling=args.pooling,
            task_query_len=args.task_query_len,
            warmup_layers=args.warmup_layers,
            chunk_window_size=args.chunk_window_size,
            chunk_sink=args.chunk_sink,
            use_task_cache=args.use_task_cache,
            alpha=args.alpha,
            separators=None,
            test_performance=False
        )
        model = LlamaForCausalLM.from_pretrained(args.model, device_map='auto', attn_implementation='flash_attention_2',
                                                 torch_dtype=torch.float16, tao_kwargs=tao_kwargs,
                                                 tokenizer_path=args.model)
    else:
        model = AutoModelForCausalLM.from_pretrained(args.model, device_map='auto',
                                                     attn_implementation='flash_attention_2',
                                                     torch_dtype=torch.float16)
    model.eval()

    # Generation
    generate_longbench(data=data_all, max_length=max_length, max_gen=max_gen, prompt_format=prompt_format,
                       dataset=dataset, model_name=args.model, model=model, tokenizer=tokenizer,
                       out_path=out_path, args=args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Model Arguments
    parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct",
                        help="model name of model path")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--save_path", default="", type=str, help="Path to save the output")

    # KV Compression
    parser.add_argument("--mode", type=str, default="take",
                        choices=["full_kv", "take"])

    # Evaluation
    parser.add_argument('--dataset', type=str, default='qasper', help="Dataset to evaluate on")
    parser.add_argument('--longbench_type', type=str, default='longbench', choices=['longbench', 'longbench-e'])

    # TAO specific parameters
    parser.add_argument("--version", type=str, default="1.0", help="TAO version, e.g. 1.0, 2.0, 3.0, 4.0")

    parser.add_argument("--kv_compress_ratio", type=float, default=0.5)
    parser.add_argument("--kv_prune_trigger_size", type=int, default=4096)
    parser.add_argument("--chunk_size", type=int, default=4096)
    parser.add_argument("--task_query_len", type=int, default=64)
    parser.add_argument("--warmup_layers", type=int, default=12)
    parser.add_argument("--kv_budget", type=int, default=512)
    parser.add_argument("--kv_warmup_budget", type=int, default=8000)
    parser.add_argument("--kv_warmup_budget_ratio", type=float, default=0.5)
    parser.add_argument("--alpha", type=float, default=0.4)
    parser.add_argument("--chunk_window_size", type=int, default=4)
    parser.add_argument("--chunk_sink", type=int, default=4)
    parser.add_argument("--kernel_size", type=int, default=7)
    parser.add_argument("--pooling", type=str, default="avg")
    parser.add_argument("--use_task_cache", type=bool, default=True)
    args = parser.parse_args()

    main(args)
