import os
from datasets import load_dataset
import torch
import json
from transformers import (
    AutoTokenizer,
    AutoConfig,
    LlamaTokenizer,
    LlamaForCausalLM,
    AutoModelForCausalLM,
    MistralConfig,
)
from tqdm import tqdm
import numpy as np
import random
import argparse
# from evaluation.quest_attention import enable_quest_attention_eval
# from evaluation.llama import enable_tuple_kv_cache_for_llama 
# from evaluation.mistral import enable_tuple_kv_cache_for_mistral

from methods.GEAR_and_OURS.GenerationBench.GenerationTest.GEARLM import SimulatedGearLlamaForCausalLM
from methods.GEAR_and_OURS.GenerationBench.GenerationTest.GEARLM import SimulatedGearMistralForCausalLM
from methods.GEAR_and_OURS.GenerationBench.GenerationTest.GEARLM import CompressionConfig
from methods.GEAR_and_OURS.GenerationBench.GenerationTest.GEARLM import SimulatedGearQwen2ForCausalLM
from methods.GEAR_and_OURS.GenerationBench.GenerationTest.GEARLM import SimulatedGearGemma2ForCausalLM


def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default=None,
        choices=[
            "llama2-7b-chat-4k",
            "llama3-8b",
            "llama3-8b-instruct",
            "qwen2-7b-instruct",
            "gemma2-9b-it",
            "longchat-v1.5-7b-32k",
            "xgen-7b-8k",
            "internlm-7b-8k",
            "chatglm2-6b",
            "chatglm2-6b-32k",
            "chatglm3-6b-32k",
            "vicuna-v1.5-7b-16k",
            "Mistral-7B-Instruct-v0.3",
            "Meta-Llama-3.1-8B-Instruct",
        ],
    )
    parser.add_argument("--e", action="store_true", help="Evaluate on LongBench-E")
    parser.add_argument("--pred_dir", type=str, default="pred")

    parser.add_argument("--task", type=str, help="task name")

    parser.add_argument("--token_budget", type=int, default=None)
    parser.add_argument("--chunk_size", type=int, default=None)
    parser.add_argument("--quest", action="store_true", help="Enable Quest Attention")
    parser.add_argument("--batch_size", type=int, default=1)

    # GEAR
    parser.add_argument("--compress_method", type=str, default="None")
    parser.add_argument("--max_new_tokens", type=int, default=256)
    parser.add_argument("--attention_number", type=int, default=40)
    parser.add_argument("--quantize_bit", type=int, default=2)
    parser.add_argument("--loop", type=int, default=3)
    parser.add_argument("--prefillrank", type=float, default=2.0)
    parser.add_argument("--prefillrankv", type=float, default=2.0)
    parser.add_argument("--rank", type=float, default=2.0)
    parser.add_argument("--rankv", type=float, default=2.0)
    parser.add_argument("--left", type=float, default=0.02)
    parser.add_argument("--streaming", type=str2bool, nargs='?', const=True, default=True)
    parser.add_argument("--streaming_gap", type=int, default=64)
    parser.add_argument("--group_num", type=int, default=0, help="Number of groups for quantization")
    parser.add_argument("--group_size", type=int, default=0, help="Size of each group for quantization")
    parser.add_argument("--top_kprun", type=float, default=0.0, help="Top-k pruning ratio for GEAR")
    parser.add_argument("--stream_grouping", type=str2bool, nargs='?', const=True, default=False, help="Use streaming with grouping.")
    parser.add_argument("--input_axis", type=str, default='right', help="Axis for input Hadamard transform ('left', 'right', or 'None').")
    parser.add_argument("--error_axis", type=str, default='right', help="Axis for error Hadamard transform ('left', 'right', or 'None').")
    parser.add_argument("--first_method", type=str, default='None', help="First method for GEAR ('None', 'KIVI', 'GEAR-KCVT', 'PALU', 'PCC').")
    parser.add_argument("--first_transform", type=str, default='None', help="First transform for GEAR ('None', 'Hadamard', 'PCA', 'COV').")
    parser.add_argument("--second_method", type=str, default='None', help="Second method for GEAR ('None', 'KIVI', 'GEAR-KCVT', 'PALU', 'PCC').")
    parser.add_argument("--second_transform", type=str, default='None', help="Second transform for GEAR ('None', 'Hadamard', 'PCA', 'COV').")
    parser.add_argument("--hla_rank", type=int, default=0, help="HLA rank for GEAR.")
    parser.add_argument("--kv_transform", type=str, default="none", choices=["none", "hadamard", "pca", "cov"], help="Transform to apply in KV cache compression (default: 'none').")
    parser.add_argument("--use_awq", type=str2bool, nargs='?', const=True, default=False, help="Use activation-aware quantization (AWQ) for residual quantization.")
    parser.add_argument("--awq_calibration_samples", type=int, default=128, help="Number of calibration samples for AWQ (default: 128).")

    # Wandb
    parser.add_argument("--use_wandb", type=str2bool, nargs='?', const=True, default=False, help="Enable logging to wandb.")
    parser.add_argument("--wandb_project", type=str, default='PCC', help="Wandb project name.")
    parser.add_argument("--wandb_run_name", type=str, default='gsm8k_gear-kcvt_bit_rank_test', help="Wandb run name.")
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args(args)


# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
    if "chatglm3" in model_name:
        prompt = tokenizer.build_chat_input(prompt)
    elif "chatglm" in model_name:
        prompt = tokenizer.build_prompt(prompt)
    elif "longchat" in model_name or "vicuna" in model_name:
        from fastchat.model import get_conversation_template

        conv = get_conversation_template("vicuna")
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
    elif "llama2" in model_name:
        prompt = f"[INST]{prompt}[/INST]"
    elif "xgen" in model_name:
        header = (
            "A chat between a curious human and an artificial intelligence assistant. "
            "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
        )
        prompt = header + f" ### Human: {prompt}\n###"
    elif "internlm" in model_name:
        prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:"
    return prompt


def post_process(response, model_name):
    if "xgen" in model_name:
        response = response.strip().replace("Assistant:", "")
    elif "internlm" in model_name:
        response = response.split("<eoa>")[0]
    return response


def get_pred(
    model,
    tokenizer,
    data,
    max_length,
    max_gen,
    prompt_format,
    dataset,
    device,
    model_name,
):
    preds = []
    for json_obj in tqdm(data):
        prompt = prompt_format.format(**json_obj)
        # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
        tokenized_prompt = tokenizer(
            prompt, truncation=False, return_tensors="pt"
        ).input_ids[0]
        if "chatglm3" in model_name:
            tokenized_prompt = tokenizer(
                prompt, truncation=False, return_tensors="pt", add_special_tokens=False
            ).input_ids[0]
        if len(tokenized_prompt) > max_length:
            half = int(max_length / 2)
            prompt = tokenizer.decode(
                tokenized_prompt[:half], skip_special_tokens=True
            ) + tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
        if dataset not in [
            "trec",
            "triviaqa",
            "samsum",
            "lsht",
            "lcc",
            "repobench_p",
        ]:  # chat models are better off without build prompts on these tasks
            prompt = build_chat(tokenizer, prompt, model_name)

        # # split the prompt and question (simulate decoding in the question stage)
        # if dataset in ["qasper", "hotpotqa"]:
        #     q_pos = prompt.rfind("Question:")
        # elif dataset in ["multifieldqa_en", "gov_report"]:
        #     q_pos = prompt.rfind("Now,")
        # elif dataset in ["triviaqa"]:
        #     q_pos = prompt.rfind("Answer the question")
        # elif dataset in ["narrativeqa"]:
        #     q_pos = prompt.rfind("Do not provide")
        # else:
        #     q_pos = -1

        # # max simulation length is 100
        # q_pos = max(len(prompt) - 100, q_pos)

        # if q_pos != None:
        #     question = prompt[q_pos:]
        #     prompt = prompt[:q_pos]

        # if "chatglm3" in model_name:
        #     # input = prompt.to(device)
        #     input = prompt.to("cuda")
        # else:
        #     # input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        #     input = tokenizer(prompt, truncation=False, return_tensors="pt").to("cuda")
        #     q_input = tokenizer(question, truncation=False, return_tensors="pt").to(
        #         "cuda"
        #     )
        #     q_input.input_ids = q_input.input_ids[:, 1:]

        # context_length = input.input_ids.shape[-1] + q_input.input_ids.shape[-1]

        if "chatglm3" in model_name:
            if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench_p"]:
                input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
            else:
                input = prompt.to(device)
        else:
            input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        context_length = input.input_ids.shape[-1]

        if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
            output = model.generate(
                **input,
                max_new_tokens=max_gen,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
                min_length=context_length+1,
                eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
            )[0]
        else:
            output = model.generate(
                **input,
                max_new_tokens=max_gen,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
            )[0]
        pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
        pred = post_process(pred, model_name)
        preds.append(
            {
                "pred": pred,
                "answers": json_obj["answers"],
                "all_classes": json_obj["all_classes"],
                "length": json_obj["length"],
            }
        )
        # with open(out_path, "a", encoding="utf-8") as f:
        #     json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}, f, ensure_ascii=False)
        #     f.write('\n')

        # if (
        #     dataset == "samsum"
        # ):  # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
        #     # assert False
        #     output = model.generate(
        #         **input,
        #         max_new_tokens=max_gen,
        #         num_beams=1,
        #         do_sample=False,
        #         temperature=1.0,
        #         min_length=context_length + 1,
        #         eos_token_id=[
        #             tokenizer.eos_token_id,
        #             tokenizer.encode("\n", add_special_tokens=False)[-1],
        #         ],
        #     )[0]
        # else:
        #     with torch.no_grad():
        #         output = model(
        #             input_ids=input.input_ids,
        #             past_key_values=None,
        #             use_cache=True,
        #         )
        #         past_key_values = output.past_key_values
        #         for input_id in q_input.input_ids[0]:
        #             output = model(
        #                 input_ids=input_id.unsqueeze(0).unsqueeze(0),
        #                 past_key_values=past_key_values,
        #                 use_cache=True,
        #             )
        #             past_key_values = output.past_key_values

        #         pred_token_idx = output.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        #         generated_content = [pred_token_idx.item()]
        #         for _ in range(max_gen - 1):
        #             outputs = model(
        #                 input_ids=pred_token_idx,
        #                 past_key_values=past_key_values,
        #                 use_cache=True,
        #             )

        #             past_key_values = outputs.past_key_values
        #             pred_token_idx = (
        #                 outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
        #             )
        #             generated_content += [pred_token_idx.item()]
        #             if pred_token_idx.item() == tokenizer.eos_token_id:
        #                 break

        #     # output = model.generate(
        #     #     **input,
        #     #     max_new_tokens=max_gen,
        #     #     num_beams=1,
        #     #     do_sample=False,
        #     #     temperature=1.0,
        #     # )[0]

        # if dataset != "samsum":
        #     pred = tokenizer.decode(generated_content, skip_special_tokens=True)
        # else:
        #     pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
        # pred = post_process(pred, model_name)
        # preds.append(
        #     {
        #         "pred": pred,
        #         "answers": json_obj["answers"],
        #         "all_classes": json_obj["all_classes"],
        #         "length": json_obj["length"],
        #     }
        # )
    return preds


