from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from vllm import LLM, SamplingParams
import sys
import os
from util import is_equiv,extract_math_answer
from datasets import load_dataset
from budget import budget_cnt


test_str = sys.argv[1]
device_id = sys.argv[2]
os.environ["CUDA_VISIBLE_DEVICES"] = device_id
n_step = sys.argv[3]
num_tokens = int(sys.argv[4])

path_prefix = ""

model_name = f"{path_prefix}/checkpoints/l1_max_replication_still3_0514/huggingface_checkpoint/checkpoint_global_step_{n_step}"
output_path = f"{path_prefix}/resdata/l1_max_replication_still3_0514_{n_step}_{test_str}_output.json"
model_name = f"{path_prefix}/checkpoints/l1_max_replication_still3_0514/huggingface_checkpoint/checkpoint_global_step_{n_step}"
output_path = f"{path_prefix}/resdata/l1_max_replication_still3_0514_{n_step}_{test_str}_output.json"
temperature = 0.6
max_tokens = 10240
sampling_n = 1
GPUS = 1
tokenizer = AutoTokenizer.from_pretrained(model_name)

if test_str == "Math":
    data_path = f"{path_prefix}/datasets/MATH/test.json"
    dataset = json.load(open(data_path, "r"))
elif test_str == "GSM":
    data_path = f"{path_prefix}/datasets/GSM8K-HF/test_budget.parquet"
    dataset = load_dataset('parquet', data_files=data_path, split='train')
elif test_str == "Omnimath":
    data_path = f"{path_prefix}/datasets/Omni-MATH"
    dataset = load_dataset(data_path, split='test')
elif test_str == "aime":
    data_path = f"{path_prefix}/datasets/aime_2024"
    dataset = load_dataset(data_path, split='train')
elif test_str == "Math500":
    data_path = f"{path_prefix}/datasets/Math500.jsonl"
    dataset = []
    with open(data_path, "r") as f:
        for line in f:
            dataset.append(json.loads(line))

llm = LLM(model=model_name, gpu_memory_utilization=0.7, tensor_parallel_size=GPUS)
sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens, n=sampling_n)

results = []
input_list = []

def my_chat_template(prompt, num_tokens):
    return prompt + "\n\nLet's think step by step and output the final answer within \\boxed{}." + f" Think for {num_tokens} tokens."   
    

for data in dataset:
    if test_str in ["Math", "Math500"]:
        # math
        input_list.append(my_chat_template(data['problem'], num_tokens))
    elif test_str == "GSM":
        # gsm
        input_list.append(my_chat_template(data['extra_info']['question'], num_tokens))
    elif test_str in ["Omnimath", "aime"]:
        input_list.append(my_chat_template(data['problem'], num_tokens))
    
print("Generating predictions...")
output = llm.generate(input_list, sampling_params)

for idx, item in enumerate(input_list):
    model_answer = output[idx].outputs[0].text.strip()
    response_length = len(output[idx].outputs[0].token_ids)
    if test_str in ["Math", "Math500"]:
        # math    
        para_ques = dataset[idx]['problem']
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = dataset[idx]['answer']
    elif test_str == "GSM":
        # gsm
        para_ques = dataset[idx]['extra_info']['question']
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = dataset[idx]['reward_model']['ground_truth']
    elif test_str in ["Omnimath", "aime"]:
        para_ques = dataset[idx]['problem']
        pred = extract_math_answer(para_ques,model_answer,"")
        gold = dataset[idx]['answer']

    t = dataset[idx].copy()
    is_cor = is_equiv(gold,pred)

    t["budget"] = num_tokens
    budget_min, budget_max = num_tokens * 0.5, num_tokens * 1.5
    if t["budget"]:
        t['budget_match'] = budget_min <= response_length <= budget_max
    else:
        t['budget_match'] = None
    
    t["gold_answer"] = gold
    t["model_response"] = model_answer
    t["model_answer"] = pred
    t["response_length"] = response_length
    t["is_cor"] = is_cor

    results.append(t)

print("Saving results...")
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=4)
print("Results saved to", output_path)
accuracy = len([t for t in results if t["is_cor"]]) / len(results)
length = sum([t['response_length'] for t in results])/len(results)
print(f"Answer Accuracy: {accuracy:.2%}")
print(f"Average Response Length: {length:.2f}")
print(f"Meaningful Budget rate: {len([t for t in results if t['budget_match'] is not None]) / len(results):.2%}")
print(f"Average budget: {sum([t['budget'] for t in results if t['budget_match'] is not None])/len([t for t in results if t['budget_match'] is not None]):.2f}")
print(f"Budget match rate: {len([t for t in results if t['budget_match']]) / len(results):.2%}")