import os
import sys
import argparse
import json
import torch

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import er_model

data_root = os.environ.get("DATA_ROOT")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="qwen7b")
    parser.add_argument("--dataset", type=str, default="math")
    parser.add_argument("--reward_model", type=str, default="prm")
    parser.add_argument("--scheduler", type=str, default="score_descent")
    parser.add_argument("--prompt_type", type=str, default="better")
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--max_steps", type=int, default=16384)
    parser.add_argument("--prefill_bit", type=int, default=3)
    parser.add_argument("--past_key_values", type=bool, default=None)
    parser.add_argument("--naive_bit", type=str, default="4,3")
    parser.add_argument("--high_bit_steps", type=int, default=512)
    parser.add_argument("--part", type=str, default="cot")
    parser.add_argument("--do_sample", type=bool, default=True)
    parser.add_argument("--temperature", type=float, default=0.6)
    parser.add_argument("--output", type=str, default="results/test.json", help="Path to save the results JSON file. If not specified, will use timestamp-based filename.")
    parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to evaluate. If not specified, will evaluate all samples.")
    parser.add_argument("--prune_path", type=str, default=None, help="Path to the prune function. If not specified, will not prune.")
    parser.add_argument("--split", type=bool, default=True, help="Path to the prune function. If not specified, will not prune.")
    parser.add_argument("--mean_score", type=str, default="0.7967,0.7590")
    parser.add_argument("--max_score", type=str, default="0.96,0.94")
    parser.add_argument("--sol_precision", type=int, default=3)
    parser.add_argument("--windows", type=int, default=0)
    return parser.parse_args()

args = parse_args()

if args.prune_path is not None:
    prune_path = f"{project_root}/src/code/cot_split/results/{args.prune_path}/data.json"
    with open(prune_path, "r") as f:
        data = json.load(f)
        prune_func = data["column_averages"]
else:
    prune_func = None
    print("prune_func is None")

# Set the CUDA device at the beginning of the script
gpu_idx = int(args.device.split(":")[1])
torch.cuda.set_device(gpu_idx)

args_dict = {
        "model": args.model,
        "dataset": args.dataset,
        "reward_model": args.reward_model,
        "prompt_type": args.prompt_type,
        "scheduler": args.scheduler,
        "device": args.device,
        "max_steps": args.max_steps,
        "prefill_bit": args.prefill_bit,
        "past_key_values": args.past_key_values,
        "naive_bit": [int(bit) for bit in args.naive_bit.split(",")],
        "high_bit_steps": args.high_bit_steps,
        "part": args.part,
        "do_sample": args.do_sample,
        "temperature": args.temperature,
        "prune_func": prune_func,
        "split": args.split,
        "mean_score": [float(score) for score in args.mean_score.split(",")],
        "max_score": [float(score) for score in args.max_score.split(",")],
        "sol_precision": args.sol_precision,
        "windows": args.windows,
    }
model = er_model(**args_dict)

num_samples = None if args.num_samples == -1 else args.num_samples
results, content = model.evaluate(num_samples=num_samples)
# print(results)

output_data = {
    "config": args_dict,
    "results": results
}
output_path = args.output

# Create the output directory if it doesn't exist
os.makedirs(os.path.dirname(output_path), exist_ok=True)

with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output_data, f, indent=4, ensure_ascii=False)

print(f"\nResults saved to: {output_path}")