from typing import Any, Dict, List

from inference_rlhf.code.collators.base import BaseCollator
from inference_rlhf.code.models.armo_rm_dt import process_conversation

class ArmoRMCollator(BaseCollator): 
    def format_query_response(self, outputs: List[Dict[str, Any]]) -> List[str]:
        """
        Format the query and response into chat format.
        """
        qr = [
            process_conversation([
                {'role': 'user', 'content': self._build_query_str(output['prompt'])},
                {'role': 'assistant', 'content': self._build_response_str(output['response'])}
            ])
        for output in outputs]
        
        return self.tokenizer.apply_chat_template(qr, tokenize=False, add_generation_prompt=False)

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Take batch of outputs and (1) format them into chat format, and (2) tokenize them into tensors.
        """
        query_responses = self.format_query_response(outputs)
        return self.tokenizer(query_responses, return_tensors="pt", padding=True,), [output['id'] for output in outputs]