import numpy as np 
from typing import Any, Dict, List, Optional, Union
import torch 
from transformers import pipeline
from transformers import AutoTokenizer, AutoModel
from .base import BaseCollator

class RewardModule: 
    def __init__(self, cfg): 
        self.model = AutoModel.from_pretrained(cfg.reward.model, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda")

    @torch.no_grad()
    def get_reward(self, batch): 
        # if batch['input_ids'].device != self.model.device: 
        #     batch = {key: value.to(self.model.device) for key, value in batch.items()}
        output = self.model(**batch)
        score = output.squeeze(-1).cpu().tolist()
        return score



class RewardCollator(BaseCollator): 
    def __init__(self, cfg, **kwargs): 
        super().__init__(cfg, **kwargs)
        tokenizer_name = "mistralai/Mistral-7B-Instruct-v0.2"
        self.chat_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

    def format_query_response(self, outputs): 
        qr = [[
            {'role': 'user', 'content': self._build_query_str(output['prompt'])},
            {'role': 'assistant', 'content': self._build_response_str(output['response'])}
        ] for output in outputs]
        return self.chat_tokenizer.apply_chat_template(qr, tokenize=False, add_generation_prompt=False)

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        query_responses = self.format_query_response(outputs)
        input_ids = self.tokenizer(query_responses, return_tensors="pt", padding=True,)
        return input_ids