"""Loss functions for INN training.

This module provides the loss functions needed for bidirectional INN training:
- Maximum Mean Discrepancy (MMD) with Inverse Multiquadratic (IMQ) kernel
- Combined INN loss (L_y + L_z + L_x)
"""

import torch
from typing import Sequence


def imq_kernel(
    x: torch.Tensor,
    y: torch.Tensor,
    bandwidths: Sequence[float] | None = None,
    c: float = 1.0,
) -> torch.Tensor:
    """Compute Inverse Multiquadratic (IMQ) kernel between two sets of samples.

    The IMQ kernel is defined as:
        k(x, y) = sum_sigma C / (C + ||x - y||^2 / sigma^2)

    where the sum is over multiple bandwidth values sigma.

    Args:
        x: First set of samples, shape (N, D).
        y: Second set of samples, shape (M, D).
        bandwidths: List of bandwidth values. If None, uses default heuristic.
        c: Constant C in the kernel formula.

    Returns:
        Kernel matrix of shape (N, M).
    """
    if bandwidths is None:
        # Default bandwidths based on median heuristic and multiscale
        bandwidths = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]

    # Compute squared pairwise distances: ||x_i - y_j||^2
    # x: (N, D), y: (M, D)
    # distances: (N, M)
    x_sq = (x**2).sum(dim=1, keepdim=True)  # (N, 1)
    y_sq = (y**2).sum(dim=1, keepdim=True)  # (M, 1)
    xy = x @ y.t()  # (N, M)
    distances_sq = x_sq - 2 * xy + y_sq.t()  # (N, M)
    # Clamp to avoid numerical issues
    distances_sq = torch.clamp(distances_sq, min=0.0)

    # Sum over multiple bandwidths
    kernel = torch.zeros_like(distances_sq)
    for sigma in bandwidths:
        sigma_sq = sigma**2
        kernel = kernel + c / (c + distances_sq / sigma_sq)

    return kernel


def mmd_loss(
    x: torch.Tensor,
    y: torch.Tensor,
    kernel: str = "imq",
    bandwidths: Sequence[float] | None = None,
) -> torch.Tensor:
    """Compute Maximum Mean Discrepancy (MMD) between two sample sets.

    MMD is a distance metric between probability distributions based on kernel
    embedding. It's computed as:
        MMD^2(P, Q) = E[k(x, x')] - 2*E[k(x, y)] + E[k(y, y')]

    Args:
        x: Samples from distribution P, shape (N, D).
        y: Samples from distribution Q, shape (M, D).
        kernel: Kernel type. Currently only "imq" (inverse multiquadratic) supported.
        bandwidths: Bandwidth values for the kernel. If None, uses default heuristic.

    Returns:
        MMD^2 estimate (scalar tensor).
    """
    if kernel != "imq":
        raise ValueError(f"Unknown kernel: {kernel}. Only 'imq' is supported.")

    n = x.shape[0]
    m = y.shape[0]

    # Compute kernel matrices
    k_xx = imq_kernel(x, x, bandwidths)  # (N, N)
    k_yy = imq_kernel(y, y, bandwidths)  # (M, M)
    k_xy = imq_kernel(x, y, bandwidths)  # (N, M)

    # Compute unbiased MMD^2 estimate
    # E[k(x, x')] for x != x' (exclude diagonal)
    # Sum all elements and subtract diagonal, then divide by n*(n-1)
    sum_k_xx = k_xx.sum() - k_xx.diag().sum()
    sum_k_yy = k_yy.sum() - k_yy.diag().sum()
    sum_k_xy = k_xy.sum()

    # Unbiased estimator
    if n > 1:
        term_xx = sum_k_xx / (n * (n - 1))
    else:
        term_xx = 0.0

    if m > 1:
        term_yy = sum_k_yy / (m * (m - 1))
    else:
        term_yy = 0.0

    term_xy = 2 * sum_k_xy / (n * m)

    mmd_sq = term_xx + term_yy - term_xy
    return mmd_sq


