import torch
import torch.nn.functional as F

def kl_divergence_loss(
    logits_trained,        # logits from trained model (main model)
    logits_frozen,         # logits from frozen model
    attention_mask,        # shape (batch, seq)
    sample_mask=None,      # shape (batch,) - 1 for samples to include, 0 to exclude
    reduction='batchmean',
    temperature=1.0
):
    """
    Compute KL divergence loss between trained and frozen model outputs.
    
    Args:
        logits_trained: Logits from the trained model, shape (B, T, V)
        logits_frozen: Logits from the frozen model, shape (B, T, V)
        attention_mask: Attention mask for valid tokens, shape (B, T)
        sample_mask: Optional mask for which samples to include in loss, shape (B,)
                     1.0 for samples to include (e.g., interleaved), 0.0 to exclude
        reduction: One of 'batchmean', 'sum', or 'none'
        temperature: Temperature for softmax (default 1.0)
    
    Returns:
        KL divergence loss scalar (or tensor if reduction='none')
    """
    # Optionally (if needed) apply temperature
    if temperature != 1.0:
        logits_trained = logits_trained / temperature
        logits_frozen = logits_frozen / temperature

    # Compute log-probs
    log_probs_trained = F.log_softmax(logits_trained, dim=-1)  # (B, T, V)
    log_probs_frozen = F.log_softmax(logits_frozen, dim=-1)
    probs_trained = log_probs_trained.exp()

    # KL(model || frozen) averaged over mask
    kl = probs_trained * (log_probs_trained - log_probs_frozen)  # (B, T, V)
    kl = kl.sum(dim=-1)      # (B, T)

    kl = kl * attention_mask  # mask out padding

    # Apply sample-level mask if provided (e.g., for interleaved samples only)
    if sample_mask is not None:
        # Expand sample_mask from (B,) to (B, 1) to broadcast with (B, T)
        sample_mask = sample_mask.unsqueeze(1)  # (B, 1)
        kl = kl * sample_mask  # (B, T)

    if reduction == 'batchmean':
        # Normalize by number of active tokens
        if sample_mask is not None:
            # Only count tokens from masked samples
            denom = (attention_mask * sample_mask).sum()
        else:
            denom = attention_mask.sum()
        kl_loss = kl.sum() / denom.clamp(min=1.0)  # Avoid division by zero
    elif reduction == 'sum':
        kl_loss = kl.sum()
    elif reduction == 'none':
        kl_loss = kl
    else:
        raise ValueError(f"Unknown reduction: {reduction}")

    return kl_loss

