# -*- coding: utf-8 -*-
import os
import re
import json
import argparse
from tqdm import tqdm
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
import torch

SYSTEM_PROMPT = r"You are a helpful assistant. You need to solve some math problems and present the answers enclosed in boxed{}."



def extract_answer_from_output(output: str) -> str:
    match = re.search(r"boxed\{(.+?)\}", output)
    if match:
        return match.group(1).strip()
    return ""

def extract_ground_truth(answer_str: str) -> str:
    if "####" in answer_str:
        return answer_str.split("####")[-1].strip()
    return ""

def evaluate_checkpoint(model_path, tokenizer, test_dataset, max_new_tokens=660, batch_size=48):
    print(f"\n[+] Loading model from: {model_path}")
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )
        model.eval()
    except Exception as e:
        print(f"[!] Failed to load {model_path}: {e}")
        return None

    correct = 0
    total = len(test_dataset)

    questions = [sample['question'] for sample in test_dataset]
    ground_truths = [extract_ground_truth(sample['answer']) for sample in test_dataset]

    num_batches = (total + batch_size - 1) // batch_size
    for i in tqdm(range(num_batches), desc=f"Evaluating {os.path.basename(model_path)}"):
        batch_questions = questions[i * batch_size : (i + 1) * batch_size]
        batch_gt = ground_truths[i * batch_size : (i + 1) * batch_size]

        batch_messages = [
            [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": q}
            ]
            for q in batch_questions
        ]

        prompts = tokenizer.apply_chat_template(
            batch_messages,
            add_generation_prompt=True,
            tokenize=False
        )

        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=False,
            add_special_tokens=False,
        ).to(model.device)

        input_length = inputs.input_ids.shape[1]

        try:
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
        except Exception as e:
            print(f"[!] Generation error in batch {i}: {e}")
            continue

        for j in range(len(batch_questions)):
            gen_tokens = outputs[j][input_length:]
            response = tokenizer.decode(gen_tokens, skip_special_tokens=True)
            pred_answer = extract_answer_from_output(response)

            if pred_answer == batch_gt[j]:
                correct += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"[✓] Accuracy: {accuracy:.4f} ({correct}/{total})")
    del model
    torch.cuda.empty_cache()
    return accuracy

def main():
    parser = argparse.ArgumentParser(description="Evaluate all checkpoints on GSM8K.")
    parser.add_argument("--model_root", type=str, required=True)
    parser.add_argument("--output_file", type=str, default="gsm8k_accuracies.txt")
    parser.add_argument("--gsm8k_test_path", type=str, required=True)
    parser.add_argument("--max_new_tokens", type=int, default=1024)
    args = parser.parse_args()

    print("[+] Loading GSM8K test set...")
    test_dataset = load_dataset(
        "parquet",
        data_files={"test": args.gsm8k_test_path}
    )["test"]

    tokenizer_path = args.model_root
    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen3-4B",
        trust_remote_code=True,
        local_files_only=True,
        padding_side="left",
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    checkpoint_dirs = []
    for item in os.listdir(args.model_root):
        full_path = os.path.join(args.model_root, item)
        if os.path.isdir(full_path) and item.startswith("checkpoint-"):
            checkpoint_dirs.append(full_path)
    checkpoint_dirs.sort(key=lambda x: int(x.split("-")[-1]))

    if not checkpoint_dirs:
        raise ValueError(f"No 'checkpoint-*' directories found in {args.model_root}")

    print(f"[+] Found {len(checkpoint_dirs)} checkpoints to evaluate.")

    results = {}
    for ckpt_dir in checkpoint_dirs:
        acc = evaluate_checkpoint(ckpt_dir, tokenizer, test_dataset, args.max_new_tokens)
        if acc is not None:
            results[ckpt_dir] = acc
        else:
            results[ckpt_dir] = -1.0

    with open(args.output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=4)

    print(f"\n[✓] All evaluations done. Results saved to: {args.output_file}")

if __name__ == "__main__":
    main()