def inn_loss(
    model: torch.nn.Module,
    x: torch.Tensor,
    y: torch.Tensor,
    lambda_y: float = 1.0,
    lambda_z: float = 1.0,
    lambda_x: float = 1.0,
    mmd_bandwidths: Sequence[float] | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
    """Compute total INN loss for bidirectional training.

    The loss has three components:
    - L_y: Forward MSE loss between predicted and true labels.
    - L_z: MMD loss between latent z and standard Gaussian prior.
    - L_x: Backward MMD loss between reconstructed and original designs.

    Total loss: L = lambda_y * L_y + lambda_z * L_z + lambda_x * L_x

    Args:
        model: INN model with forward(x) -> (y, z) and inverse(y, z) -> x.
        x: Design parameters batch, shape (N, input_dim).
        y: Labels batch, shape (N, output_dim).
        lambda_y: Weight for forward MSE loss.
        lambda_z: Weight for latent MMD loss.
        lambda_x: Weight for backward MMD loss.
        mmd_bandwidths: Bandwidth values for MMD kernel.

    Returns:
        total_loss: Weighted sum of all losses.
        loss_dict: Dictionary with individual loss values for logging.
    """
    batch_size = x.shape[0]
    device = x.device

    # Forward pass: x -> (y_pred, z_pred)
    y_pred, z_pred = model.forward(x)

    # L_y: Forward MSE loss
    mse = torch.nn.functional.mse_loss(y_pred, y)
    L_y = mse

    # L_z: Latent distribution loss (MMD with standard Gaussian)
    z_prior = torch.randn_like(z_pred)
    L_z = mmd_loss(z_pred, z_prior, bandwidths=mmd_bandwidths)

    # L_x: Backward reconstruction loss
    # Sample z from prior and reconstruct x from (y, z_sampled)
    z_sampled = torch.randn(batch_size, model.latent_dim, device=device)
    x_reconstructed = model.inverse(y, z_sampled)
    L_x = mmd_loss(x_reconstructed, x, bandwidths=mmd_bandwidths)

    # Total loss
    total_loss = lambda_y * L_y + lambda_z * L_z + lambda_x * L_x

    loss_dict = {
        "L_y": L_y.item(),
        "L_z": L_z.item(),
        "L_x": L_x.item(),
        "total": total_loss.item(),
    }

    return total_loss, loss_dict


def conditional_inn_loss(
    model: torch.nn.Module,
    x: torch.Tensor,
    y: torch.Tensor,
    c: torch.Tensor,
    lambda_y: float = 1.0,
    lambda_z: float = 1.0,
    lambda_x: float = 1.0,
    mmd_bandwidths: Sequence[float] | None = None,
) -> tuple[torch.Tensor, dict[str, float]]:
    """Compute total conditional INN loss for bidirectional training.

    Similar to inn_loss but for conditional INNs that take conditioning c.

    The loss has three components:
    - L_y: Forward MSE loss between predicted and true labels.
    - L_z: MMD loss between latent z and standard Gaussian prior.
    - L_x: Backward MMD loss between reconstructed and original designs.

    Total loss: L = lambda_y * L_y + lambda_z * L_z + lambda_x * L_x

    Args:
        model: Conditional INN model with forward(x, c) -> (y, z) and inverse(y, z, c) -> x.
        x: Design parameters batch, shape (N, input_dim).
        y: Labels batch, shape (N, output_dim).
        c: Conditioning variables batch, shape (N, cond_dim).
        lambda_y: Weight for forward MSE loss.
        lambda_z: Weight for latent MMD loss.
        lambda_x: Weight for backward MMD loss.
        mmd_bandwidths: Bandwidth values for MMD kernel.

    Returns:
        total_loss: Weighted sum of all losses.
        loss_dict: Dictionary with individual loss values for logging.
    """
    batch_size = x.shape[0]
    device = x.device

    # Forward pass: x -> (y_pred, z_pred) given c
    y_pred, z_pred = model.forward(x, c)

    # L_y: Forward MSE loss
    mse = torch.nn.functional.mse_loss(y_pred, y)
    L_y = mse

    # L_z: Latent distribution loss (MMD with standard Gaussian)
    z_prior = torch.randn_like(z_pred)
    L_z = mmd_loss(z_pred, z_prior, bandwidths=mmd_bandwidths)

    # L_x: Backward reconstruction loss
    # Sample z from prior and reconstruct x from (y, z_sampled, c)
    z_sampled = torch.randn(batch_size, model.latent_dim, device=device)
    x_reconstructed = model.inverse(y, z_sampled, c)
    L_x = mmd_loss(x_reconstructed, x, bandwidths=mmd_bandwidths)

    # Total loss
    total_loss = lambda_y * L_y + lambda_z * L_z + lambda_x * L_x

    loss_dict = {
        "L_y": L_y.item(),
        "L_z": L_z.item(),
        "L_x": L_x.item(),
        "total": total_loss.item(),
    }

    return total_loss, loss_dict
