import os
import json
from utils import *
from glob import glob
from datetime import datetime
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import pandas as pd
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="./saved_models/Llama-3.2-3B-Instruct")
parser.add_argument("--k", type=int, default=8)
args = parser.parse_args()

def remove_hint_heading(input_string):
  lines = input_string.split('\n')
  filtered_lines = [line for line in lines if 'hint' not in line.lower()]
  result_string = '\n'.join(filtered_lines)
  return result_string.strip()

MAX_MODEL_LEN = 10240
tokenizer = AutoTokenizer.from_pretrained(args.model)

llm = LLM(model = args.model,
          max_model_len=MAX_MODEL_LEN,
          tensor_parallel_size=8)

model_name = args.model.split('/')[-1]
with open(f'./data/{model_name}/step-by-step-explanation_{args.k}.json', 'r') as f:
    test_data = json.load(f)

prompts = []
for sample in test_data:
    for k in range(args.k):
        msg = [
            {"role": "system", "content": f"You are a tutor. You are given a set of question, correct answer and solution. Your job is to provide a hint for the problem. The hint should help the student learn the core concept (e.g. formula, lemma, or necessary knowledge) needed to solve this problem. The hint should be concise, to the point, but high level. Do not include any detailed steps or calculations or the final answer."},
            {"role": "user", "content": f"Question: {sample['problem']}\n\nAnswer: {sample['expected_answer']}\n\nSolution: {sample[f'explanation_{k+1}']}\n\nNow, please provide a hint for this problem to help the student learn the core concept."}
        ]
        prompt = tokenizer.apply_chat_template(msg,
                                            tokenize=False,
                                            add_generation_prompt=True)
        prompts.append(prompt)

print(prompts[0])

sampling_params = SamplingParams(temperature=1.0,
                                 top_p=0.9,
                                 max_tokens=1024)


outputs = llm.generate(prompts, sampling_params)

for i in range(0, len(outputs), args.k):
    for k in range(args.k):
        generated_hint = outputs[i+k].outputs[0].text
        generated_hint = generated_hint.replace("**Hint:**", "").strip()
        test_data[i // args.k]['hint_'+str(k+1)] = remove_hint_heading(generated_hint)
    
    os.makedirs(f"./data/{args.model.split('/')[-1]}", exist_ok=True)
write_json(test_data, f"./data/{args.model.split('/')[-1]}/step-by-step-explanation-and-hint_{args.k}.json")