import os
import json
from utils import *
from glob import glob
from datetime import datetime
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import pandas as pd
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--model_name", type=str)
parser.add_argument("--k", type=int, default=8)
parser.add_argument("--gpus", type=str, required=True, help="which GPUs to use")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
num_gpus = len(args.gpus.split(","))

MAX_MODEL_LEN = 32768
tokenizer = AutoTokenizer.from_pretrained(args.model)
instruction_following = (f"\n\nYou must put your answer inside \\boxed{{}} "
                         f"and your final answer will be extracted automatically by the \\boxed{{}} tag.")

llm = LLM(model = args.model,
          max_model_len=MAX_MODEL_LEN,
          tensor_parallel_size=num_gpus)

test_data = read_json("./data/training_data_with_gpt_reasoning.json")
for sample in test_data:
    sample['reasoning'] = []
    # sample['predictions'] = []
    sample['num_pass'] = 0

prompts = []
for sample in test_data:

    msg = [
        {"role": "user", "content": sample['problem'] + instruction_following}
    ]
    prompt = tokenizer.apply_chat_template(msg,
                                           tokenize=False,
                                           add_generation_prompt=True)
    prompts.append(prompt)

print(prompts[0])

sampling_params = SamplingParams(temperature=0.7,
                                 max_tokens=10000)

for _ in range(args.k):
    outputs = llm.generate(prompts, sampling_params)
    num_correct = 0
    for i in range(len(outputs)):
        reasoning = outputs[i].outputs[0].text
        # pred = remove_boxed(last_boxed_only_string(reasoning))
        test_data[i]['reasoning'].append(reasoning)
        # test_data[i]['predictions'].append(pred)
        if verl_math_equiv(reasoning, test_data[i]['expected_answer']):
            test_data[i]['num_pass'] += 1

valid_samples = []
num_easy_samples = 0
for sample in test_data:
    if sample['num_pass'] == args.k:
        num_easy_samples += 1
        continue
    valid_samples.append(sample)

print(f"Number of easy samples: {num_easy_samples}")
print(f"Number of valid samples: {len(valid_samples)}")
print(f"Number of total samples: {len(test_data)}")

os.makedirs(f"./data/{args.model_name}", exist_ok=True)
print("===============")
print(valid_samples[0])
print("===============")
write_json(valid_samples, f"./data/{args.model_name}/stage2_valid_samples.json")