import torch
import torch.nn.functional as F

def kl_divergence_loss(
    logits_trained,        # logits from trained model (main model)
    logits_frozen,         # logits from frozen model
    attention_mask,        # shape (batch, seq)
    reduction='batchmean',
    temperature=1.0
):
    # Optionally (if needed) apply temperature
    if temperature != 1.0:
        logits_trained = logits_trained / temperature
        logits_frozen = logits_frozen / temperature

    # Compute log-probs
    log_probs_trained = F.log_softmax(logits_trained, dim=-1)  # (B, T, V)
    log_probs_frozen = F.log_softmax(logits_frozen, dim=-1)
    probs_trained = log_probs_trained.exp()

    # KL(model || frozen) averaged over mask
    kl = probs_trained * (log_probs_trained - log_probs_frozen)  # (B, T, V)
    kl = kl.sum(dim=-1)      # (B, T)

    kl = kl * attention_mask  # mask out padding

    if reduction == 'batchmean':
        kl_loss = kl.sum() / attention_mask.sum()
    elif reduction == 'sum':
        kl_loss = kl.sum()
    elif reduction == 'none':
        kl_loss = kl
    else:
        raise ValueError(f"Unknown reduction: {reduction}")

    return kl_loss

