import os, sys 
cache_dir = "/work/hdd/bdkj/audreyh/.cache"
os.environ['XDG_CACHE_HOME'] = cache_dir
# sys.path.append('/u/audreyh/workspace/safe-rlhf/safe_rlhf')
import torch
from transformers import AutoTokenizer
from safe_rlhf.models import AutoModelForScore
from accelerate import Accelerator
import ipdb


TASK_DESC = "As an expert problem solver, solve the following mathematical questions step by step."
QUESTION_FORMAT = "Q: {question}"
ANSWER_FORMAT = "A:{answer}"
SEP = "\n"
question = "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"
answer = " We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6."

message = [
    {'role': 'user',
     'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
    {'role': 'assistant',
     'content': ANSWER_FORMAT.format(answer=answer)},
]
messages = [message for _ in range(15)]
messages.append(
   [
    {'role': 'user',
     'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
    {'role': 'assistant',
     'content': ANSWER_FORMAT.format(answer=answer[:10])},
     ]
     )
def format_query_response(messages): 
    qr = []
    for messages in messages: 
        query = messages[0]['content']
        response = messages[1]['content']
        _query = f"BEGINNING OF CONVERSATION: USER: {query}"
        _response = f" ASSISTANT: {response}"
        qr.append(_query + _response)
    return qr 


accelerator = Accelerator()
model_name = "PKU-Alignment/beaver-7b-v1.0-reward"

model = AutoModelForScore.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='auto')
model = accelerator.prepare(model)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer = accelerator.prepare(tokenizer)

inputs = format_query_response(messages)
input_ids = tokenizer(inputs, return_tensors='pt', padding=True).to(accelerator.device)
import ipdb; ipdb.set_trace()

with torch.no_grad(): 
    output = model(**input_ids)
    scores = output['end_scores'].squeeze(-1).cpu().tolist()
    ipdb.set_trace()

