import os
import sys
import argparse
import json
import torch

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import er_model

data_root = os.environ.get("DATA_ROOT")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="qwen38")
    parser.add_argument("--dataset", type=str, default="gsm8k")
    parser.add_argument("--reward_model", type=str, default="skywork")
    parser.add_argument("--scheduler", type=str, default="part_split")
    parser.add_argument("--prompt_type", type=str, default="better")
    parser.add_argument("--device", type=str, default="cuda:1")
    parser.add_argument("--max_steps", type=int, default=2048)
    parser.add_argument("--prefill_bit", type=int, default=8)
    parser.add_argument("--past_key_values", type=bool, default=None)
    parser.add_argument("--naive_bit", type=str, default="7,6")
    parser.add_argument("--high_bit_steps", type=int, default=512)
    parser.add_argument("--part", type=str, default="cot")
    parser.add_argument("--do_sample", type=bool, default=True)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--output", type=str, default="results/eval_qwen38_gsm8k_test.json", help="Path to save the results JSON file. If not specified, will use timestamp-based filename.")
    return parser.parse_args()


model_path = {
    "qwen7b": f"{data_root}/quantize_model/packed/qwen7b-distill",
    "qwen38": f"{data_root}/quantize_model/packed/qwen3-8b",
}

dataset_path = {
    "gsm8k": f"{data_root}/gsm8k",
    "competition_math": f"{data_root}/efficient-reasoning/competition_math",
}

reward_model_path = {
    "prm": f"{data_root}/Qwen2.5-Math-PRM-7B",
    "skywork": f"{data_root}/Skywork-Reward-Llama-3.1-8B-v0.2",
}

args = parse_args()

args_dict = {
        "model": args.model,
        "model_path": model_path[args.model],
        "dataset": args.dataset,
        "dataset_path": dataset_path[args.dataset],
        "reward_model": args.reward_model,
        "reward_model_path": reward_model_path[args.reward_model],
        "prompt_type": args.prompt_type,
        "scheduler": args.scheduler,
        "device": args.device,
        "max_steps": args.max_steps,
        "prefill_bit": args.prefill_bit,
        "past_key_values": args.past_key_values,
        "naive_bit": [int(bit) for bit in args.naive_bit.split(",")],
        "high_bit_steps": args.high_bit_steps,
        "part": args.part,
        "do_sample": args.do_sample,
        "temperature": args.temperature,
    }
model = er_model(**args_dict)
results = model.evaluate(num_samples=0)
print(results)

output_data = {
    "config": args_dict,
    "results": results
}
output_path = args.output

with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output_data, f, indent=4, ensure_ascii=False)

print(f"\nResults saved to: {output_path}")