import os, sys
from accelerate import Accelerator
import ipdb

cache_dir = "/work/hdd/bdkj/audreyh/.cache"
os.environ['XDG_CACHE_HOME'] = cache_dir
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np 
import torch 
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification, pipeline
from huggingface_hub import login



SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
TASK_DESC=""
QUESTION_FORMAT = "{question}"
ANSWER_FORMAT = "{answer}"
SEP = "\n"
question = "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"
answer = " We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6."

message = [
    {'role': 'user',
     'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
    {'role': 'assistant',
     'content': ANSWER_FORMAT.format(answer=answer)},
]

reward_model = "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr"

tokenizer = AutoTokenizer.from_pretrained(reward_model, padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
    tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

model = pipeline(
    "sentiment-analysis",
    model=reward_model,
    device="cuda",
    tokenizer=tokenizer, 
    model_kwargs={"torch_dtype": torch.bfloat16}
)
import ipdb; ipdb.set_trace()
# login(token="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")


# TASK_DESC = "As an expert problem solver, solve the following mathematical questions step by step."
# QUESTION_FORMAT = "Q: {question}"
# ANSWER_FORMAT = "A:{answer}"
# SEP = "\n"

# question = "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"
# answer = " We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6."

# message = [
#     {'role': 'user',
#      'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
#     {'role': 'assistant',
#      'content': ANSWER_FORMAT.format(answer=answer)},
# ]
# messages = [message for _ in range(15)]
# messages.append(
#    [
#     {'role': 'user',
#      'content': TASK_DESC + 2*SEP + QUESTION_FORMAT.format(question=question)},
#     {'role': 'assistant',
#      'content': ANSWER_FORMAT.format(answer=answer[:10])},
#      ]
#      )


# ########################################################################################################################################
# # Begin code
# ########################################################################################################################################

# model_name = "Ray2333/GRM-Llama3.2-3B-rewardmodel-ft"
# accelerator = Accelerator() 

# tokenizer = AutoTokenizer.from_pretrained(model_name)
# tokenizer = accelerator.prepare(tokenizer)
# # kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
# kwargs = {"padding": True, "return_tensors": "pt"}


# inputs = tokenizer.apply_chat_template(messages, tokenize=False)
# # input_ids = tokenizer.encode_plus(inputs, **kwargs)
# input_ids = tokenizer(inputs, **kwargs)


# model = AutoModelForSequenceClassification.from_pretrained(model_name, torch_dtype=torch.float16)
# model = accelerator.prepare(model)

# with torch.no_grad(): 
#     input_ids = {key: value.to(model.device) for key, value in input_ids.items()}
#     output = model(**input_ids).logits
#     score = output.squeeze(-1).cpu().tolist()

# '''
# for eurus
# '''
# # tokenizer_name = "mistralai/Mistral-7B-Instruct-v0.2"
# # tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
# # tokenizer.pad_token_id = tokenizer.eos_token_id

# # inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
# # input_ids = tokenizer(inputs, return_tensors="pt", padding=True,)

# # input_ids = {key: value.cuda() for key, value in input_ids.items()}

# # model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map="cuda", torch_dtype=torch.bfloat16)
# # model = accelerator.prepare(model)

# # with torch.no_grad(): 
# #     score = model(**input_ids).squeeze(-1).cpu().tolist()


# '''
# for armo-rm
# '''
# # input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", padding=True).to(accelerator.device)

# # model = AutoModelForSequenceClassification.from_pretrained(model_name, device_map="cuda", trust_remote_code=True, torch_dtype=torch.bfloat16)
# # model = accelerator.prepare(model)

# # with torch.no_grad():
# #     output = model(input_ids)
# #     score = output.score.cpu().tolist()