import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import json
import os

device = "cuda:0"

class RewardModel:
    def __init__(self, model_name):
        
        self.model_name = model_name
        print(f"{model_name}")
        if(model_name=="LxzGordon/URM-LLaMa-3.1-8B"):
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map=device,
                attn_implementation="flash_attention_2",
                num_labels=10,
            )
        else:
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                torch_dtype=torch.bfloat16,
                device_map=device,
                attn_implementation="flash_attention_2",
                num_labels=1,
            )
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    def calculate_reward(self, prompt, response):
        conv = [
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": response}
        ]
        
        conv_formatted = self.tokenizer.apply_chat_template(conv, tokenize=False)
        
        if self.tokenizer.bos_token is not None and conv_formatted.startswith(self.tokenizer.bos_token):
            conv_formatted = conv_formatted[len(self.tokenizer.bos_token):]
        
        conv_tokenized = self.tokenizer(conv_formatted, return_tensors="pt").to(device)
        
        with torch.no_grad():
            outputs = self.model(**conv_tokenized)
            return outputs.logits[0][0].item()
