import os
import json
from utils import *
from glob import glob
from datetime import datetime
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import pandas as pd
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, default="./saved_models/Llama-3.2-3B-Instruct")
parser.add_argument("--k", type=int, default=8)
args = parser.parse_args()

MAX_MODEL_LEN = 8192
tokenizer = AutoTokenizer.from_pretrained(args.model)

llm = LLM(model = args.model,
          max_model_len=MAX_MODEL_LEN,
          tensor_parallel_size=4)

with open('./data/training_data.json', 'r') as f:
    test_data = json.load(f)

prompts = []
for sample in test_data:

    msg = [
        {"role": "user", "content": f"Question: {sample['problem']}\n\nI was told the answer is {sample['expected_answer']} but I don't know why. Please explain why the answer is {sample['expected_answer']} step by step."},
        {"role": "assistant", "content": "Let's go through the problem step by step.\n\n### Step 1:"}
    ]
    prompt = tokenizer.apply_chat_template(msg,
                                           tokenize=False,
                                           add_generation_prompt=False,
                                           continue_final_message=True)
    prompts.append(prompt)

print(prompts[0])

sampling_params = SamplingParams(temperature=1.0,
                                 top_p=0.9,
                                 max_tokens=8192)

for k in range(args.k):
    outputs = llm.generate(prompts, sampling_params)
    num_correct = 0
    for i in range(len(outputs)):
        generated_explanation = outputs[i].outputs[0].text
        test_data[i]['explanation_'+str(k+1)] = "Let's go through the problem step by step.\n\n### Step 1: " + generated_explanation.strip()
    
    os.makedirs(f"./data/{args.model.split('/')[-1]}", exist_ok=True)
write_json(test_data, f"./data/{args.model.split('/')[-1]}/step-by-step-explanation_{args.k}.json")