import numpy as np 
from typing import Any, Dict, List, Optional, Union
import torch 
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .base import *

from inference_rlhf.code.models.armo_rm_dt import process_conversation


class RewardModule(BaseRewardModule):
    def __init__(self, cfg):
        super().__init__(cfg, torch_dtype=torch.bfloat16) 

    # def __init__(self, cfg): 
    #     self.model = AutoModelForSequenceClassification.from_pretrained(cfg.reward.model, trust_remote_code=True, torch_dtype=torch.bfloat16)

    # @torch.no_grad()
    # def get_reward(self, batch): 
    #     output = self.model(**batch)
    #     score = output.score.cpu().tolist()
    #     return score

class RewardCollator(BaseCollator): 
    def format_query_response(self, outputs: List[Dict[str, Any]]) -> List[str]:
        """
        Format the query and response into chat format.
        """
        qr = [
            process_conversation([
                {'role': 'user', 'content': self._build_query_str(output['prompt'])},
                {'role': 'assistant', 'content': self._build_response_str(output['response'])}
            ])
        for output in outputs]
        
        return self.tokenizer.apply_chat_template(qr, tokenize=False, add_generation_prompt=False)

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Take batch of outputs and (1) format them into chat format, and (2) tokenize them into tensors.
        """
        query_responses = self.format_query_response(outputs)
        return self.tokenizer(query_responses, return_tensors="pt", padding=True,)