from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from vllm import LLM, SamplingParams
import sys
import os
from util import is_equiv,extract_math_answer
from datasets import load_dataset
# export MKL_SERVICE_FORCE_INTEL=TRUE

model_type = sys.argv[1]
# Math, MetaMath, GSM, Still3
teststr = sys.argv[2]
start = 0
end = 1000000
path_prefix = ""
model_dict = {
    "math": f"{path_prefix}/Qwen2.5-Math-72B-Instruct",
    "r1": f"{path_prefix}/DeepSeek-R1-Distill-Qwen-1___5B",
    "qwq": f"{path_prefix}/QwQ-32B",
}
model_name = model_dict[model_type]
if teststr == "Math":
    data_path = f"{path_prefix}/MATH/train.json"
    output_path = f"./data/Mathtrain_{model_type}_06.json"
elif teststr == "MetaMath":
    data_path = f"{path_prefix}/MetaMathQA/MetaMathQA-395K.json"
    output_path = f"{path_prefix}/resdata/overthinking/MetaMath_{model_type}_res_300k.json"
elif teststr == "GSM":
    data_path = f"{path_prefix}/GSM8K/grade_school_math/data/train.jsonl"
    output_path = f"./data/GSMtrain_{model_type}_06.json"
elif teststr == "Still3":
    data_path = f"{path_prefix}/STILL-3-Preview-RL-Data-HF/data/train-00000-of-00001.parquet"
    output_path = f"./data/s1K-1.1_train_{model_type}_06_1.json"
    
temperature = 0.6
max_tokens = 8192
GPUS = 1

if teststr != "GSM" and teststr != "Still3":
    with open(data_path, "r") as f:
        train_data = json.load(f)
elif teststr == "Still3":
    train_data = load_dataset("parquet", data_files=data_path, split='train')
else:
    train_data = []
    with open(data_path, 'r', encoding='utf-8') as file:
        for line in file:
            train_data.append(json.loads(line))

llm = LLM(model=model_name, gpu_memory_utilization=0.7, tensor_parallel_size=GPUS)
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)

results = []
input_list = []

for item in train_data:
    if teststr == "GSM":
        # GSM
        text = item["question"]
    elif teststr == "MetaMath":
        # MetaMath
        text = item["query"]
    elif teststr == "Math":
        # MATH
        text = item["problem"]
    elif teststr == "Still3":
        text = item["question"]
    prompt = (
        f"{text}\n"
        "Please put your final answer within \\boxed{}."
    )
    input_list.append(prompt)

print("Generating predictions...")
output = llm.generate(input_list, sampling_params)

for idx, item in enumerate(input_list):
    model_answer = output[idx].outputs[0].text.strip()
    if teststr == "GSM":
        # GSM
        para_ques = train_data[idx]["question"]
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = extract_math_answer(para_ques,train_data[idx]["answer"],"")
    elif teststr == "MetaMath":
        # MetaMath
        para_ques = train_data[idx]["query"]
        pred = extract_math_answer(para_ques,model_answer,train_data[idx]["type"])
        gold = extract_math_answer(para_ques,train_data[idx]["response"],train_data[idx]["type"])
    elif teststr == "Math":
        # MATH
        para_ques = train_data[idx]["problem"]
        gold = train_data[idx]['answer']
        pred = extract_math_answer(para_ques,model_answer,train_data[idx]["type"])
    elif teststr == "Still3":
        para_ques = train_data[idx]["question"]
        gold = train_data[idx]['answer']
        pred = extract_math_answer(para_ques,model_answer,"")

    t = train_data[idx].copy()
    is_cor = is_equiv(gold,pred)
    t["token_count"] = len(output[idx].outputs[0].token_ids)
    t["gold_answer"] = gold
    t["model_response"] = model_answer
    t["model_answer"] = pred
    t["is_cor"] = is_cor

    results.append(t)

print("Saving results...")
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
print("Results saved to", output_path)
accuracy = len([t for t in results if t["is_cor"]]) / len(results)
print(f"Classification Accuracy: {accuracy:.2%}")