import numpy as np 
from typing import Any, Dict, List, Optional, Union
import torch 
from transformers import pipeline
from transformers import AutoTokenizer
from .base import BaseCollator
from safe_rlhf.models import AutoModelForScore


class RewardModule: 
    def __init__(self, cfg): 
        self.model = AutoModelForScore.from_pretrained(cfg.reward.model, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')

    @torch.no_grad()
    def get_reward(self, batch): 
        # if batch['input_ids'].device != self.model.device: 
        #     batch = {key: value.to(self.model.device) for key, value in batch.items()}
        output = self.model(**batch)
        scores = output['end_scores'].squeeze(-1).cpu().tolist()
        return scores

class RewardCollator(BaseCollator): 
    def format_query_response(self, outputs): 
        qr = [f"BEGINNING OF CONVERSATION: USER: {self._build_query_str(output['prompt'])} ASSISTANT: {self._build_response_str(output['response'])}" for output in outputs]
        return qr

    def __call__(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
        query_responses = self.format_query_response(outputs)
        return self.tokenizer(query_responses, return_tensors="pt", padding=True)