import torch
from torch.nn import functional as F
from transformers import Trainer


# length-normalized per-example NLL (mean over valid tokens)
def batch_mean_nll_from_logits(logits, labels):
    # shift for causal
    logits = logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous().long()
    valid = labels.ne(-100)

    vocab = logits.size(-1)
    ce_tok = F.cross_entropy(
        logits.view(-1, vocab), labels.view(-1), ignore_index=-100, reduction="none"
    ).view_as(
        labels
    )  # [B,T-1]

    tok_counts = valid.sum(dim=-1).clamp_min(1)
    nll = (ce_tok * valid).sum(dim=-1) / tok_counts  # [B]
    return nll


class DPOTrainer(Trainer):
    def __init__(self, *args, ref_model=None, beta=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        assert ref_model is not None, "ref_model must be supplied"
        self.ref_model = ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad_(False)
        self.beta = beta
        # Disable KV cache during training for stability with PEFT
        try:
            self.ref_model.config.use_cache = False
        except Exception:
            pass

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # We want: chosen = IDK, rejected = FORGET
        idk_ids, idk_mask, idk_lbls = (
            inputs["idk_input_ids"],
            inputs["idk_attention_mask"],
            inputs["idk_labels"],
        )
        forg_ids, forg_mask, forg_lbls = (
            inputs["forget_input_ids"],
            inputs["forget_attention_mask"],
            inputs["forget_labels"],
        )

        # Policy forwards (keep grads), use same autocast as args

        out_ch = model(
            input_ids=idk_ids, attention_mask=idk_mask, labels=idk_lbls
        )  # chosen = IDK
        out_rj = model(
            input_ids=forg_ids, attention_mask=forg_mask, labels=forg_lbls
        )  # rejected = FORGET
        nll_ch = batch_mean_nll_from_logits(out_ch.logits, idk_lbls)  # [B]
        nll_rj = batch_mean_nll_from_logits(out_rj.logits, forg_lbls)  # [B]

        # Reference (no grad, same autocast)
        with torch.no_grad():
            ref_ch = self.ref_model(
                input_ids=idk_ids, attention_mask=idk_mask, labels=idk_lbls
            )
            ref_rj = self.ref_model(
                input_ids=forg_ids, attention_mask=forg_mask, labels=forg_lbls
            )
            nll_ch_ref = batch_mean_nll_from_logits(ref_ch.logits, idk_lbls)
            nll_rj_ref = batch_mean_nll_from_logits(ref_rj.logits, forg_lbls)

        delta_pi = -(nll_ch - nll_rj)  # forget
        delta_ref = -(nll_ch_ref - nll_rj_ref)  # reference

        delta_pi = delta_pi - delta_pi.mean().detach()
        delta_ref = delta_ref - delta_ref.mean().detach()

        # Loss
        beta = self.beta
        loss_vec = -F.logsigmoid(beta * (delta_pi - delta_ref))
        loss = loss_vec.mean()

        # Diagnostics
        self.log(
            {
                "dpo_delta_pi_mean": delta_pi.mean().item(),
                "dpo_delta_ref_mean": delta_ref.mean().item(),
                "dpo_margin_mean": (delta_pi - delta_ref).mean().item(),
            }
        )

        return (loss, None) if return_outputs else loss
