import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import functional as F
from transformers import Trainer


def get_batch_nll(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    # shift for causal LM
    logits = logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()

    # keep everything on the same device/dtypes
    device = logits.device
    labels = labels.to(device=device, dtype=torch.long)

    valid = labels.ne(-100)
    logprobs = F.log_softmax(logits, dim=-1)  # [B, T-1, V]
    gathered = logprobs.gather(-1, labels.clamp_min(0).unsqueeze(-1))  # [B, T-1, 1]
    gathered = gathered.squeeze(-1)  # [B, T-1]
    gathered = torch.where(valid, gathered, torch.zeros_like(gathered))
    return -gathered.sum(dim=-1)  # [B]


def _freeze_ref(m):
    m.eval()
    for p in m.parameters():
        p.requires_grad_(False)
    try:
        m.config.use_cache = True  # faster inference
    except Exception:
        pass
    return m


class KTOTrainer(Trainer):
    def __init__(self, *args, ref_model=None, beta=0.1, **kwargs):
        super().__init__(*args, **kwargs)
        assert ref_model is not None, "ref_model must be supplied"
        self.ref_model = _freeze_ref(ref_model)
        self.beta = beta

    def _wrap_model(self, model, training=True):
        model = super()._wrap_model(model, training)
        self.ref_model = self.ref_model.to(self.args.device)
        return model

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        with torch.inference_mode():
            idk_out = model(
                input_ids=inputs["idk_input_ids"],
                attention_mask=inputs["idk_attention_mask"],
                labels=inputs["idk_labels"],
            )
            idk_ref = self.ref_model(
                input_ids=inputs["idk_input_ids"],
                attention_mask=inputs["idk_attention_mask"],
                labels=inputs["idk_labels"],
            )
            idk_nll_pol = get_batch_nll(idk_out.logits, inputs["idk_labels"])  # [B]
            idk_nll_ref = get_batch_nll(idk_ref.logits, inputs["idk_labels"])  # [B]
            KL_term = (idk_nll_pol - idk_nll_ref).mean()  # scalar, NO grad

            fg_ref = self.ref_model(
                input_ids=inputs["forget_input_ids"],
                attention_mask=inputs["forget_attention_mask"],
                labels=inputs["forget_labels"],
            )
            fg_nll_ref = get_batch_nll(fg_ref.logits, inputs["forget_labels"])  # [B]

        # ----- FORGET branch (policy with grads) -----
        fg_out = model(
            input_ids=inputs["forget_input_ids"],
            attention_mask=inputs["forget_attention_mask"],
            labels=inputs["forget_labels"],
        )
        fg_nll_pol = get_batch_nll(fg_out.logits, inputs["forget_labels"])  # [B]

        # log_ratios = log pπ(FORGET) - log pref(FORGET)
        log_ratios = -(fg_nll_pol) - (-(fg_nll_ref))  # == fg_nll_pol - fg_nll_ref

        # EXACT `kto_sigmoid` form
        margin = KL_term - self.beta * log_ratios  # [B] (KL_term is scalar)
        loss_vec = 1.0 - torch.sigmoid(margin)  # [B]
        loss = loss_vec.mean() * (2.0 / self.beta)

        if return_outputs:
            return loss, {
                "kl_term": KL_term.detach(),
                "log_ratio_mean": log_ratios.mean().detach(),
            }
        return loss
