import math
from trl import ORPOTrainer
import torch
import torch.nn.functional as F
from typing import Tuple
import numpy as np

class MPOTrainer(ORPOTrainer):

    def odds_ratio_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.

        Args:
            policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
            policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)

        Returns:
            A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
            The losses tensor contains the ORPO loss for each example in the batch.
            The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
            The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
            The `log(sigmoid(log_odds_chosen))` for logging purposes.
        """

        # Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
    
        if self.loss_type == "softplus":
            losses = - self.beta * F.softplus(
                (policy_chosen_logps - policy_rejected_logps)
            )
        elif self.loss_type == "cpo":
            losses = - self.beta * F.logsigmoid(
                (policy_chosen_logps - policy_rejected_logps)
            )
        elif self.loss_type == "orpo":
            log_odds = (policy_chosen_logps - policy_rejected_logps) - (
                torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
            )
            losses = - self.beta * F.logsigmoid(
                log_odds
            )
        elif self.loss_type == "leakyrelu":
            losses = - self.beta * F.leaky_relu(
                (policy_chosen_logps - policy_rejected_logps), negative_slope=0.1,
            )
        elif self.loss_type == "slic":
            losses = - self.beta * F.relu(
                (policy_chosen_logps - policy_rejected_logps) - self.po_cutoff
            )
        elif self.loss_type == "sft":
            losses = - self.beta * F.relu(
                (policy_chosen_logps - policy_rejected_logps)
            )
        elif self.loss_type == "simpo" or self.loss_type == "tampo":
            losses = - self.beta * F.logsigmoid(
                (policy_chosen_logps - policy_rejected_logps) - self.po_cutoff
            )
        else:
            raise ValueError("Invalid loss type")
        # print("losses:", losses)
        log_odds = (policy_chosen_logps - policy_rejected_logps) - (
            torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
        )
        ratio = F.logsigmoid(log_odds)
        # losses_orpo = self.beta * ratio

        chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
        rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()

        return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds)

    def get_batch_loss_metrics(
        self,
        model,
        batch,
        train_eval,
    ):
        """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        forward_output = self.concatenated_forward(model, batch)
        (
            policy_chosen_logps,
            policy_rejected_logps,
            policy_chosen_logits,
            policy_rejected_logits,
            policy_nll_loss,
        ) = forward_output[:5]
        if self.aux_loss_enabled:
            aux_loss = forward_output[5]

        losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
            policy_chosen_logps, policy_rejected_logps
        )
        losses = losses.mean()
        if self.loss_type == "tampo":
            # positive = F.softplus(logits) - torch.log(torch.tensor(2, device=logits.device))
            # negative = F.logsigmoid(logits) + torch.log(torch.tensor(2, device=logits.device))
            po_loss = losses
            if self.alpha_schedule == "linear":
                alpha = (self.state.global_step / self.state.max_steps)
            elif self.alpha_schedule == "square":
                alpha = (self.state.global_step / self.state.max_steps) ** 2
            elif self.alpha_schedule == "tanh":
                alpha = (math.tanh((self.state.global_step / self.state.max_steps) * 10 - 7.5) + 1) / 2
            else:
                raise ValueError("Invalid schedule type")

            loss = alpha * po_loss + (1 - alpha) * policy_nll_loss
                
        elif self.loss_type == "sft":
            loss = policy_nll_loss
        else:
            loss = losses + policy_nll_loss

        reward_accuracies = (chosen_rewards > rejected_rewards).float()

        prefix = "eval_" if train_eval == "eval" else ""
        metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
        metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
        metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
        metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
            chosen_rewards - rejected_rewards
        ).mean()
        metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
        metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
        # metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
        # metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
        metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
        metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
        metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
        # metrics[f"{prefix}po_loss"] = self.accelerator.gather_for_metrics(po_loss).mean()
        for k, v in metrics.items():
            metrics[k] = v.item()

        return loss, metrics