import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MutualInformationLoss(nn.Module):
    """Mutual Information Estimation and Minimization"""

    def __init__(self, method='hsic', sigma=None):
        super().__init__()
        self.method = method
        self.sigma = sigma

    def forward(self, x, y):
        """
        Estimate mutual information between x and y

        Args:
            x: [batch_size, dim_x]
            y: [batch_size, dim_y]

        Returns:
            MI estimate (scalar)
        """
        if self.method == 'hsic':
            return self.hsic(x, y)
        elif self.method == 'infonce':
            return self.infonce_mi(x, y)
        elif self.method == 'dv':
            return self.donsker_varadhan(x, y)
        elif self.method == 'mine':
            return self.mine_estimator(x, y)
        else:
            raise ValueError(f"Unknown method: {self.method}")

    def hsic(self, x, y, sigma_x=None, sigma_y=None):
        """
        Hilbert-Schmidt Independence Criterion (HSIC)
        Used to estimate upper bound of mutual information
        """
        batch_size = x.shape[0]

        # Center
        x = x - x.mean(dim=0, keepdim=True)
        y = y - y.mean(dim=0, keepdim=True)

        # Compute kernel matrices
        K = self._rbf_kernel(x, sigma_x)
        L = self._rbf_kernel(y, sigma_y)

        # HSIC statistic
        H = torch.eye(batch_size, device=x.device) - 1.0 / batch_size
        hsic_value = torch.trace(K @ H @ L @ H) / ((batch_size - 1) ** 2)

        return hsic_value

    def infonce_mi(self, x, y, temperature=0.1):
        """
        InfoNCE mutual information lower bound estimation
        """
        batch_size = x.shape[0]

        # Compute similarity matrix
        x_norm = F.normalize(x, dim=1)
        y_norm = F.normalize(y, dim=1)
        sim_matrix = torch.mm(x_norm, y_norm.T) / temperature

        # Positive sample pairs (diagonal)
        pos_sim = torch.diag(sim_matrix)

        # InfoNCE loss (negative log likelihood)
        exp_sim = torch.exp(sim_matrix)
        denom = exp_sim.sum(dim=1) - torch.exp(torch.diag(sim_matrix))

        nll = -pos_sim + torch.log(denom + 1e-8)

        # Mutual information lower bound
        mi_lower_bound = torch.log(torch.tensor(batch_size, dtype=torch.float)) - nll.mean()

        return mi_lower_bound

    def donsker_varadhan(self, x, y, T_network=None):
        """
        Donsker-Varadhan representation for mutual information estimation
        """
        if T_network is None:
            # Default critic network
            T_network = nn.Sequential(
                nn.Linear(x.shape[1] + y.shape[1], 128),
                nn.ReLU(),
                nn.Linear(128, 128),
                nn.ReLU(),
                nn.Linear(128, 1)
            ).to(x.device)

        # Positive sample pairs
        xy = torch.cat([x, y], dim=1)
        T_positive = T_network(xy)

        # Negative sample pairs (y randomly shuffled)
        y_shuffled = y[torch.randperm(y.shape[0])]
        xy_negative = torch.cat([x, y_shuffled], dim=1)
        T_negative = T_network(xy_negative)

        # Donsker-Varadhan estimate
        mi_estimate = T_positive.mean() - torch.log(T_negative.exp().mean() + 1e-8)

        return mi_estimate, T_network

    def mine_estimator(self, x, y, T_network=None, ema_decay=0.1):
        """
        Mutual Information Neural Estimation (MINE)
        """
        if T_network is None:
            T_network = nn.Sequential(
                nn.Linear(x.shape[1] + y.shape[1], 256),
                nn.ReLU(),
                nn.Linear(256, 256),
                nn.ReLU(),
                nn.Linear(256, 1)
            ).to(x.device)

        # Positive sample pairs
        xy_joint = torch.cat([x, y], dim=1)
        T_joint = T_network(xy_joint)

        # Negative sample pairs (y independently sampled)
        y_marginal = y[torch.randperm(y.shape[0])]
        xy_marginal = torch.cat([x, y_marginal], dim=1)
        T_marginal = T_network(xy_marginal)

        # MINE estimate
        mi_estimate = T_joint.mean() - torch.logsumexp(T_marginal, dim=0) + np.log(x.shape[0])

        return mi_estimate, T_network

    def _rbf_kernel(self, x, sigma=None):
        """RBF kernel function"""
        pairwise_dists = torch.cdist(x, x, p=2)

        if sigma is None:
            # Use median heuristic
            sigma = torch.median(pairwise_dists[pairwise_dists > 0])
            if sigma == 0:
                sigma = 1.0

        kernel = torch.exp(-pairwise_dists ** 2 / (2 * sigma ** 2))

        # Ensure numerical stability
        kernel = kernel + torch.eye(kernel.shape[0], device=kernel.device) * 1e-8

        return kernel


class DisentanglementLoss(nn.Module):
    """Disentanglement Loss Combination"""

    def __init__(self, mi_method='hsic', mi_weight=1.0, contrast_weight=0.5):
        super().__init__()
        self.mi_method = mi_method
        self.mi_weight = mi_weight
        self.contrast_weight = contrast_weight

        self.mi_estimator = MutualInformationLoss(method=mi_method)

    def forward(self, invariant_embs, variant_embs, env_labels=None):
        """
        Compute disentanglement loss

        Args:
            invariant_embs: dict {env_id: [batch_size, dim]}
            variant_embs: dict {env_id: [batch_size, dim]}
            env_labels: environment labels

        Returns:
            total disentanglement loss
        """
        # 1. Mutual information minimization loss
        mi_loss = 0
        for env_id in invariant_embs.keys():
            inv_emb = invariant_embs[env_id]
            var_emb = variant_embs[env_id]
            mi_loss += self.mi_estimator(inv_emb, var_emb)

        mi_loss = mi_loss / len(invariant_embs)

        # 2. Environment contrast loss (encourage variant representations to distinguish environments)
        contrast_loss = 0
        if self.contrast_weight > 0 and len(variant_embs) > 1:
            contrast_loss = self.environment_contrast_loss(variant_embs, env_labels)

        # Total loss
        total_loss = self.mi_weight * mi_loss + self.contrast_weight * contrast_loss

        return {
            'total': total_loss,
            'mi_loss': mi_loss,
            'contrast_loss': contrast_loss
        }

    def environment_contrast_loss(self, variant_embs, env_labels=None):
        """Environment contrast loss"""
        env_ids = list(variant_embs.keys())
        loss = 0
        count = 0

        for i, env_i in enumerate(env_ids):
            for j, env_j in enumerate(env_ids[i + 1:], i + 1):
                emb_i = variant_embs[env_i]
                emb_j = variant_embs[env_j]

                # Batch sizes might be different
                min_batch = min(emb_i.shape[0], emb_j.shape[0])
                emb_i = emb_i[:min_batch]
                emb_j = emb_j[:min_batch]

                # Compute distance
                distance = F.pairwise_distance(emb_i, emb_j, p=2)

                # Want representations from different environments to be far apart
                loss += torch.exp(-distance).mean()
                count += 1

        if count > 0:
            loss = loss / count

        return loss