def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)


def load_model_and_tokenizer(path, model_name, device, args):
    # if 'llama' in model_name.lower() or 'longchat' in model_name.lower():
    #     enable_tuple_kv_cache_for_llama()
    # if 'mistral' in model_name.lower():
    #     enable_tuple_kv_cache_for_mistral()
        
    config = AutoConfig.from_pretrained(
        path,
        use_flash_attn=False,
        trust_remote_code=True,
    )

    # config.rope_scaling = None
    compress_config = (
        None
        if args.compress_method is None
        else CompressionConfig(
            compress_method=args.compress_method,
            rank=args.rank,
            rankv=args.rankv,
            prefill_rank=args.prefillrank,
            prefill_rankv=args.prefillrankv,
            loop=args.loop,
            quantize_bit=args.quantize_bit,
            group_num=args.group_num,
            group_size=args.group_size,
            top_k=args.top_kprun,
            left=args.left,
            attention_number=args.attention_number,
            batch_num=args.batch_size,
            streaming=args.streaming,
            streaming_gap=args.streaming_gap,
            stream_grouping=args.stream_grouping,
            input_axis=args.input_axis if args.input_axis.lower() != 'none' else None,
            error_axis=args.error_axis if args.error_axis.lower() != 'none' else None,
            first_method=args.first_method if args.first_method.lower() != 'none' else None,
            first_transform=args.first_transform if args.first_transform.lower() != 'none' else None,
            second_method=args.second_method if args.second_method.lower() != 'none' else None,
            second_transform=args.second_transform if args.second_transform.lower() != 'none' else None,
            hla_rank=args.hla_rank,
            kv_transform=args.kv_transform if args.kv_transform.lower() != 'none' else None,
        )
    )

    if compress_config is not None:
        compress_config.copy_for_all_attention()
        compress_config.calculate_compress_ratio_list(4095, 4096)

    model_kwargs = {}

    if "Llama-2" or "Mistral" or "Qwen2" or "gemma-2" in path:
        model_kwargs["torch_dtype"] = torch.float16
        model_kwargs["device_map"] = "auto"
        # model_kwargs["token"] = args.hf_token
    else:
        raise ValueError(f"Model {path} not supported")

    if args.compress_method == "None":
        model = AutoModelForCausalLM.from_pretrained(
            path, trust_remote_code=True, config=config, **model_kwargs
        )
        tokenizer = AutoTokenizer.from_pretrained(
            path,
            padding_side="left",
            use_fast=False,
        )
        tokenizer.pad_token = tokenizer.eos_token
    elif "Llama" in path:
        model = SimulatedGearLlamaForCausalLM.from_pretrained(
            path,
            trust_remote_code=True,
            **model_kwargs,
            compress_config=compress_config,
            config=config
        )
        tokenizer = AutoTokenizer.from_pretrained(
            path,
            padding_side="left",
            use_fast=False,
        )
        tokenizer.pad_token = tokenizer.eos_token
    elif "Mistral" in path:
        config = MistralConfig.from_pretrained(
            path,
            use_flash_attn=False,
            trust_remote_code=True,
        )
        model = SimulatedGearMistralForCausalLM.from_pretrained(
            path,
            trust_remote_code=True,
            **model_kwargs,
            compress_config=compress_config,
            config=config
        )
        tokenizer = AutoTokenizer.from_pretrained(
            path,
            padding_side="left",
            use_fast=False,
            trust_remote_code=True,
        )
        tokenizer.pad_token = tokenizer.eos_token
    elif "Qwen2" in path:
        model = SimulatedGearQwen2ForCausalLM.from_pretrained(
            path,
            config=config,
            **model_kwargs,
            compress_config=compress_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            path,
            padding_side="left",
            trust_remote_code=True,
            use_fast=False,
        )
        tokenizer.pad_token = tokenizer.eos_token
    elif "gemma-2" in path:
        model = SimulatedGearGemma2ForCausalLM.from_pretrained(
            path,
            config=config,
            **model_kwargs,
            compress_config=compress_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            path,
            padding_side="left",
            trust_remote_code=True,
            use_fast=False,
        )
        tokenizer.pad_token = tokenizer.eos_token


    model = model.eval()

    # if args.quest:
    #     enable_quest_attention_eval(model, args)

    return model, tokenizer


