import numpy as np 
from typing import Any, Dict, List, Optional, Union
import torch 
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .base import *

class RewardModule(BaseRewardModule): 
    def __init__(self, cfg): 
        super().__init__(cfg, torch_dtype=torch.float16)
    # def __init__(self, cfg): 
    #     self.model = AutoModelForSequenceClassification.from_pretrained(cfg.reward.model, torch_dtype=torch.float16)

    # @torch.no_grad()
    # def get_reward(self, batch): 
    #     output = self.model(**batch).logits
    #     scores = output.squeeze(-1).cpu().tolist()
    #     return scores

class RewardCollator(BaseCollator): 

    def format_query_response(self, outputs): 
        qr = [[
            {'role': 'user', 'content': self._build_query_str(output['prompt'])},
            {'role': 'assistant', 'content': self._build_response_str(output['response'])}
        ] for output in outputs]
        return self.tokenizer.apply_chat_template(qr, tokenize=False, add_generation_prompt=False)

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        query_responses = self.format_query_response(outputs)
        input_ids = self.tokenizer(query_responses, return_tensors="pt", padding=True,)
        return input_ids