import torch
import torch.nn.functional as F
from trl import DPOTrainer

class SelfDebiasDPOTrainer(DPOTrainer):
    def __init__(self, *args, self_debias_alpha=0.25, **kwargs):
        super().__init__(*args, **kwargs)
        self.self_debias_alpha = self_debias_alpha

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        ref_rejected_logps: torch.FloatTensor,
    ):
        """
         SelfDebias  (7): L_SelfDebias = L_SC + alpha * L_DPO
        
        L_SC (Self-Correction Loss):  MSE  IPO Loss
        L_DPO (Direct Preference Loss):  Sigmoid 
        """

        # logits = log(π(yw)/π(yl)) - log(π_ref(yw)/π_ref(yl))
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = ref_chosen_logps - ref_rejected_logps
        logits = pi_logratios - ref_logratios

        # losses = -logσ(β * logits)
        loss_dpo = -F.logsigmoid(self.beta * logits)

        # L_SC ≈ (1 - β * logits)^2

        loss_sc = (1 - self.beta * logits) ** 2

        final_loss = loss_sc + self.self_debias_alpha * loss_dpo

        chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps).detach()

        return final_loss, chosen_rewards, rejected_rewards