import torch
import torch.nn as nn
import numpy as np


class WassersteinPrototype(nn.Module):
    """Wasserstein Prototype Computation and Alignment Module"""

    def __init__(self, feature_dim, num_classes, num_environments, wasserstein_method='exact'):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.num_environments = num_environments
        from loss.wasserstein import WassersteinDistance
        self.wasserstein = WassersteinDistance(method=wasserstein_method)

        # Initialize class prototypes (learnable parameters)
        self.prototypes = nn.Parameter(
            torch.randn(num_classes, feature_dim)
        )

        # Environment weights (uniform distribution)
        self.env_weights = torch.ones(num_environments) / num_environments

    def compute_gaussian_params(self, embeddings):
        """Compute mean and covariance of Gaussian distribution"""
        # embeddings: [batch_size, feature_dim]
        mean = torch.mean(embeddings, dim=0)  # [feature_dim]

        if embeddings.shape[0] > 1:
            cov = torch.cov(embeddings.T)  # [feature_dim, feature_dim]
        else:
            cov = torch.eye(self.feature_dim, device=embeddings.device)

        return mean, cov

    def wasserstein_distance_gaussian(self, mean1, cov1, mean2, cov2):
        """
        Compute Wasserstein distance between two Gaussian distributions
        Formula: W² = ||μ1 - μ2||² + Tr(Σ1 + Σ2 - 2(Σ1^{1/2}Σ2Σ1^{1/2})^{1/2})
        """
        # Mean difference
        mean_diff = torch.norm(mean1 - mean2, p=2)

        # Square root of covariance matrices
        sqrt_cov1 = self._matrix_sqrt(cov1)
        sqrt_cov2 = self._matrix_sqrt(cov2)

        # Trace term calculation
        trace_term = torch.trace(cov1 + cov2 - 2 * self._matrix_sqrt(
            sqrt_cov1 @ cov2 @ sqrt_cov1
        ))

        return mean_diff + trace_term

    def _matrix_sqrt(self, matrix):
        """Matrix square root (symmetric positive definite)"""
        # Use eigendecomposition
        eigvals, eigvecs = torch.linalg.eigh(matrix)
        sqrt_eigvals = torch.sqrt(torch.clamp(eigvals, min=1e-8))
        return eigvecs @ torch.diag(sqrt_eigvals) @ eigvecs.T

    def compute_barycenter(self, gaussian_params_list):
        """
        Compute Wasserstein barycenter (iterative optimization)
        gaussian_params_list: list of (mean, cov) for each environment
        """
        # Initialize barycenter as weighted average
        barycenter_mean = torch.zeros(self.feature_dim, device=gaussian_params_list[0][0].device)
        barycenter_cov = torch.eye(self.feature_dim, device=barycenter_mean.device)

        # Iterative optimization (simplified version)
        for _ in range(10):  # Number of iterations
            total_distance = 0
            for (mean, cov), weight in zip(gaussian_params_list, self.env_weights):
                distance = self.wasserstein_distance_gaussian(
                    barycenter_mean, barycenter_cov, mean, cov
                )
                total_distance += weight * distance

            # Gradient descent update (simplified)
            # Actual implementation could use more complex optimization algorithms

        return barycenter_mean, barycenter_cov

    def alignment_loss(self, invariant_embeddings, labels):
        """
        Prototype alignment loss
        invariant_embeddings: dict {env_id: [batch_size, feature_dim]}
        labels: [batch_size]
        """
        total_loss = 0

        for class_id in range(self.num_classes):
            # Collect representations of this class across all environments
            class_embeddings_per_env = []

            for env_id in range(self.num_environments):
                env_embeddings = invariant_embeddings[env_id]
                class_mask = (labels == class_id)

                if class_mask.sum() > 0:
                    class_emb = env_embeddings[class_mask]
                    mean, cov = self.compute_gaussian_params(class_emb)
                    class_embeddings_per_env.append((mean, cov))

            if len(class_embeddings_per_env) > 0:
                # Compute Wasserstein barycenter for current class
                barycenter_mean, barycenter_cov = self.compute_barycenter(class_embeddings_per_env)

                # Compute alignment loss
                for (mean, cov), weight in zip(class_embeddings_per_env, self.env_weights):
                    distance = self.wasserstein_distance_gaussian(
                        barycenter_mean, barycenter_cov, mean, cov
                    )
                    total_loss += weight * distance

        return total_loss