from __future__ import annotations

import json
import os
import re
import torch
import numpy as np
import random
from pathlib import Path

from argparse import ArgumentParser, Namespace

from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)


model2maxlen = {
    "llama2": 3950,
    "llama-2": 3950,
    "llama3": 7950,
    "llama-3-70": 7950,
    "llama-3.1": 127000,
    "mistral": 31500,
    "qwen2.5": 127000,
    "gemma": 7950,
}

DATA_NAME_TO_PATH = {
    # Retrieval tasks
    "passkey": "passkey.jsonl",
    "number_string": "number_string.jsonl",
    "kv_retrieval": "kv_retrieval.jsonl",
    # Book tasks
    "longbook_sum_eng": "longbook_sum_eng.jsonl",
    "longbook_choice_eng": "longbook_choice_eng.jsonl",
    "longbook_qa_eng": "longbook_qa_eng.jsonl",
    "longbook_qa_chn": "longbook_qa_chn.jsonl",
    # "book_qa_eng": "longbook_eng/longbook_qa_eng.jsonl",
    "longdialogue_qa_eng": "longdialogue_qa_eng.jsonl",
    # Math tasks
    "math_find": "math_find.jsonl",
    "math_calc": "math_calc.jsonl",
    # Code tasks
    "code_run": "code_run.jsonl",
    "code_debug": "code_debug.jsonl",
}

dataset2maxlen = {
    "passkey": 15,
    "number_string": 20,
    "kv_retrieval": 80,
    "longbook_sum_eng": 1200,
    "longbook_choice_eng": 40,
    "longbook_qa_eng": 40,
    "longbook_qa_chn": 40,
    "longdialogue_qa_eng": 40,
    "math_find": 3,
    "math_calc": 30000,
    "code_run": 5,
    "code_debug": 5,
}


data2prompt = {
    "passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.\n\n{context}\n\n{input}\n\nThe pass key is",  # noqa
    "number_string": "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n\n{context}\n\n{input}\n\nThe sequence of digits is",  # noqa
    "kv_retrieval": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}",  # noqa
    "longbook_sum_eng": "Summarize the book below.\n\n{context}\n\nSummary:",  # noqa
    "longbook_choice_eng": "Read the book and answer the question.\n\n{context}\n\nQuestion: {question}\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe letter of the correct answer is",  # noqa
    "longbook_qa_eng": "Read the book and answer the question. Be very concise in your answer.\n\n{context}\n\nQuestion: {question}\nAnswer:",  # noqa
    "longbook_qa_chn": "阅读以下书籍然后回答问题。\n\n{context}\n\n问题：{question}\n答案：",  # noqa
    "math_find": "{prefix}\n\n{context}\n\n{input}",
    "math_calc": "Let us calculate the intermediate values of an expression.\n\nExpression: 1 + 3 + 4\nValues: [1, 4, 8]\n\nExpression: 8 - 3 + 2 - 4\nValues: [8, 5, 7, 3]\n\nExpression: {context}\nValues:",  # noqa
    "code_run": "There is a function called {func} in the following Python code.\n\n{context}\n\nPlease compute the exact value of {func_call}. The value of {func_call} is",  # noqa
    "code_debug": "Following is a Python code where exactly one of the functions/methods has a deliberate error that makes it crash.\n\n{context}\n\nOptions:\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe correct option is:",  # noqa
    "longdialogue_qa_eng": 'Below is a dialogue script where one random occurrence of a character name is replaced with "$$MASK$$", and you should try to guess who that character is.\n\n{context}\n\nThe name that has been replaced with $$MASK$$ is likely',  # noqa
}


datasets = [
    "code_debug",
    "code_run",
    "kv_retrieval",
    "longbook_choice_eng",
    "longbook_qa_chn",
    "longbook_qa_eng",
    "longbook_sum_eng",
    "longdialogue_qa_eng",
    "math_calc",
    "math_find",
    "number_string",
    "passkey",
]


def build_chat(prompt):
    prompt = f"[INST] {prompt} [/INST]"
    return prompt


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)


def check_benchmark_availability(data_path):
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    hf_endpoint = "https://huggingface.co"
    if os.environ.get("HF_ENDPOINT"):
        hf_endpoint = os.environ.get("HF_ENDPOINT")

    base_url = f"{hf_endpoint}/datasets/xinrongzhang2022/InfiniteBench/resolve/main/"

    for dataset in datasets:
        file_path = os.path.join(data_path, f"{dataset}.jsonl")
        if not os.path.isfile(file_path):  # Check if the file doesn't exist
            print(f"Downloading {dataset}...")

            wget_command = f"wget {base_url}{dataset}.jsonl?download=true -O {file_path}"
            os.system(wget_command)

    print("All benchmark data ready.")


