import torch
import torch.nn as nn
import torch.nn.functional as F


class DisentangleLoss(nn.Module):
    """Representation Disentanglement Loss (Mutual Information Minimization)"""

    def __init__(self, method='hsic'):
        super().__init__()
        self.method = method

    def forward(self, invariant_emb, variant_emb, env_labels=None):
        """
        Calculate disentanglement loss
        invariant_emb: [batch_size, feature_dim] invariant representations
        variant_emb: [batch_size, feature_dim] variant representations
        env_labels: environment labels (for InfoNCE)
        """
        if self.method == 'hsic':
            return self.hsic_loss(invariant_emb, variant_emb)
        elif self.method == 'infonce':
            return self.infonce_loss(invariant_emb, variant_emb, env_labels)
        else:
            raise ValueError(f"Unknown method: {self.method}")

    def hsic_loss(self, x, y):
        """Hilbert-Schmidt Independence Criterion (HSIC)"""
        # Centralization
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        # Kernel matrices
        K = self._rbf_kernel(x)
        L = self._rbf_kernel(y)

        # HSIC estimation
        n = x.shape[0]
        hsic = torch.trace(K @ L) / (n ** 2)

        # Regularization term
        hsic += torch.trace(K) * torch.trace(L) / (n ** 4)
        hsic -= 2 * torch.sum(K @ L) / (n ** 3)

        return hsic

    def infonce_loss(self, x, y, env_labels):
        """InfoNCE loss"""
        batch_size = x.shape[0]

        # Calculate similarity matrix
        sim_matrix = F.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0), dim=2)

        # Positive sample pairs (same sample)
        pos_sim = torch.diag(sim_matrix)

        # Negative sample pairs
        neg_sim = sim_matrix - torch.eye(batch_size, device=x.device) * 1e12

        # InfoNCE loss
        loss = -torch.log(
            torch.exp(pos_sim) / (torch.exp(pos_sim) + torch.exp(neg_sim).sum(dim=1))
        ).mean()

        return loss

    def _rbf_kernel(self, x, sigma=None):
        """RBF kernel function"""
        n = x.shape[0]
        pairwise_dists = torch.cdist(x, x, p=2)

        if sigma is None:
            sigma = torch.median(pairwise_dists)

        kernel = torch.exp(-pairwise_dists ** 2 / (2 * sigma ** 2))
        return kernel

    def environment_contrastive_loss(self, variant_embeddings):
        """
        Environment contrastive loss: Encourage variant representations to capture environmental differences
        variant_embeddings: dict {env_id: [batch_size, feature_dim]}
        """
        loss = 0
        env_ids = list(variant_embeddings.keys())

        for i in range(len(env_ids)):
            for j in range(i + 1, len(env_ids)):
                emb_i = variant_embeddings[env_ids[i]]
                emb_j = variant_embeddings[env_ids[j]]

                # Maximize differences between different environments
                diff = F.pairwise_distance(emb_i, emb_j, p=2)
                loss += torch.exp(-diff).mean()  # Want large distance, so minimize exp(-distance)

        return loss