import os
import json
import random
import argparse
import numpy as np
import torch

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

ruler_datasets = ["niah_single_1", "niah_single_2", "niah_single_3", "niah_multikey_1", "niah_multikey_2", "niah_multikey_3",
                  "niah_multiquery", "niah_multivalue", "cwe", "fwe", "vt"]

ruler_dataset2maxlen = {
    "niah_single_1": 64,
    "niah_single_2": 64,
    "niah_single_3": 64,
    "niah_multikey_1": 64,
    "niah_multikey_2": 64,
    "niah_multikey_3": 64,
    "niah_multiquery": 64,
    "niah_multivalue": 64,
    "cwe": 64,
    "fwe": 64,
    "vt": 64
}

model2maxlen = {
    "llama2": 3950,
    "llama-2": 3950,
    "llama3": 7950,
    "llama3.1": 128000,
    "llama-3": 7950,
    "llama-3.1": 128000,
    "mistral": 31500,
    "qwen2": 31500,
}

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)

def build_chat(prompt, model_name):
    if "llama-3" in model_name.lower():
        prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    elif "llama2" in model_name.lower():
        prompt = f"[INST]{prompt}[/INST]"
    elif "mistral" in model_name.lower():
        prompt = f'<s>[INST] {prompt} [/INST]'
    return prompt

