from trl import GRPOTrainer


class AdaptiveGRPOTrainer(GRPOTrainer):
    def __init__(self, target_kl, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.target_kl = target_kl
        self.k_beta = 0.1

    def compute_loss(self, *args, **kwargs):
        loss = super().compute_loss(*args, **kwargs)
        mode = "train" if self.model.training else "eval"
        if mode == "eval" or self.target_kl < 0:
            return loss
        current_kl = self._metrics[mode]["kl"][-1]
        e = (current_kl - self.target_kl) / self.target_kl
        e = max(min(e, 0.2), -0.2)
        self.beta *= 1 + self.k_beta * e
        self._metrics[mode]["adaptive_beta"].append(self.beta)
        return loss
