from typing import Any, Dict, List
from importlib import import_module

from transformers import AutoTokenizer

class BaseCollator:
    def __init__(self, cfg): 
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.reward.model)
        self.include_prompt = cfg.evaluation.include_prompt
        self.task_desc = ""
        if self.include_prompt: 
            task_desc = cfg.task.TASK_DESC
            inst = cfg.reward.INST 
            self.task_desc = inst if cfg.reward.replace_inst else task_desc + inst
        
        tm = import_module(f"inference_rlhf.code.tasks.{cfg.task.name}", package='code')
        self.question_format = tm.QUESTION_FORMAT
        self.answer_format = tm.ANSWER_FORMAT
        self.sep = tm.SEP

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        query_responses = self.format_query_response(outputs)
        return self.tokenize(query_responses)
    
    def format_q(self, question, n=0): 
        return self.question_format.format(question=question) + n * self.sep
    
    def format_a(self, answer, n=0): 
        return self.answer_format.format(answer=answer) + n * self.sep
    
    def _build_query_str(self, question): 
        if self.include_prompt: 
            query = self.task_desc + self.sep * 2 
        else: 
            query = ""
        query += self.format_q(question, n=0)
        return query

    def _build_response_str(self, response): 
        return self.format_a(response)  
    
    def _format_query_response(self, query, response): 
        pass

    def format_query_response(self, outputs): 
        return [self._format_query_response(output['prompt'], output['response']) for output in outputs]      
    
    def tokenize(self, qr): 
        return self.tokenizer(qr, return_tensors="pt", padding=True,)