if __name__ == "__main__":
    seed_everything(42)
    args = parse_args()
    model2path = json.load(open("benchmarks/LongBench/config/model2path.json", "r"))
    model2maxlen = json.load(open("benchmarks/LongBench/config/model2maxlen.json", "r"))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = args.model
    # define your model
    model, tokenizer = load_model_and_tokenizer(
        model2path[model_name], model_name, device, args
    )
    max_length = model2maxlen[model_name]
    if args.e:
        datasets = [
            "qasper",
            "multifieldqa_en",
            "hotpotqa",
            "2wikimqa",
            "gov_report",
            "multi_news",
            "trec",
            "triviaqa",
            "samsum",
            "passage_count",
            "passage_retrieval_en",
            "lcc",
            "repobench_p",
        ]
    else:
        datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
                    "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
                    "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench_p"]
    # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
    dataset2prompt = json.load(open("benchmarks/LongBench/config/dataset2prompt.json", "r"))
    dataset2maxlen = json.load(open("benchmarks/LongBench/config/dataset2maxlen.json", "r"))
    # predict on each dataset
    if not os.path.exists(args.pred_dir):
        os.makedirs(args.pred_dir)
    if not os.path.exists(args.pred_dir + "_e"):
        os.makedirs(args.pred_dir + "_e")
    for dataset in datasets:
        if dataset == "repobench_p":
            dataset = "repobench-p"
        if args.e:
            data = load_dataset("THUDM/LongBench", f"{dataset}_e", split="test", trust_remote_code=True)
            if dataset == "repobench-p":
                dataset = "repobench_p"
            if not os.path.exists(f"{args.pred_dir}_e/{model_name}"):
                os.makedirs(f"{args.pred_dir}_e/{model_name}")
            out_path = f"{args.pred_dir}_e/{model_name}/{dataset}.jsonl"
            if args.quest:
                out_path = f"{args.pred_dir}_e/{model_name}/{dataset}-{args.token_budget}.jsonl"
            else:
                out_path = f"{args.pred_dir}_e/{model_name}/{dataset}-full.jsonl"
            if dataset == "repobench_p":
                dataset = "repobench-p"
        else:
            data = load_dataset("THUDM/LongBench", dataset, split="test", trust_remote_code=True)
            if dataset == "repobench-p":
                dataset = "repobench_p"
            if not os.path.exists(f"{args.pred_dir}/{model_name}"):
                os.makedirs(f"{args.pred_dir}/{model_name}")
            if args.quest:
                out_path = f"{args.pred_dir}/{model_name}/{dataset}-{args.token_budget}.jsonl"
            else:
                out_path = f"{args.pred_dir}/{model_name}/{dataset}-full.jsonl"
            if dataset == "repobench_p":
                dataset = "repobench-p"
        if dataset == "repobench-p":
            dataset = "repobench_p"
        prompt_format = dataset2prompt[dataset]
        max_gen = dataset2maxlen[dataset]
        preds = get_pred(
            model,
            tokenizer,
            data,
            max_length,
            max_gen,
            prompt_format,
            dataset,
            device,
            model_name,
        )
        with open(out_path, "w", encoding="utf-8") as f:
            for pred in preds:
                json.dump(pred, f, ensure_ascii=False)
                f.write("\n")
