"""KL divergence utilities for VAE training."""

from __future__ import annotations
import torch


def kl_diag_gaussian(mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    """Compute KL divergence between N(mu, sigma^2) and N(0, I).

    Args:
        mu: Mean of shape (B, K, D) or (B, D)
        logvar: Log variance of shape (B, K, D) or (B, D)

    Returns:
        KL divergence per sample, shape (B,)
    """
    # Sum over latent dimensions (last 1 or 2 dims depending on shape)
    if mu.ndim == 3:
        # Token-level latents: (B, K, D)
        return 0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1.0 - logvar, dim=(-1, -2))
    else:
        # Standard latents: (B, D)
        return 0.5 * torch.sum(torch.exp(logvar) + mu**2 - 1.0 - logvar, dim=-1)
