import json
import torch
from tqdm import tqdm
import fire
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from utils.grader import check_is_correct
from utils.parser import extract_answer

import os
import sys
sys.path.append('PyramidKV')

from pyramidkv.monkeypatch import replace_llama

method = "pyramidkv"
replace_llama(method.lower())


def evaluate_prediction(model_pred: str, gold_answer: str) -> bool:
    extracted_answer = extract_answer(model_pred)
    return check_is_correct(extracted_answer, gold_answer)

def requestLRM(model, tokenizer, prompt, max_new_tokens=16384):
    system_prompt = "Please reason step by step, and put your final answer within \\boxed{}."
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    with torch.inference_mode():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=max_new_tokens
        )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

def compute_pass_at_1(data_path: str, model_path: str, output_path: str) -> float:
    with open(data_path, 'r') as f:
        dataset = f.readlines()
        dataset = [json.loads(l) for l in dataset]

    config = AutoConfig.from_pretrained(model_path)
    config._attn_implementation = "sdpa"
    config.window_size = 32       
    config.max_capacity_prompt = 64

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        attn_implementation=config._attn_implementation,
        device_map="auto",
    )

    tokenizer = AutoTokenizer.from_pretrained(model_path)

    model.config.window_size = config.window_size
    model.config.max_capacity_prompt = config.max_capacity_prompt

    correct = 0
    total = len(dataset)

    with tqdm(total=total, desc="Evaluating", ncols=80) as pbar, open(output_path, 'a+') as fout:
        for i, item in enumerate(dataset, 1):
            gold_answer = item["answer"]
            problem = item["problem"]

            model_pred = requestLRM(model, tokenizer, problem)
            is_correct = evaluate_prediction(model_pred, gold_answer)

            try:
                for layer_idx in range(model.config.num_hidden_layers):
                    delattr(model.model.layers[layer_idx].self_attn, "kv_seq_len")
            except:
                pass


            result = {
                "problem": problem,
                "gold_answer": gold_answer,
                "model_pred": model_pred,
                "is_correct": is_correct,
            }
            if i == 0:
                print(result)
            fout.write(json.dumps(result, ensure_ascii=False) + '\n')

            if is_correct:
                correct += 1
            current_pass1 = correct / i
            pbar.set_description(f"pass@1: {current_pass1:.2%}")
            pbar.update(1)

    final_pass1 = correct / total if total > 0 else 0.0
    print(f"\nFinal pass@1: {final_pass1:.2%}")
    return final_pass1

def main(model_path: str, data_path: str, output_path: str):
    compute_pass_at_1(data_path, model_path, output_path)

if __name__ == "__main__":
    fire.Fire(main)
