import os
import sys
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM
from tqdm import tqdm
import numpy as np
import random
import argparse
import time
import torch.distributed as dist
import torch.multiprocessing as mp
from models.modeling_llama_linear import LlamaAttention_LA_SA, LlamaModel_LA_SA_forward, LlamaAttention_LA_SA_1, LlamaAttention_LA_SA_2, LlamaModel_LA_SA_forward_1, LlamaModel_LA_SA_forward_2
from models.modeling_qwen_linear import Qwen2_Linear_forward, Qwen2Attention_linear
from models.modeling_mistral_linear import Mistral_LA_SA_forward, MistralAttention_LA_SA_2
from models.cake_llama import llama_model_forward_cake, llama_attn_forward_cake
from models.cake_mistral import mistral_model_forward_cake, mistral_attn_forward_cake
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel, LlamaDecoderLayer
# from models.based.based.models.gpt import GPTLMHeadModel
from models.cake_cache import CakeprefillKVCache
from models.cake_utils import CompressConfig
import torch


import argparse
from numpy import random
import json

def load_model_and_tokenizer(path, model_name, device, args, rank):
    cache_type = args.cache_type
    from pyramidkv.monkeypatch import replace_llama, replace_mistral
    replace_llama(cache_type.lower())
    replace_mistral(cache_type.lower())

    if args.data_parallel:
        device_map ={"": rank}
    else:
        device_map = "auto"


    tokenizer = AutoTokenizer.from_pretrained(
        path,
        use_fast=args.use_fast_tokenizer,
        padding_side="left",
    )

    model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map=device_map,
        use_cache=args.use_cache,
        attn_implementation=args.attn_implementation,
    )

    # --- 추가 ---
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # 모델 쪽에도 동일하게 반영
    model.config.pad_token_id = tokenizer.pad_token_id
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    # (선택) eos도 명시
    model.config.eos_token_id = tokenizer.eos_token_id
    model.generation_config.eos_token_id = tokenizer.eos_token_id

    tokenizer.padding_side = "left"


    #print(model)

    cache_size = args.cache_size
    if args.cache_type == "FullKV":
        print("using Full KV")
    if args.cache_type in ["snapkv", "pyramidkv", "h2o", "cam", "l2norm", "adakv", "headkv", "think", "streamingllm"]:
        if args.cache_type.lower() in ["snapkv", "pyramidkv", "h2o", "cam", "l2norm", "adakv", "headkv", "think"]:
            window_sizes = args.window_sizes
        elif args.cache_type.lower() in ["streamingllm"]:
            window_sizes = cache_size - 4
        else:
            window_sizes = cache_size

        if args.cache_type.lower() == 'headkv':
            with open(args.head_path, 'r') as file:
                head_list = json.loads(file.readline())
            head_score_list = [np.mean(l[1]) for l in head_list.items()]
            head_score_list = torch.tensor(head_score_list / sum(head_score_list))
            total_attention = head_score_list.reshape(model.config.num_hidden_layers, model.config.num_attention_heads)
            total_pool_capacity = (
                                              args.cache_size // args.head_beta) * model.config.num_hidden_layers * model.config.num_attention_heads
            min_num = (args.cache_size - args.cache_size // args.head_beta)
            head_capacity = torch.round(total_attention * total_pool_capacity + min_num).int()
            model.model.config.head_capacity = head_capacity

        kernel_sizes = 7
        pooling = "maxpool"
        ratio = args.pruning_ratio
        recent_size = args.recent_size

        layers = len(model.model.layers)
        # check if window_sizes is a list
        if not isinstance(window_sizes, list):
            window_sizes = [window_sizes] * layers
        if not isinstance(cache_size, list):
            cache_size = [cache_size] * layers
        if not isinstance(kernel_sizes, list):
            kernel_sizes = [kernel_sizes] * layers
        if not isinstance(ratio, list):
            ratio = [ratio] * layers
        if not isinstance(recent_size, list):
            recent_size = [recent_size] * layers
        for i in range(layers):
            model.model.layers[i].self_attn.config.window_size = window_sizes[i]
            model.model.layers[i].self_attn.config.max_capacity_prompt = cache_size[i]
            model.model.layers[i].self_attn.config.kernel_size = kernel_sizes[i]
            model.model.layers[i].self_attn.config.pooling = pooling
            model.model.layers[i].self_attn.config.merge = args.merge
            model.model.layers[i].self_attn.config.floor = args.floor
            model.model.layers[i].self_attn.config.ratio = ratio[i]
            model.model.layers[i].self_attn.config.recent_size = recent_size[i]
    elif args.cache_type =="LA":
        print("Using LA cache")
        target_dtype = next(model.model.parameters()).dtype
        target_device = next(model.model.parameters()).device
        model.model.forward = LlamaModel_LA_SA_forward.__get__(model.model, LlamaModel)
        #print("LA attention")
        # replace llama attention
        for i, layer in enumerate(model.model.layers):
            ref_attn = layer.self_attn
            custom_attn = LlamaAttention_LA_SA(model.config, i, args.cache_size, args.window_sizes)
            custom_attn.load_state_dict(ref_attn.state_dict(), strict=False)
            custom_attn = custom_attn.to(dtype=target_dtype, device=target_device)
            layer.self_attn = custom_attn
    elif args.cache_type =="LA1":
        print("Using LA1 cache")
        target_dtype = next(model.model.parameters()).dtype
        target_device = next(model.model.parameters()).device
        model.model.forward = LlamaModel_LA_SA_forward_1.__get__(model.model, LlamaModel)
        #print("LA attention")
        # replace llama attention
        for i, layer in enumerate(model.model.layers):
            ref_attn = layer.self_attn
            custom_attn = LlamaAttention_LA_SA_1(model.config, i, args.cache_size, args.window_sizes, args.linear_cache_size, args.alpha, args.beta)
            custom_attn.load_state_dict(ref_attn.state_dict(), strict=False)
            custom_attn = custom_attn.to(dtype=target_dtype, device=target_device)
            layer.self_attn = custom_attn

    elif args.cache_type =="LA2":
        print("Using LA2 cache")
        if "llama" in model_name:
            target_dtype = next(model.model.parameters()).dtype
            target_device = next(model.model.parameters()).device
            model.model.forward = LlamaModel_LA_SA_forward_2.__get__(model.model, LlamaModel)
            #print("LA attention")
            # replace llama attention
            for i, layer in enumerate(model.model.layers):
                ref_attn = layer.self_attn
                custom_attn = LlamaAttention_LA_SA_2(model.config, i, args.cache_size, args.window_sizes, args.linear_cache_size, args.alpha, args.beta)
                custom_attn.load_state_dict(ref_attn.state_dict(), strict=False)
                custom_attn = custom_attn.to(dtype=target_dtype, device=target_device)
                layer.self_attn = custom_attn
        elif "qwen" in model_name:
            target_dtype = next(model.model.parameters()).dtype
            target_device = next(model.model.parameters()).device

            model.model.forward = Qwen2_Linear_forward.__get__(model.model, LlamaModel)
            #print("LA attention")
            # replace llama attention
            for i, layer in enumerate(model.model.layers):
                ref_attn = layer.self_attn
                custom_attn = Qwen2Attention_linear(model.config, i, args.cache_size, args.window_sizes, args.linear_cache_size, args.alpha, args.beta)
                custom_attn.load_state_dict(ref_attn.state_dict(), strict=False)
                custom_attn = custom_attn.to(dtype=target_dtype, device=target_device)
                layer.self_attn = custom_attn
        elif "mistral" in model_name:
            print("using LAKE based mistral")
            target_dtype = next(model.model.parameters()).dtype
            target_device = next(model.model.parameters()).device
            model.model.forward = Mistral_LA_SA_forward.__get__(model.model, LlamaModel)
            #print("LA attention")
            # replace llama attention
            for i, layer in enumerate(model.model.layers):
                ref_attn = layer.self_attn
                custom_attn = MistralAttention_LA_SA_2(model.config, i, args.cache_size, args.window_sizes, args.linear_cache_size, args.alpha, args.beta)
                custom_attn.load_state_dict(ref_attn.state_dict(), strict=False)
                custom_attn = custom_attn.to(dtype=target_dtype, device=target_device)
                layer.self_attn = custom_attn

    elif args.cache_type == "based1.4b":
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        #model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m")

    elif args.cache_type == "cake":
        if "llama" in model_name:
            print("Using cake cache llama")
            compress = True
            cascading = True
            compress_config = CompressConfig(compress, cascading)
            compress_config.cache_size = args.cache_size
            compress_config.window_size = args.window_sizes
            gamma = 200.0
            if args.cache_size==1024:
                tau1 = 1.6
                tau2 = 0.4
            else:
                tau1 = 1.6
                tau2 = 0.6
            # tau1 = 0.9
            # tau2 = 2.3
            hyper = [tau1, tau2, gamma]
            compress_config.hyper = hyper

            model.model.forward = llama_model_forward_cake.__get__(model.model, LlamaModel)
            layers = model.config.num_hidden_layers
            for i, layer in enumerate(model.model.layers):
                layer.self_attn.forward = llama_attn_forward_cake.__get__(layer.self_attn, LlamaAttention)

            model.model.layers[i].self_attn.config.key_size = [
                                                                  compress_config.cache_size - compress_config.window_size] * layers
            model.model.layers[i].self_attn.config.window_size = [compress_config.window_size] * layers
            model.model.layers[i].self_attn.config.prefill = [True] * layers
            model.model.layers[i].self_attn.config.decoding_evict = [None] * layers
            model.model.layers[i].self_attn.config.tau1 = compress_config.hyper[0]
            model.model.layers[i].self_attn.config.tau2 = compress_config.hyper[1]
            model.model.layers[i].self_attn.config.gamma = compress_config.hyper[2]
            model.model.layers[i].self_attn.config.prefill_cake_evict = [CakeprefillKVCache(
                cache_size=compress_config.cache_size,
                window_size=compress_config.window_size,
                k_seq_dim=2,
                v_seq_dim=2,
                num_heads=model.model.layers[i].self_attn.num_heads,
                num_layers=layers,
                use_cascading=compress_config.cascading
            )] * layers

        elif "mistral" in model_name:
            print("Using cake cache")
            compress = True
            cascading = True
            compress_config = CompressConfig(compress, cascading)
            compress_config.cache_size = args.cache_size
            compress_config.window_size = args.window_sizes
            gamma = 200.0
            if args.cache_size == 1024:
                tau1 = 0.8
                tau2 = 0.5
            elif args.cache_size == 512:
                tau1 = 1.0
                tau2 = 0.8
            else:
                tau1 = 1.6
                tau2 = 0.6
            # tau1 = 0.9
            # tau2 = 2.3
            hyper = [tau1, tau2, gamma]
            compress_config.hyper = hyper

            model.model.forward = mistral_model_forward_cake.__get__(model.model, LlamaModel)
            layers = model.config.num_hidden_layers
            for i, layer in enumerate(model.model.layers):
                layer.self_attn.forward = mistral_attn_forward_cake.__get__(layer.self_attn, LlamaAttention)

            model.model.layers[i].self_attn.config.key_size = [
                                                                  compress_config.cache_size - compress_config.window_size] * layers
            model.model.layers[i].self_attn.config.window_size = [compress_config.window_size] * layers
            model.model.layers[i].self_attn.config.prefill = [True] * layers
            model.model.layers[i].self_attn.config.decoding_evict = [None] * layers
            model.model.layers[i].self_attn.config.tau1 = compress_config.hyper[0]
            model.model.layers[i].self_attn.config.tau2 = compress_config.hyper[1]
            model.model.layers[i].self_attn.config.gamma = compress_config.hyper[2]
            model.model.layers[i].self_attn.config.prefill_cake_evict = [CakeprefillKVCache(
                cache_size=compress_config.cache_size,
                window_size=compress_config.window_size,
                k_seq_dim=2,
                v_seq_dim=2,
                num_heads=model.model.layers[i].self_attn.num_heads,
                num_layers=layers,
                use_cascading=compress_config.cascading
            )] * layers
    #print(model)
    model.eval()
    return model, tokenizer


# def load_vllm_model(model_path):
#     llm = LLM(model=model_path)
#     return llm

def generate_prompt_landmark(n_garbage, seed, percent):
    """Generates a text file and inserts an passkey at a random position."""
    rnd_state = random.get_state()
    random.seed(seed)
    # n_garbage_prefix = random.randint(0, n_garbage)
    n_garbage_prefix = int(percent * n_garbage)
    n_garbage_suffix = n_garbage - n_garbage_prefix

    task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
    garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
    garbage_inf = " ".join([garbage] * 50000)
    assert len(garbage_inf) >= n_garbage
    garbage_prefix = garbage_inf[:n_garbage_prefix]
    garbage_suffix = garbage_inf[:n_garbage_suffix]
    pass_key = random.randint(50000, 500000)

    information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
    final_question = "What is the pass key? The pass key is"
    lines = [
        task_description,
        garbage_prefix,
        information_line,
        garbage_suffix,
        final_question,
    ]
    random.set_state(rnd_state)
    return "\n".join(lines), str(pass_key)

def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "llama2-13b-chat-4k", "based1.4b", "mamba-7B", "llama-2-7b-grouped-linear", "llama-3.1-8B-instruct", "qwen2.5-7b-instruct", "mistral-0.3-7b-32k"])
    parser.add_argument('--pretraining_length', type=int, default=32000)
    parser.add_argument('--scale', type=str, default="13b")
    parser.add_argument('--max_length', type=str, default="256k")
    parser.add_argument('--min_length', type=str, default="1k")
    parser.add_argument('--gap', type=str, default="8k")
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--num_tests', type=int, default=10, help='number of repeat testing for each length')
    parser.add_argument('--cache_type', type=str, default="FullKV",
                        choices=["cake", "FullKV", "LA", "snapkv", "pyramidkv", "H2O", "cam", "l2norm", "adakv",
                                 "headkv", "think", "LA1", "LA2"])
    parser.add_argument("--cache_size", type=int, default=128, help="")
    parser.add_argument("--window_sizes", type=int, default=96, help="")
    parser.add_argument("--linear_cache_size", type=int, default=0, help="")
    parser.add_argument("--alpha", type=float, default=0.6, help="alpha param of score")
    parser.add_argument("--beta", type=float, default=0.5, help="beta param of score")
    parser.add_argument("--use_cache", type=bool, default=True, help="")
    parser.add_argument("--attn_implementation", type=str, default="flash_attention_2",
                        choices=["flash_attention_2", "sdpa", "eager"])
    parser.add_argument("--data_parallel", type=bool, default=True, help="")
    parser.add_argument("--use_fast_tokenizer", type=bool, default=True, help="")

    args = parser.parse_args()
    return args


