import torch
import torch.nn as nn
import torch.nn.functional as F


class TwoClusterEmptyPenalty(nn.Module):
    """
    Adds β · Σ_k ReLU(τ − p̂_k)

    • p̂_k = proportion of samples in cluster k inside the current batch
    • τ = small safety threshold (e.g. 1 / batch_size or 0.01)
    • β = penalty weight you set by hand

    If either cluster’s proportion drops below τ, the ReLU term
    becomes positive and the loss rises; otherwise it contributes 0.
    Everything is differentiable w.r.t. the logits.
    """
    def __init__(self, beta = 1.0, tau = 0.45, eps = 1e-8):
        super().__init__()
        self.beta = beta
        self.tau = tau
        self.eps = eps

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """
        logits: [N, 2] - raw scores for the two clusters.
        """
        probs = F.softmax(logits, dim=1) # [N, 2]
        p_hat = probs.mean(dim=0) # length‑2 vector

        # Default τ = 1 / batch_size (≈ “at least one sample” in expectation)
        if self.tau is None:
            tau = 1.0 / logits.size(0)
        else:
            tau = self.tau

        # Penalty: only active when p̂_k < τ
        penalty = F.relu(tau - p_hat).sum()
        return self.beta * penalty


def chebyshev_distance(x, y):
    diff = torch.abs(x.unsqueeze(1) - y.unsqueeze(0))  # shape (n, m, d)
    cheb_dist = torch.max(diff, dim=2).values          # shape (n, m)
    return cheb_dist

def chebyshev_mmd_loss(x, y, sigma=1.0):
    cheb_dist_xx = chebyshev_distance(x, x)  # (n, n)
    cheb_dist_yy = chebyshev_distance(y, y)  # (m, m)
    cheb_dist_xy = chebyshev_distance(x, y)  # (n, m)

    # Compute the Chebyshev kernel
    k_xx = torch.exp(-cheb_dist_xx / sigma)  # (n, n)
    k_yy = torch.exp(-cheb_dist_yy / sigma)  # (m, m)
    k_xy = torch.exp(-cheb_dist_xy / sigma)  # (n, m)

    # MMD loss computation
    mmd_loss = (1 / (x.size(0) ** 2)) * torch.sum(k_xx) + \
               (1 / (y.size(0) ** 2)) * torch.sum(k_yy) - \
               (2 / (x.size(0) * y.size(0))) * torch.sum(k_xy)

    return mmd_loss
