import json
import os
import random
import argparse

import numpy as np
from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers import Qwen3ForCausalLM, Qwen3Config
from rnsa.cache_utils import RNSACache
from rnsa.qwen3 import RNSAQwen3ForCausalLM, RNSAQwen3Config
from rnsa.llama import RNSALlamaForCausalLM, RNSALlamaConfig

dataset2key = {
    "gsm8k": ["question", "answer"],
    "aime24": ["question", "answer"],
    "math": ["problem", "answer"],
}

dataset2max_length = {
    "gsm8k": 8192,
    "aime24": 16384,
    "math": 8192,
}


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)


prompt_template = "You are given a math problem.\n\nProblem: {question}\n\n You need to solve the problem step by step. First, you need to provide the chain-of-thought, then provide the final answer.\n\n Provide the final answer in the format: Final answer:  \\boxed{{}}"


def apply_chat_template(prompts, tokenizer):
    chat_prompts = []

    for prompt in prompts:
        messages = [
            {"role": "user", "content": prompt}
        ]
        chat_messages = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
        chat_prompts.append(chat_messages)
    return chat_prompts


def main(args):
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)

    fout = open(args.save_path, "w")

    prompts = []
    test_data = []

    with open(args.dataset_path) as f:
        for index, line in enumerate(f):
            example = json.loads(line)
            question_key = dataset2key[args.dataset_name][0]

            question = example[question_key]
            example["question"] = question
            prompt = prompt_template.format(**example)

            example["prompt"] = prompt
            example["index"] = index
            prompts.append(prompt)
            test_data.append(example)


    for i in tqdm(range(0, len(prompts), args.eval_batch_size)):
        batch_prompts = prompts[i : i + args.eval_batch_size]
        batch_prompts = apply_chat_template(batch_prompts, tokenizer)
        tokenized_prompts = tokenizer(
            batch_prompts,
            padding="longest",
            return_tensors="pt",
            add_special_tokens=True,
        ).to("cuda")

        if i < 2:
            print("Input prompt:\n", batch_prompts)

        prefill_lengths = tokenized_prompts["attention_mask"].sum(dim=1).tolist()

        if args.method == 'rnsa':
            past_key_values = RNSACache(
                max_seq_len=args.max_length,
                device="cuda",
            )
        else:
            past_key_values = None


        start_time = torch.cuda.Event(enable_timing=True)
        end_time = torch.cuda.Event(enable_timing=True)
        start_time.record()

        print("start generating")
        output = model.generate(
            **tokenized_prompts,
            max_length=args.max_length,
            do_sample=False,
            num_beams=1,
            past_key_values=past_key_values,
        )
        end_time.record()
        torch.cuda.synchronize()
        generation_time = start_time.elapsed_time(end_time) / 1000.0  # in seconds
        print(f"Generation time for batch {i // args.eval_batch_size}: {generation_time:.2f} seconds")

        batch_token_stats = []
        for j in range(output.size(0)):
            total_tokens = int((output[j] != tokenizer.pad_token_id).sum().item())

            prefill = prefill_lengths[j]
            output_tokens = total_tokens - prefill

            batch_token_stats.append(
                {
                    "sample_idx": i + j,
                    "prefill_tokens": prefill,
                    "output_tokens": output_tokens,
                    "total_tokens": total_tokens,
                }
            )

        batch_outputs = tokenizer.batch_decode(
            [output[j][prefill_lengths[j] :] for j in range(output.size(0))],
            skip_special_tokens=True,
        )

        if i < 2:
            print("Output prediction\n", batch_outputs)

        torch.cuda.empty_cache()

        for j in range(len(batch_outputs)):
            sample_idx = batch_token_stats[j]["sample_idx"]
            test_data[sample_idx]["prompt"] = batch_prompts[j]
            test_data[sample_idx]["output"] = batch_outputs[j]
            test_data[sample_idx]["prefill_tokens"] = batch_token_stats[j]["prefill_tokens"]
            test_data[sample_idx]["output_tokens"] = batch_token_stats[j]["output_tokens"]
            test_data[sample_idx]["total_tokens"] = batch_token_stats[j]["total_tokens"]
            test_data[sample_idx]["sample_idx"] = batch_token_stats[j]["sample_idx"]

            fout.write(json.dumps(test_data[sample_idx], ensure_ascii=False) + "\n")

        fout.flush()

    fout.close()


def parse_arguments():
    parser = argparse.ArgumentParser()

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--dataset_path", type=str)
    parser.add_argument("--save_path", type=str)
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--attn_impl", type=str, default="rnsa_flex")
    parser.add_argument("--max_length", type=int, default=-1)
    parser.add_argument("--eval_batch_size", type=int, default=1)

    # method config
    parser.add_argument(
        "--method",
        type=str,
        default=None,
        choices=["rkv", "fullkv", "snapkv", "streamingllm", "h2o", "rnsa"],
    )
    parser.add_argument("--kv_budget", type=int, default=None)
    parser.add_argument("--window_size", type=int, default=None)
    parser.add_argument("--compress_strategy", type=str, default="knorm_alpha")
    parser.add_argument("--first_tokens", type=int, default=4)
    parser.add_argument("--mix_lambda", type=float, default=0.07)
    parser.add_argument("--retain_ratio", type=float, default=0.2)
    parser.add_argument("--update_kv", type=str, default="True", choices=["True", "False"])
    parser.add_argument(
        "--retain_direction", type=str, default="last", choices=["last", "first"]
    )

    # model config
    parser.add_argument(
        "--divide_method",
        type=str,
        default="step_length",
        choices=["newline", "step_length"],
    )
    parser.add_argument("--divide_length", type=int, default=128)
    parser.add_argument(
        "--compression_content",
        type=str,
        default="all",
        choices=["think", "all"],
        help="whether to compress the whole model output or only the think part",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    set_seed(args.seed)

    args.dataset_name = args.dataset_path.split("/")[-1].split(".")[0]
    if args.max_length == -1: args.max_length = dataset2max_length[args.dataset_name]

    # ====== build compression config ======
    compression_config = {
        "method": args.method,
        "method_config": {
            "budget": args.kv_budget,
            "window_size": args.window_size,
            "mix_lambda": args.mix_lambda,
            "retain_ratio": args.retain_ratio,
            "retain_direction": args.retain_direction,
            "first_tokens": args.first_tokens,
        },
        "compression": None,
        "update_kv": args.update_kv
    }
    model_config = {
        "divide_method": args.divide_method,
        "divide_length": args.divide_length,
        "compression_content": args.compression_content,

    }

    if args.method == 'rnsa':
        if "qwen3" in args.model_path.lower():
            model_cls = RNSAQwen3ForCausalLM
        elif "llama" in args.model_path.lower():
            model_cls = RNSALlamaForCausalLM
        else:
            raise ValueError

        model = model_cls.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16,
            load_rnsa_weights=True,
            low_cpu_mem_usage=True,
            use_cache=True,
            device_map="cuda",
        )
        model.config.attn_impl = args.attn_impl
        model.config.compress_memory = (args.update_kv == 'True')
        model.config.memory_size = args.kv_budget
        model.config.compress_strategy = args.compress_strategy
        print(f"Using RNSA model with config: {model.config}")
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            device_map="auto",
            use_cache=True,
            attn_implementation=args.attn_impl,
        )
        model.config.base_model = args.model_path

    model.eval()

    model.config.update(model_config)

    tokenizer = AutoTokenizer.from_pretrained(
        model.config.base_model, use_fast=True, padding_side="left"
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    main(args)