# if __name__ == "__main__":
#     max_gen = 128000
#
#
#     args = parse_config()
#
#     output_name = f"output.jsonl"
#     print("results will be save to:", output_name)
#     model_path = args.model
#     model = load_vllm_model(model_path)
#
#     # hyper params
#     k = 1000
#     max_length = int(args.max_length.replace("k", '')) * k
#     min_length = int(args.min_length.replace("k", '')) * k
#     gap = int(args.gap.replace("k", '')) * k
#     num_per = 10
#     depth_percent = 1 / num_per
#
#     # length_list = [k] + [i for i in range(4*k, max_length + 1, gap)]
#     length_list = [i for i in range(min_length, max_length + 1, gap)]
#
#     results = []
#     # sampling_params = SamplingParams(
#     #     temperature=0.0,
#     #     top_p=1.0,
#     #     use_beam_search=True,
#     #     best_of=4,
#     #     max_tokens=5,
#     # )
#
#     for length in length_list:
#         # This is a rough ratio to control the number of texts and tokens
#         n_garbage = int(3.75 * length // k * k)
#
#         depths = [depth_percent * i for i in range(1, num_per + 1)]
#         for depth in depths:
#             passed_tests = 0
#             all_accuries = {}
#             prompts = []
#             answers = []
#             for j in range(args.num_tests):
#                 prompt, answer = generate_prompt_landmark(n_garbage, j, depth)
#                 prompts.append(prompt)
#                 answers.append(answer)
#
#
#
#             outputs = model.generate(
#                 **input,
#                 max_new_tokens=max_gen,
#                 num_beams=1,
#                 do_sample=False,
#                 temperature=1.0,
#                 top_p=None,
#                 top_k=None,
#                 # attention_mask=attention_mask,
#             )[0]
#
#
#
#             for output, answer in zip(outputs, answers):
#                 print("[prediction]:  ", repr(output.outputs[0].text))
#                 print("[ground truth]:  ", repr(answer))
#                 if answer in output.outputs[0].text:
#                     passed_tests += 1
#             accuracy = float(passed_tests) / args.num_tests
#             res = {"context_length": f"{length // k}k", "depth_percent": depth * 100, "score": accuracy}
#             results.append(res)
#             print(res)
#             with open(output_name, "a") as f:
#                 print(json.dumps(res), file=f)

