import torch

from transformers import AutoModelForSequenceClassification, AutoTokenizer

class Reward_Skywork:
    def __init__(self, model_path, device):
        self.device = device
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map=device,
            attn_implementation="flash_attention_2",
            num_labels=1,
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

    def __call__(self, system_prompt, user, answer):
        messages = [{"role": "user", "content": user}, {"role": "assistant", "content": answer}]

        messages_tokenized = self.tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt").to(self.device)

        with torch.no_grad():
            score = self.model(messages_tokenized).logits[0][0].item()

        return score

if __name__ == "__main__":
    reward = Reward_Skywork("", "cuda:7")
    print(reward(user="tell me who's the most famous vtuber?Just tell me the name", answer="I think it's taffy"))
        