import json
import random
import sys
import os
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset



def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    
    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]
    
    return retval

def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None


def parse_question_answer(subdir, file):
    with open(os.path.join(subdir, file), 'r', encoding="utf-8") as fp:
        try:
            problem_data = json.load(fp)
        except Exception as e:
            print(f"Error loading JSON from {file}", e)
            raise e
        prob_content = problem_data["problem"]
        # question = "\n\nPlease solve the problem below.\nProblem: " + prob_content + "\nAnswer:"
        prob_level = problem_data["level"]
        prob_type = problem_data["type"]
        try:
            prob_level = int(prob_level.split("Level ")[1])
        except:
            prob_level = None
        answer = remove_boxed(last_boxed_only_string(problem_data["solution"]))
        # return prob_content, prob_level, prob_type, problem_data["solution"]
        return prob_content, prob_level, prob_type, answer

def forward(model_name, n_samples, dataset_name):
    if dataset_name == "math500":
        ds = load_dataset("HuggingFaceH4/MATH-500")['test']
    elif dataset_name == "gsm8k":
        ds = load_dataset("gsm8k", "main", split="test")

    response_dict = {}
    prompts = []
    questions = []
    # llm = LLM(model=model_name, gpu_memory_utilization=0.9, tensor_parallel_size=2)
    llm = LLM(model=model_name, gpu_memory_utilization=0.9)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    random.seed(0)
    for row in ds:
        if dataset_name == "math500":
            prob_content = row["problem"]
        elif dataset_name == "gsm8k":
            prob_content = row["question"]
        questions.append(prob_content)

        messages = [
            {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
            {"role": "user", "content": prob_content}
        ]
        prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        prompts.append(prompt)


    sampling_params = SamplingParams(n=n_samples, temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)
    outputs = llm.generate(prompts, sampling_params)
    for output, question in zip(outputs, questions):
        reply = []
        for answer in output.outputs:
            reply.append(answer.text)
        # if response_dict[question] is none, then set it to the reply
        # else append the reply to the list
        if question not in response_dict:
            response_dict[question] = reply
        else:
            response_dict[question].extend(reply)

    return response_dict

if __name__ == "__main__":
    EXP_NAME = sys.argv[1]
    dataset = sys.argv[2]
    model_path = sys.argv[3]
    # create a directory if not exists
    try:
        os.mkdir(EXP_NAME)
    except FileExistsError:
        pass

    _output = forward(model_path, 500, dataset)
    with open(EXP_NAME+"/samples.json", "w", encoding="utf-8") as f:
        json.dump(_output, f)
        print(f"Saved samples to {EXP_NAME}/samples.json")