import os, sys

cache_dir = "/work/hdd/bdkj/audreyh/.cache"
os.environ['XDG_CACHE_HOME'] = cache_dir

import numpy as np 
import torch 
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import login

from vllm import LLM, SamplingParams

TASK_DESC = "As an expert problem solver, solve the following mathematical questions step by step."
# TASK_DESC += " Then, report the final answer following the phrase 'The answer is'."
# TASK_DESC += " In the last line of your response, report the final answer following the phrase 'The answer is '."
QUESTION_FORMAT = "Q: {question}"
ANSWER_FORMAT = "A:{answer}"
SEP = "\n"


question = "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"
# question = "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?"
answer = " We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6."

model_name = "microsoft/Phi-3-small-8k-instruct"
model_name = "OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# messages = [
#     {'role': 'user',
#      'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
# ]
# prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

messages = [
    {'role': 'user',
     'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
    {'role': 'assistant',
     'content': ANSWER_FORMAT.format(answer=answer)},
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print(prompt)

sampling_params = SamplingParams(
            n = 5,
            temperature=1,
            top_p = 1,
            top_k=-1, 
            max_tokens=512,
            logprobs=0, 
            stop_token_ids=[tokenizer.eos_token_id]
        )

llm = LLM(model_name, trust_remote_code=True)


outputs = llm.generate([prompt], sampling_params)
output = outputs[0]
for out in output.outputs: 
    response = out.text 
    print(response)
    print('-------')