def iter_jsonl(fname, cnt=None):
    i = 0
    with open(fname, "r") as fin:
        for line in fin:
            if i == cnt:
                break
            yield json.loads(line)
            i += 1


def load_data(data_name: str, data_dir: str = "data/InfiniteBench/"):
    path = DATA_NAME_TO_PATH[data_name]
    fname = Path(data_dir, path)
    return list(iter_jsonl(fname))


def truncate_input(input: list, max_length: int, manner="middle"):
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None


def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)
    print(f"# tokens before: {len_before}")
    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)  # type: ignore
    print(f"# tokens after: {len_after}")
    assert len_after <= len_before
    assert len_after <= max_tokens
    return tok.decode(tokens, skip_special_tokens=True)


def get_answer(eg: dict, data_name: str):
    if data_name in ["code_debug", "longbook_choice_eng"]:
        OPTIONS = "ABCD"
        if isinstance(eg["answer"], str):
            ret = [eg["answer"], OPTIONS[eg["options"].index(eg["answer"])]]
        elif isinstance(eg["answer"], list):
            if len(eg["answer"]) == 1:
                ret = [eg["answer"][0], OPTIONS[eg["options"].index(eg["answer"][0])]]
            elif len(eg["answer"]) == 2 and eg["answer"][1] in ["A", "B", "C", "D"]:
                ret = eg["answer"]
            else:
                raise ValueError
        else:
            raise ValueError
        return ret

    return eg["answer"]


