import numpy as np 
from typing import Any, Dict, List, Optional, Union
import torch 
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .base import *


class RewardModule(BaseRewardModule): 
    def __init__(self, cfg, **kwargs): 
        self.pipe_kwargs = {
            "return_all_scores": True,
            "function_to_apply": "none",
        }
        self.model = pipeline(
            "sentiment-analysis",
            model=cfg.reward.model,
            device="cuda",
            tokenizer=AutoTokenizer.from_pretrained(cfg.reward.model), 
            model_kwargs={"torch_dtype": torch.bfloat16}
        )

    @torch.no_grad()
    def get_reward(self, batch): 
        outputs = self.model(batch, batch_size=len(batch), **self.pipe_kwargs)
        return [output[0]["score"] for output in outputs]


class RewardCollator(BaseCollator): 
    def format_query_response(self, query, response): 
        qr = [
            {'role': 'user', 'content': self._build_query_str(query)},
            {'role': 'assistant', 'content': self._build_response_str(response)}
        ]
        return self.tokenizer.apply_chat_template(qr, tokenize=False, add_generation_prompt=False).replace(self.tokenizer.bos_token, "")

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        query_responses = [self.format_query_response(output['prompt'], output['response']) for output in outputs]
        return query_responses