"""
losses.py
Loss functions for AWML training.
Includes ELBO, conservation penalty, modular independence penalty, and transfer regularization.
"""

import torch
import torch.nn.functional as F


def elbo_loss(obs, obs_recon, dists, z_samples, beta=1.0):
    recon_loss = F.mse_loss(obs_recon, obs, reduction="mean")
    kl_loss = sum([torch.distributions.kl_divergence(d, torch.distributions.Normal(0, 1)).mean() for d in dists])
    return recon_loss + beta * kl_loss


def conservation_penalty(z_t, z_next, invariants_fn):
    """
    Penalize violation of conservation laws.
    invariants_fn: function I(z) -> invariant scalar.
    """
    inv_t = invariants_fn(z_t)
    inv_next = invariants_fn(z_next)
    return F.mse_loss(inv_t, inv_next)


def modular_independence_penalty(z_samples):
    """
    Approximate total correlation penalty.
    Encourage independence between modules.
    """
    z_flat = z_samples.view(z_samples.size(0), -1)
    cov = torch.cov(z_flat.T)
    off_diag = cov - torch.diag(torch.diag(cov))
    return torch.norm(off_diag, p="fro")


def transfer_regularization(params, shared_params):
    """
    Penalize divergence between current params and shared cross-domain params.
    """
    return sum(F.mse_loss(p, sp.detach()) for p, sp in zip(params, shared_params))