def parse_args() -> Namespace:
    p = ArgumentParser()

    p.add_argument("--seed", type=int, default=42, help="")
    p.add_argument("--base_dir", type=str, default="")
    p.add_argument("--dataset", type=str, default="")
    p.add_argument("--data_file", type=str, default="")
    p.add_argument("--save_dir", type=str, default="")
    p.add_argument("--add_file_name", type=str, default="")

    p.add_argument(
        "--model_name",
        type=str,
        default=None,
        help="if specified, we will load the model to generate the predictions.",
    )
    p.add_argument(
        "--model_path",
        type=str,
        default=None,
        help="if specified, we will load the model to generate the predictions.",
    )
    p.add_argument("--use_fast_tokenizer", type=bool, default=True, help="")
    p.add_argument("--output_attentions", type=bool, default=False, help="")

    p.add_argument(
        "--max_num_examples",
        type=int,
        default=None,
        help="maximum number of examples to evaluate per task.",
    )
    p.add_argument(
        "--sample_method",
        type=str,
        default="topk",
        choices=["random", "topk"],
        help="how to sample the examples.",
    )

    p.add_argument("--max_new_tokens", type=int, default=None, help="")

    p.add_argument(
        "--eval_batch_size", type=int, default=1, help="batch size for evaluation."
    )
    p.add_argument("--window_sizes", type=int, default=32, help="")

    # Parameters for restkv
    p.add_argument("--use_wo", action="store_true", help="Whether to use Wo")
    p.add_argument("--use_ema", action="store_true", help="Whether to use EMA")
    p.add_argument("--use_pyramid", action="store_true", help="Whether to use Pyramid")
    p.add_argument("--alpha", type=float, default=0.2, help="hyper-parameter used in EMA")
    p.add_argument("--kernel_sizes", type=int, default=5, help="")
    p.add_argument(
        "--pooling",
        type=str,
        default="avgpool",
        choices=["maxpool", "avgpool", "adaptive"],
    )

    # Parameters for Ablation Study
    p.add_argument(
        "--ss",
        type=str,
        default="none",
        choices=["none", "avgpool", "maxpool", "adaptive"],
    )
    p.add_argument(
        "--ts", type=str, default="none", choices=["none", "mean", "inv-ema", "ema"]
    )
    p.add_argument(
        "--indicator",
        type=str,
        default="random",
        choices=["random", "attn", "attn-v", "reconstruction"],
    )

    p.add_argument("--use_cache", type=bool, default=True, help="")
    p.add_argument(
        "--attn_implementation",
        type=str,
        default="flash_attention_2",
        choices=["flash_attention_2", "sdpa", "eager"],
    )
    p.add_argument("--method", type=str, default=None)
    p.add_argument("--quant_method", type=str, default=None, choices=["kivi", "kvquant"])
    p.add_argument("--nbits", type=int, default=8, help="")
    p.add_argument("--max_capacity_prompts", type=int, default=512, help="")
    p.add_argument("--max_capacity_prompts_ratio", type=float, default=-1, help="")
    p.add_argument(
        "--steps",
        type=int,
        default=-1,
        help="maximum number of examples to evaluate per task.",
    )
    p.add_argument("--merge", type=str, default=None, help="kv merge method(look-m)")
    p.add_argument(
        "--floor", type=float, default=0.2, help="hyper-parameter used in AdaKV"
    )
    p.add_argument(
        "--head_path",
        type=str,
        default="./data/heads_score/Meta-Llama-3-8B-Instruct_retrieval_reasoning_heads.json",
        help="Path to head score (HeadKV)",
    )
    p.add_argument(
        "--head_beta", type=float, default=1.01, help="hyper-parameter used on HeadKV"
    )
    p.add_argument("--recent_size", type=int, default=32, help="")
    p.add_argument(
        "--pruning_ratio", type=float, default=0.4, help="pruning ratio of Key Cache"
    )

    p.add_argument(
        "--use_chat_format",
        action="store_true",
        help="If given, we will use the chat format for the prompts.",
    )
    p.add_argument(
        "--chat_formatting_function",
        type=str,
        default="eval.templates.create_prompt_with_tulu_chat_format",
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`.",
    )

    return p.parse_args()


def get_length(input: str, tokenizer) -> int:
    """
    Get the length of the input string after tokenization.
    """
    tokens = tokenizer.encode(input)
    return len(tokens)


def create_prompt(eg: dict, data_name: str) -> str:
    """
    Create prompt for a given example.

    Args:
        eg: example dict
        data_name: name of the dataset/task
        model_name: optional, used to fetch model-specific templates.
    """
    # Directly use the appropriate template if the model_name is provided.
    template = data2prompt[data_name]

    # Now create the prompt based on the template and task data
    if data_name == "code_run":
        find_result = re.findall(r"func_[0-9]+\(\-?[0-9]+\)", eg["input"])
        func_call = find_result[0]
        func = func_call.split("(")[0]
        return template.format(
            func=func,
            func_call=func_call,
            context=eg["context"],
        )
    elif data_name in ["code_debug", "code_debug_qa"]:
        code = eg["context"]
        if data_name == "code_debug":
            return template.format(
                context=code,
                OPTION_A=eg["options"][0],
                OPTION_B=eg["options"][1],
                OPTION_C=eg["options"][2],
                OPTION_D=eg["options"][3],
            )
        return template.format(context=code)
    elif data_name == "longdialogue_qa_eng":
        script = eg["context"]
        prompt = template.format(context=script)
        return prompt
    elif data_name in [
        "longbook_choice_eng",
        "longbook_qa_eng",
        "longbook_sum_eng",
        "longbook_qa_chn",
    ]:
        book = eg["context"]
        if data_name == "longbook_choice_eng":
            return template.format(
                question=eg["input"],
                context=book,
                OPTION_A=eg["options"][0],
                OPTION_B=eg["options"][1],
                OPTION_C=eg["options"][2],
                OPTION_D=eg["options"][3],
            )
        elif data_name == "longbook_qa_eng":
            return template.format(
                question=eg["input"],
                context=book,
            )
        elif data_name == "longbook_sum_eng":
            return template.format(context=book)
        elif data_name == "longbook_qa_chn":
            return template.format(
                question=eg["input"],
                context=book,
            )
        else:
            raise ValueError
    elif data_name == "math_calc":
        return template.format(context=eg["context"])
    elif data_name == "math_find":
        prompt = eg["input"]
        context = eg["context"]
        find_result = re.findall(r"The .+ of", prompt)
        assert find_result, f"Cannot find the target number in {prompt}"
        target_number = find_result[0].lower()[:-3]
        prefix = f"What is {target_number} in the following list?"
        return template.format(
            prefix=prefix,
            context=context,
            input=prompt,
        )

    # Default behavior if content key exists
    if "content" in eg:
        content = eg["content"]
        del eg["content"]
        eg["context"] = content

    format_dict = {
        "context": eg["context"],
        "input": eg["input"],
    }
    prompt = template.format(**format_dict)
    return prompt


def truncate_input(input: list, max_length: int, manner="middle"):
    if len(input) <= max_length:
        return input
    if manner == "middle":
        split = max_length // 2
        return input[0:split] + input[-split:]
    else:
        return None


def truncate_by_tokens(input, tok, max_tokens, manner: str = "middle"):
    tokens = tok.encode(input)
    len_before = len(tokens)
    tokens = truncate_input(tokens, max_length=max_tokens, manner=manner)
    len_after = len(tokens)  # type: ignore
    assert len_after <= len_before
    assert len_after <= max_tokens
    return tok.decode(tokens, skip_special_tokens=True)


def main(args):
    print("Loading data...")

    test_data = []

    prompts = []
    inputs = []
    contexts = []
    answerss = []
    datasets = []
    _ids = []
    ground_truths = []

    input_max_len = 0

    model_path = args.model_path.lower()

    for key in model2maxlen:
        if key in model_path:
            model_max_len = model2maxlen[key]

    output_max_len = dataset2maxlen[args.dataset]

    examples = load_data(args.dataset)
    for example in examples:
        example["length"] = get_length(example["context"], tokenizer)
        length = example["length"]
        if length > input_max_len:
            input_max_len = length

        prompt = create_prompt(example, args.dataset)

        if "llama2" in args.model_path.lower():
            prompt = build_chat(prompt)

        example["prompt"] = prompt

        test_data.append(example)

    print(f"Max Length is {input_max_len}")

    if args.max_num_examples and len(test_data) > args.max_num_examples:
        if args.sample_method == "random":
            test_data = random.sample(test_data, args.max_num_examples)
        elif args.sample_method == "topk":
            test_data = test_data[: args.max_num_examples]

    for example in test_data:
        prompts.append(example["prompt"])
        inputs.append(example["input"])
        contexts.append(example["context"])
        answerss.append(example["answer"])
        datasets.append(args.dataset)
        _ids.append(example["id"])
        ground_truths.append(get_answer(example, args.dataset))

    print("Finish loading model and tokenizer")
    model_name = model_path.split("/")[-1]

    os.makedirs(
        os.path.join(
            args.save_dir, f"{model_name}_{args.max_capacity_prompts}", args.dataset
        ),
        exist_ok=True,
    )

    fout = open(
        os.path.join(
            args.save_dir,
            f"{model_name}_{args.max_capacity_prompts}",
            args.dataset,
            f"{args.method}_{args.add_file_name}.json",
        ),
        "w",
    )

    for i in tqdm(range(0, len(prompts), args.eval_batch_size)):
        batch_prompts = prompts[i : i + args.eval_batch_size]

        batch_datasets = datasets[i : i + args.eval_batch_size]
        batch_ids = _ids[i : i + args.eval_batch_size]
        batch_ground_truths = ground_truths[i : i + args.eval_batch_size]

        tokenized_prompts = tokenizer(
            batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=True
        ).to("cuda")
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if len(batch_input_ids[0]) > model_max_len:
            half = int(model_max_len / 2)
            prompt = tokenizer.decode(
                batch_input_ids[0][:half], skip_special_tokens=True
            ) + tokenizer.decode(batch_input_ids[0][-half:], skip_special_tokens=True)

            tokenized_prompts = tokenizer(
                prompt, padding="longest", return_tensors="pt", add_special_tokens=True
            ).to("cuda")
            batch_input_ids = tokenized_prompts.input_ids
            attention_mask = tokenized_prompts.attention_mask

        if args.max_capacity_prompts != -1:
            max_capacity_prompts = args.max_capacity_prompts
        elif args.max_capacity_prompts_ratio != -1:
            max_capacity_prompts = round(
                batch_input_ids.shape[1] * args.max_capacity_prompts_ratio
            )

        if args.method != "FullKV":
            if args.method.lower() in [
                "restkv",
            ]:
                window_sizes = args.window_sizes

            if args.method.lower() == "headkv":
                with open(args.head_path, "r") as file:
                    head_list = json.loads(file.readline())
                head_score_list = [np.mean(l[1]) for l in head_list.items()]
                head_score_list = torch.tensor(head_score_list / sum(head_score_list))
                total_attention = head_score_list.reshape(
                    model.config.num_hidden_layers, model.config.num_attention_heads
                )
                total_pool_capacity = (
                    (args.max_capacity_prompts // args.head_beta)
                    * model.config.num_hidden_layers
                    * model.config.num_attention_heads
                )
                min_num = (
                    args.max_capacity_prompts
                    - args.max_capacity_prompts // args.head_beta
                )
                head_capacity = torch.round(
                    total_attention * total_pool_capacity + min_num
                ).int()
                model.model.config.head_capacity = head_capacity

            kernel_sizes = args.kernel_sizes
            pooling = args.pooling
            ratio = args.pruning_ratio
            recent_size = args.recent_size

            layers = len(model.model.layers)
            # check if window_sizes is a list
            if not isinstance(window_sizes, list):
                window_sizes = [window_sizes] * layers
            if not isinstance(max_capacity_prompts, list):
                max_capacity_prompts = [max_capacity_prompts] * layers
            if not isinstance(kernel_sizes, list):
                kernel_sizes = [kernel_sizes] * layers
            if not isinstance(ratio, list):
                ratio = [ratio] * layers
            if not isinstance(recent_size, list):
                recent_size = [recent_size] * layers
            for i in range(layers):
                model.model.layers[i].self_attn.config.window_size = window_sizes[i]
                model.model.layers[
                    i
                ].self_attn.config.max_capacity_prompt = max_capacity_prompts[i]
                model.model.layers[i].self_attn.config.kernel_size = kernel_sizes[i]
                model.model.layers[i].self_attn.config.pooling = pooling
                model.model.layers[i].self_attn.config.merge = args.merge
                model.model.layers[i].self_attn.config.floor = args.floor
                model.model.layers[i].self_attn.config.ratio = ratio[i]
                model.model.layers[i].self_attn.config.recent_size = recent_size[i]
                if args.method.lower() == "restkv":
                    model.model.layers[i].self_attn.config.use_wo = args.use_wo
                    model.model.layers[i].self_attn.config.use_ema = args.use_ema
                    model.model.layers[i].self_attn.config.use_pyramid = args.use_pyramid
                    model.model.layers[i].self_attn.config.alpha = args.alpha
                elif args.method.lower() == "ablation":
                    model.model.layers[i].self_attn.config.ss = args.ss
                    model.model.layers[i].self_attn.config.ts = args.ts
                    model.model.layers[i].self_attn.config.indicator = args.indicator

        context_length = batch_input_ids.shape[-1]

        if args.quant_method == None:
            output = model.generate(
                **tokenized_prompts,
                output_attentions=args.output_attentions,
                max_new_tokens=output_max_len,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
                min_length=context_length + 1,
                eos_token_id=[tokenizer.eos_token_id],
                pad_token_id=tokenizer.pad_token_id,
            )
        else:
            output = model.generate(
                **tokenized_prompts,
                output_attentions=args.output_attentions,
                max_new_tokens=output_max_len,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
                min_length=context_length + 1,
                eos_token_id=[tokenizer.eos_token_id],
                cache_implementation="quantized",
                cache_config={
                    "nbits": args.nbits,
                    "backend": "HQQ",
                    "device": "cuda",
                    "residual_length": output_max_len,
                    "axis_key": 1,
                    "q_group_size": 64,
                },
            )

        batch_outputs = tokenizer.batch_decode(
            [output[0][context_length:]], skip_special_tokens=True
        )

        batch_generations = batch_outputs

        torch.cuda.empty_cache()

        for j in range(args.eval_batch_size):
            example = {}

            example["ground_truth"] = batch_ground_truths[j]
            example["prediction"] = batch_generations[j]

            example["dataset"] = batch_datasets[j]
            example["id"] = batch_ids[j]

            fout.write(json.dumps(example) + "\n")


if __name__ == "__main__":
    args = parse_args()
    check_benchmark_availability("./data/InfiniteBench/")

    set_seed(args.seed)
    if args.quant_method == "kvquant":
        from restkv.quantcache import KVQuantizedCache
        from transformers import cache_utils

        cache_utils.HQQQuantizedCache = KVQuantizedCache
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_path, use_fast=args.use_fast_tokenizer, padding_side="left"
    )

    from restkv.monkeypatch import (
        replace_llama,
        replace_mistral,
        replace_qwen,
        replace_gemma,
    )

    if "llama" in args.model_path.lower():
        replace_llama(args.method.lower())
    elif "mistral" in args.model_path.lower():
        replace_mistral(args.method.lower())
    elif "qwen" in args.model_path.lower():
        replace_qwen(args.method.lower())
    elif "gemma" in args.model_path.lower():
        replace_gemma(args.method.lower())

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
        use_cache=args.use_cache,
        attn_implementation=args.attn_implementation,
    )

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model.eval()
    save_dir = args.save_dir

    max_capacity_prompts = args.max_capacity_prompts

    for idx, dataset in enumerate(datasets):
        print(
            f"Working on max_capacity_prompts {args.max_capacity_prompts} dataset {dataset} - {idx}/{len(datasets)}"
        )

        args.dataset = dataset

        args.data_file = f"data/InfiniteBench/{args.dataset}.jsonl"

        main(args)