def main(args):
    print("Loading data...")
    
    test_data = []

    with open(args.data_file) as fp:
        for line in fp:
            example = json.loads(line)
            length = example["length"]
            prompt = example["input"] 

            if "llama2" in args.model_path.lower() or "llama-3" in args.model_path.lower() or "mistral" in args.model_path.lower():
                prompt = build_chat(prompt, args.model_path.lower())
            
            example["prompt"] = prompt
            test_data.append(example)

    input_max_len = max([ex["length"] for ex in test_data])
    print(f"Max Length is {input_max_len}")
    
    if args.max_num_examples and len(test_data) > args.max_num_examples:
        if args.sample_method == "random":
            test_data = random.sample(test_data, args.max_num_examples)
        elif args.sample_method == "topk":
            test_data = test_data[:args.max_num_examples]
    
    prompts = [example["prompt"] for example in test_data]
    inputs = [example["input"] for example in test_data] 
    answerss = [example["outputs"] for example in test_data]
    lengths = [example["length"] for example in test_data]
    _ids = [example["index"] for example in test_data]
    
    print("Finish loading model and tokenizer")
    
    model_name_simple = args.model_path.split("/")[-1]

    output_dir = os.path.join(args.save_dir, f"{model_name_simple}_{args.max_capacity_prompts}", str(args.context_length), args.dataset)
    os.makedirs(output_dir, exist_ok=True)
    fout = open(os.path.join(output_dir, f"{args.method}.json"), "w")
    
    model_max_len = 0
    for key in model2maxlen:
        if key in args.model_path.lower():
            model_max_len = model2maxlen[key]
            break
    if model_max_len == 0:
        print(f"Warning: model_max_len not found for {args.model_path}. Using a default of 4096.")
        model_max_len = 4096

    output_max_len = ruler_dataset2maxlen[args.dataset]

    for i in tqdm(range(0, len(prompts), args.eval_batch_size)):
        batch_prompts = prompts[i:i+args.eval_batch_size]
        batch_inputs = inputs[i:i+args.eval_batch_size]
        batch_answerss = answerss[i:i+args.eval_batch_size]
        batch_lengths = lengths[i:i+args.eval_batch_size]
        batch__ids = _ids[i:i+args.eval_batch_size]
        
        tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=True).to('cuda')
        batch_input_ids = tokenized_prompts.input_ids
        attention_mask = tokenized_prompts.attention_mask

        if len(batch_input_ids[0]) > model_max_len:
            half = int(model_max_len / 2)
            prompt_truncated = tokenizer.decode(batch_input_ids[0][:half], skip_special_tokens=True) + \
                               tokenizer.decode(batch_input_ids[0][-half:], skip_special_tokens=True)
            
            tokenized_prompts = tokenizer(prompt_truncated, padding="longest", return_tensors="pt", add_special_tokens=True).to('cuda')
            batch_input_ids = tokenized_prompts.input_ids
            attention_mask = tokenized_prompts.attention_mask

        if args.method.lower() != "fullkv":

            model.model.config.window_size = 8
            model.model.config.base_capacity = args.max_capacity_prompts
            
            model.model.config.kernel_size = 7
            model.model.config.skip = 0
            model.model.config.normalize = True
            model.model.config.pooling = "maxpool"
            model.model.config.floor = 0.2

            
            model.config.skip = 0
            model.config.normalize = True
            model.config.floor = 0.2
            model.config.first_n_token = 0


        context_length = batch_input_ids.shape[-1]

        output = model.generate(
            **tokenized_prompts,
            output_attentions = args.output_attentions,
            max_new_tokens=output_max_len,
            num_beams=1,
            do_sample=False,
            temperature=1.0,
            min_length=context_length+1,
            eos_token_id=[tokenizer.eos_token_id]
        )

        batch_outputs = tokenizer.batch_decode([output[0][context_length:]], skip_special_tokens=True)
        batch_generations = batch_outputs
        
        torch.cuda.empty_cache()
        
        for j in range(args.eval_batch_size):
            example = {}
            example["answers"] = batch_answerss[j]
            example["pred"] = batch_generations[j]
            example["length"] = batch_lengths[j]
            example["_id"] = batch__ids[j]

            
            fout.write(json.dumps(example) + "\n")
    fout.close() 

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--seed", type=int, default=42, help="")
    parser.add_argument("--base_dir", type=str, default="")
    parser.add_argument("--dataset", type=str, default="")
    parser.add_argument("--data_file", type=str, default="")
    parser.add_argument("--save_dir", type=str, default="results_ruler")
    
    parser.add_argument("--model_name", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
    parser.add_argument("--model_path", type=str, default=None, help="if specified, we will load the model to generate the predictions.")
    parser.add_argument("--use_fast_tokenizer", type=bool, default=True, help="")
    parser.add_argument("--output_attentions", type=bool, default=False, help="")
    
    parser.add_argument("--max_num_examples", type=int, default=None, help="maximum number of examples to evaluate per task.")
    parser.add_argument("--sample_method", type=str, default="topk", choices=["random", "topk"], help="how to sample the examples.")
    
    parser.add_argument("--max_new_tokens", type=int, default=None, help="")
    
    parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
    
    parser.add_argument("--use_cache", type=bool, default=True, help="")
    parser.add_argument("--attn_implementation", type=str, default="flash_attention_2", choices=["flash_attention_2", "sdpa", "eager"])
    
    parser.add_argument("--method", type=str, default="FullKV", help="KV cache management method (e.g., FullKV, SnapKV, PyramidKV, StreamingLLM, AdativeKV)")
    
    parser.add_argument("--nbits", type=int, default=8, help="Number of bits for quantization.")

    parser.add_argument("--max_capacity_prompts", type=int, default=512, help="Maximum capacity for prompts in the KV cache.")
    parser.add_argument("--max_capacity_prompts_ratio", type=float, default=-1, help="Ratio for determining max_capacity_prompts.")
    
    
    parser.add_argument("--steps", type=int, default=-1, help="maximum number of examples to evaluate per task.")
    
    parser.add_argument(
        "--use_chat_format", 
        action="store_true", 
        help="If given, we will use the chat format for the prompts."
    )
    parser.add_argument(
        "--chat_formatting_function", 
        type=str, 
        default="eval.templates.create_prompt_with_tulu_chat_format", 
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
    )
    
    args = parser.parse_args()
    
    set_seed(args.seed)

    if args.model_path == '':
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_path,
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_path,
            use_fast=False
        )

    if args.method.lower() != 'fullkv':
        from score_kv.monkeypatch import replace_llama, replace_mistral
        replace_llama(args.method)
        replace_mistral(args.method)
    else: 
        pass
    
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        device_map="auto",
        use_cache=args.use_cache,
        attn_implementation=args.attn_implementation
    )
    
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    
    model.eval()
    
    for context_length in context_length_list:
        for idx, dataset in enumerate(ruler_datasets): 

            args.context_length = context_length 
            args.dataset = dataset 
            
            args.data_file = f"data/RULER/{context_length}/{args.dataset}.jsonl"

            if not os.path.exists(args.data_file):
                print(f"Skipping {args.data_file}: file not found.")
                continue

            main(args)