from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from vllm import LLM, SamplingParams
import sys
import os
from util import is_equiv,extract_math_answer
from datasets import load_dataset
from budget import budget_cnt


test_str = sys.argv[1]
device_id = sys.argv[2]
os.environ["CUDA_VISIBLE_DEVICES"] = device_id
n_step = sys.argv[3]
path_prefix = ""
model_name = f"{path_prefix}/checkpoints/s1_alpha_02/huggingface_checkpoint/checkpoint_global_step_{n_step}"
output_path = f"{path_prefix}/resdata/s1_alpha_02_{n_step}_{test_str}test_output.json"

temperature = 0.6
max_tokens = 10240
sampling_n = 1
GPUS = 1
tokenizer = AutoTokenizer.from_pretrained(model_name)

if test_str == "Math":
    data_path = f"{path_prefix}/datasets/MATH/test.json"
    dataset = json.load(open(data_path, "r"))
elif test_str == "GSM":
    data_path = f"{path_prefix}/datasets/GSM8K-HF/test_budget.parquet"
    dataset = load_dataset('parquet', data_files=data_path, split='train')
elif test_str == "Omnimath":
    data_path = f"{path_prefix}/datasets/Omni-MATH"
    dataset = load_dataset(data_path, split='test')
elif test_str == "aime":
    data_path = f"{path_prefix}/datasets/aime_2024"
    dataset = load_dataset(data_path, split='train')
elif test_str == "Math500":
    data_path = f"{path_prefix}/datasets/Math500.jsonl"
    dataset = []
    with open(data_path, "r") as f:
        for line in f:
            dataset.append(json.loads(line))

llm = LLM(model=model_name, gpu_memory_utilization=0.7, tensor_parallel_size=GPUS)
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, n=sampling_n)

results = []
input_list = []

def my_chat_template(prompt):
    prefix = 'Answer the given question. You should first estimate the total number of tokens you will need to answer this question based on its difficulty. Then you think about the reasoning process in the mind and provide the user with the answer. The token budget and whole solution are enclosed within <budget> </budget> and <solution> </solution> tags, respectively, i.e., <budget> token budget here, just an integer </budget><solution> solution here, please output the final answer within \\boxed{} </solution>.\n\nQuestion: '
    return "<｜begin▁of▁sentence｜><｜User｜>" + prefix + prompt + "<｜Assistant｜>"

def prefill_template(prompt, num_tokens):
    prefix = 'Answer the given question. You should first estimate the total number of tokens you will need to answer this question based on its difficulty. Then you think about the reasoning process in the mind and provide the user with the answer. The token budget and whole solution are enclosed within <budget> </budget> and <solution> </solution> tags, respectively, i.e., <budget> token budget here, just an integer </budget><solution> solution here, please output the final answer within \\boxed{} </solution>.\n\nQuestion: '
    return "<｜begin▁of▁sentence｜><｜User｜>" + prefix + prompt + "<｜Assistant｜><budget>" + num_tokens + "</budget>"

for data in dataset:
    if test_str in ["Math", "Math500"]:
        # math
        input_list.append(my_chat_template(data['problem']))
    elif test_str == "GSM":
        # gsm
        input_list.append(my_chat_template(data['prompt'][0]['content']))
    elif test_str in ["Omnimath", "aime"]:
        input_list.append(my_chat_template(data['problem']))
    
print("Generating predictions...")
output = llm.generate(input_list, sampling_params)

for idx, item in enumerate(input_list):
    model_answer = output[idx].outputs[0].text.strip()
    response_length = len(output[idx].outputs[0].token_ids)
    if test_str in ["Math", "Math500"]:
        # math    
        para_ques = dataset[idx]['problem']
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = dataset[idx]['answer']
    elif test_str == "GSM":
        # gsm
        para_ques = dataset[idx]['prompt'][0]['content']
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = dataset[idx]['reward_model']['ground_truth']
    elif test_str in ["Omnimath", "aime"]:
        para_ques = dataset[idx]['problem']
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = dataset[idx]['answer']

    t = dataset[idx].copy()
    is_cor = is_equiv(gold,pred)

    t["budget"], budget_min, budget_max = budget_cnt(model_answer)
    if t["budget"]:
        t['budget_match'] = budget_min <= response_length <= budget_max
    else:
        t['budget_match'] = None
    
    t["gold_answer"] = gold
    t["model_response"] = model_answer
    t["model_answer"] = pred
    t["response_length"] = response_length
    t["is_cor"] = is_cor

    results.append(t)

print("Saving results...")
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
print("Results saved to", output_path)
accuracy = len([t for t in results if t["is_cor"]]) / len(results)
length = sum([t['response_length'] for t in results])/len(results)
print(f"Answer Accuracy: {accuracy:.2%}")
print(f"Average Response Length: {length:.2f}")
print(f"Meaningful Budget rate: {len([t for t in results if t['budget_match'] is not None]) / len(results):.2%}")
print(f"Average budget: {sum([t['budget'] for t in results if t['budget_match'] is not None])/len([t for t in results if t['budget_match'] is not None]):.2f}")
print(f"Budget match rate: {len([t for t in results if t['budget_match']]) / len(results):.2%}")