from typing import Any, Dict, List, Tuple, Union

import torch
from transformers import PreTrainedModel
from trl import RewardTrainer

class RewardModel():
    def __init__(self, models, tokenizer, gold_rm = None, g_tokenizer = None, agg_fn="mean", batch_size = 6, g_batch_size = 1, max_length = 48):
        self.model = self._prepare_model(models)
        self.tokenizer = tokenizer
        self.gold_rm = self._prepare_model(gold_rm)
        self.g_tokenizer = g_tokenizer
        self.batch_size = batch_size
        self.g_batch_size = g_batch_size
        self.eval_original_scores = {}
        self.max_length = max_length
        self.agg_fn = agg_fn
    
    def _prepare_model(self, model):
        if model is None:
            return None
        elif type(model) is list:
            return lambda minibatch: torch.vstack([m(**minibatch.to(m.device)).logits.flatten() for m in model])
        else:
            return lambda minibatch: model(**minibatch.to(model.device)).logits.flatten()[None]
        
    def get_scores(self, samples: List[str], model, tokenizer, batch_size = 2, agg_fn="min"):
        scores_list = []
        for i in range(0, len(samples), batch_size):
            sub_samples = samples[i : i + batch_size]
            embeddings = tokenizer(
                sub_samples,
                truncation=True,
                max_length=self.max_length,
                padding=True,
                return_tensors="pt",
            )
            with torch.no_grad():
                model_output = model(embeddings)
                if agg_fn == "min":
                    sub_scores = model_output.min(dim=0)[0]
                elif agg_fn == "mean":
                    sub_scores = model_output.mean(dim=0)
                else:
                    sub_scores = model_output
            scores_list.append(sub_scores)
        scores = torch.hstack(scores_list)
        return scores

    def score(self, samples: List[str], prompts: List[str], outputs: List[str], model = None, tokenizer = None, is_eval=False, cache_key="model", **kwargs):
        if model == None:
            model = self.model
        if tokenizer == None:
            tokenizer = self.tokenizer
        return self.get_scores(samples, model, tokenizer, **{"batch_size": self.batch_size, **kwargs})


class RegularizedRewardTrainer(RewardTrainer):
    def compute_loss(
        self,
        model: Union[PreTrainedModel, torch.nn.Module],
        inputs: Dict[str, Union[torch.Tensor, Any]],
        return_outputs=False,
        eta=1e-3
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
        )[0]
        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
        )[0]
        loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + eta * ((rewards_chosen + rewards_rejected) ** 2).mean()
        if return_outputs:
            return loss, {
                "rewards_chosen": rewards_chosen,
                "rewards_rejected": rewards_rejected,
            }
        return loss