if __name__ == "__main__":
    max_gen = 128000
    args = parse_config()

    output_name = f"output.jsonl"
    print("results will be saved to:", output_name)
    model2path = json.load(open("config/model2path.json", "r"))


    # ✅ transformers 기반으로 모델 로드
    model, tokenizer = load_model_and_tokenizer(
        path=model2path[args.model],              # 모델 경로
        model_name=args.model,        # 모델 이름
        device=args.gpu,              # 사용할 GPU id
        args=args,                    # argparse 인자
        rank=0                        # single GPU 기준
    )

    # hyper params
    k = 1000
    max_length = int(args.max_length.replace("k", '')) * k
    min_length = int(args.min_length.replace("k", '')) * k
    gap = int(args.gap.replace("k", '')) * k
    num_per = 10
    depth_percent = 1 / num_per

    length_list = [i for i in range(min_length, max_length + 1, gap)]
    results = []

    for length in length_list:
        n_garbage = int(3.75 * length // k * k)
        depths = [depth_percent * i for i in range(1, num_per + 1)]

        for depth in depths:
            passed_tests = 0

            for j in range(args.num_tests):
                prompt, answer = generate_prompt_landmark(n_garbage, j, depth)

                # 개별 토크나이즈 (배치 아님)
                inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

                # 개별 generate
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_gen,
                    num_beams=1,
                    do_sample=False,
                    temperature=1.0,
                    top_p=None,
                    top_k=None,
                )

                # 개별 decode
                pred = tokenizer.decode(outputs[0], skip_special_tokens=True)

                print("[prediction]:  ", repr(pred))
                print("[ground truth]:", repr(answer))
                if answer in pred:
                    passed_tests += 1

            accuracy = float(passed_tests) / args.num_tests
            res = {
                "context_length": f"{length // k}k",
                "depth_percent": depth * 100,
                "score": accuracy,
            }
            results.append(res)
            print(res)
            with open(output_name, "a") as f:
                print(json.dumps(res), file=f)

        # for depth in depths:
        #     passed_tests = 0
        #     prompts, answers = [], []
        #     for j in range(args.num_tests):
        #         prompt, answer = generate_prompt_landmark(n_garbage, j, depth)
        #         prompts.append(prompt)
        #         answers.append(answer)
        #
        #     # ✅ transformers generate 사용
        #     inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
        #
        #     outputs = model.generate(
        #         **inputs,
        #         max_new_tokens=max_gen,
        #         num_beams=1,
        #         do_sample=False,
        #         temperature=1.0,
        #         top_p=None,
        #         top_k=None,
        #     )
        #
        #     decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        #
        #     for pred, answer in zip(decoded, answers):
        #         print("[prediction]:  ", repr(pred))
        #         print("[ground truth]:", repr(answer))
        #         if answer in pred:
        #             passed_tests += 1
        #
        #     accuracy = float(passed_tests) / args.num_tests
        #     res = {
        #         "context_length": f"{length // k}k",
        #         "depth_percent": depth * 100,
        #         "score": accuracy,
        #     }
        #     results.append(res)
        #     print(res)
        #     with open(output_name, "a") as f:
        #         print(json.dumps(res), file=f)
