import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Dict, Optional
import warnings


class WassersteinDistance(nn.Module):
    """Wasserstein distance computation (supports recursive optimization)"""

    def __init__(self, method='exact', reg=1.0, epsilon=1e-8):
        super().__init__()
        self.method = method
        self.reg = reg
        self.epsilon = epsilon

    def forward(self, mu1: torch.Tensor, sigma1: torch.Tensor,
                mu2: torch.Tensor, sigma2: torch.Tensor) -> torch.Tensor:
        """Compute Wasserstein distance between two Gaussian distributions"""

        # Ensure covariance matrices are symmetric positive definite
        sigma1 = self._ensure_spd(sigma1)
        sigma2 = self._ensure_spd(sigma2)

        if self.method == 'exact':
            return self._exact_wasserstein(mu1, sigma1, mu2, sigma2)
        elif self.method == 'sinkhorn':
            return self._sinkhorn_wasserstein(mu1, sigma1, mu2, sigma2)
        elif self.method == 'trace':
            return self._trace_approximation(mu1, sigma1, mu2, sigma2)
        else:
            raise ValueError(f"Unknown method: {self.method}")

    def _ensure_spd(self, sigma):
        # Increase epsilon from 1e-8 to 1e-4 for float32 stability
        eps = 1e-4
        sigma = (sigma + sigma.T) / 2
        return sigma + torch.eye(sigma.size(0), device=sigma.device) * eps

    def _exact_wasserstein(self, mu1: torch.Tensor, sigma1: torch.Tensor,
                           mu2: torch.Tensor, sigma2: torch.Tensor) -> torch.Tensor:
        """Exact Wasserstein distance computation (for Gaussian distributions)"""

        # Mean difference
        mean_diff = torch.norm(mu1 - mu2, p=2)

        # Covariance term
        try:
            # Compute matrix square root
            sqrt_sigma1 = self._matrix_sqrt(sigma1)
            # Compute sqrt(sigma1 * sigma2 * sigma1)
            middle = sqrt_sigma1 @ sigma2 @ sqrt_sigma1
            sqrt_middle = self._matrix_sqrt(middle)
            # Trace term
            trace_term = torch.trace(sigma1 + sigma2 - 2 * sqrt_middle)
        except:
            # Fallback to approximation if exact computation fails
            trace_term = torch.norm(sigma1 - sigma2, p='fro')

        # Wasserstein distance squared
        w2_sq = mean_diff**2 + torch.clamp(trace_term, min=0)

        return torch.sqrt(w2_sq + self.epsilon)

    def _matrix_sqrt(self, A: torch.Tensor) -> torch.Tensor:
        """Compute matrix square root"""
        try:
            # Eigenvalue decomposition
            eigvals, eigvecs = torch.linalg.eigh(A)
            # Ensure non-negative eigenvalues
            eigvals = torch.clamp(eigvals, min=self.epsilon)
            sqrt_eigvals = torch.sqrt(eigvals)
            return eigvecs @ torch.diag(sqrt_eigvals) @ eigvecs.T
        except:
            # Fallback approximation
            return torch.eye(A.size(0), device=A.device) * torch.sqrt(torch.mean(torch.diag(A)))

    def _sinkhorn_wasserstein(self, mu1: torch.Tensor, sigma1: torch.Tensor,
                              mu2: torch.Tensor, sigma2: torch.Tensor,
                              n_samples: int = 100) -> torch.Tensor:
        """Sinkhorn algorithm for approximate Wasserstein distance"""

        # Sample from distributions
        samples1 = self._sample_gaussian(mu1, sigma1, n_samples)
        samples2 = self._sample_gaussian(mu2, sigma2, n_samples)

        # Compute cost matrix
        C = torch.cdist(samples1, samples2, p=2)

        # Sinkhorn iterations
        K = torch.exp(-C / self.reg)
        u = torch.ones(n_samples, device=mu1.device) / n_samples
        v = torch.ones(n_samples, device=mu2.device) / n_samples

        for _ in range(50):
            u = 1.0 / (K @ v + self.epsilon)
            v = 1.0 / (K.T @ u + self.epsilon)

        # Compute distance
        distance = torch.sum(u * (K * C) @ v)

        return distance

    def _trace_approximation(self, mu1: torch.Tensor, sigma1: torch.Tensor,
                             mu2: torch.Tensor, sigma2: torch.Tensor) -> torch.Tensor:
        """Trace approximation method (most stable)"""
        mean_diff = torch.norm(mu1 - mu2, p=2)
        cov_diff = torch.norm(sigma1 - sigma2, p='fro')

        # Combine mean and covariance differences
        distance = mean_diff + 0.1 * torch.sqrt(cov_diff + self.epsilon)
        return distance

    def _sample_gaussian(self, mu: torch.Tensor, sigma: torch.Tensor,
                         n_samples: int) -> torch.Tensor:
        """Sample from Gaussian distribution"""
        d = mu.size(0)
        sigma = self._ensure_spd(sigma)

        try:
            # Cholesky decomposition
            L = torch.linalg.cholesky(sigma)
        except:
            # Use eigenvalue decomposition if Cholesky fails
            eigvals, eigvecs = torch.linalg.eigh(sigma)
            eigvals = torch.clamp(eigvals, min=self.epsilon)
            L = eigvecs @ torch.diag(torch.sqrt(eigvals))

        z = torch.randn(n_samples, d, device=mu.device)
        samples = mu + z @ L.T

        return samples


class RecursiveWassersteinBarycenter(nn.Module):
    """Recursive Wasserstein Barycenter Computation"""

    def __init__(self, feature_dim: int, num_categories: int = 10,
                 max_iter: int = 10, tol: float = 1e-4,
                 wasserstein_method: str = 'trace'):
        super().__init__()

        self.feature_dim = feature_dim
        self.num_categories = num_categories
        self.max_iter = max_iter
        self.tol = tol

        # Wasserstein distance calculator
        self.wasserstein = WassersteinDistance(method=wasserstein_method)

        # Category prototypes (learnable parameters)
        self.prototype_means = nn.Parameter(
            torch.zeros(num_categories, feature_dim)
        )
        self.prototype_log_vars = nn.Parameter(
            torch.zeros(num_categories, feature_dim)
        )

        # Track convergence history
        self.convergence_history = []

    @property
    def prototype_covs(self) -> torch.Tensor:
        """Get prototype covariance matrices"""
        vars = torch.exp(self.prototype_log_vars) + self.wasserstein.epsilon
        # Return diagonal covariance matrices
        return torch.stack([torch.diag(v) for v in vars])

    def compute_barycenter_recursive(
        self,
        distributions: List[Tuple[torch.Tensor, torch.Tensor]],
        weights: Optional[torch.Tensor] = None,
        current_prototype: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        iteration: int = 0,
        verbose: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, List[float]]:
        """
        Recursively compute Wasserstein barycenter

        Args:
            distributions: List of (mean, cov) tuples
            weights: Weight vector
            current_prototype: Current prototype estimate
            iteration: Current iteration number
            verbose: Whether to print detailed information

        Returns:
            Final prototype (mean, cov), distance history
        """

        n_distributions = len(distributions)

        # Initialize weights
        if weights is None:
            weights = torch.ones(n_distributions, device=distributions[0][0].device) / n_distributions

        # Initialize prototype
        if current_prototype is None:
            # Initialize with weighted average
            init_mean = sum(w * mu for w, (mu, _) in zip(weights, distributions))
            init_mean = init_mean / weights.sum()

            # Initialize covariance
            init_cov = torch.zeros(self.feature_dim, self.feature_dim, device=init_mean.device)
            for w, (mu, cov) in zip(weights, distributions):
                init_cov += w * cov
            init_cov = init_cov / weights.sum()

            current_prototype = (init_mean, init_cov)

        current_mean, current_cov = current_prototype

        # Compute current total distance
        total_distance = 0.0
        for (mu, cov), w in zip(distributions, weights):
            distance = self.wasserstein(current_mean, current_cov, mu, cov)
            total_distance += w * distance

        # Record distance
        self.convergence_history.append(total_distance.item())

        if verbose:
            print(f"Iteration {iteration}: Total distance = {total_distance.item():.6f}")

        # Check convergence criteria
        if iteration >= self.max_iter:
            if verbose:
                print(f"Reached maximum iterations ({self.max_iter})")
            return current_mean, current_cov, self.convergence_history

        if iteration > 0:
            prev_distance = self.convergence_history[-2] if len(self.convergence_history) > 1 else float('inf')
            if abs(total_distance.item() - prev_distance) < self.tol:
                if verbose:
                    print(f"Converged at iteration {iteration}")
                return current_mean, current_cov, self.convergence_history

        # Recursively update prototype (L-Step)
        updated_mean, updated_cov = self._update_prototype(
            current_mean, current_cov, distributions, weights
        )

        # Recursive call
        return self.compute_barycenter_recursive(
            distributions, weights, (updated_mean, updated_cov),
            iteration + 1, verbose
        )

    def _update_prototype(
        self,
        current_mean: torch.Tensor,
        current_cov: torch.Tensor,
        distributions: List[Tuple[torch.Tensor, torch.Tensor]],
        weights: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update prototype estimate (gradient descent step)"""

        # Compute gradients (simplified version, auto-diff can be used)
        grad_mean = torch.zeros_like(current_mean)
        grad_cov = torch.zeros_like(current_cov)

        total_weight = 0.0

        for (mu, cov), w in zip(distributions, weights):
            # Mean gradient (move toward weighted mean)
            grad_mean += w * (mu - current_mean)

            # Covariance gradient (move toward weighted covariance)
            grad_cov += w * (cov - current_cov)

            total_weight += w

        # Normalize gradients
        if total_weight > 0:
            grad_mean = grad_mean / total_weight
            grad_cov = grad_cov / total_weight

        # Learning rates (adjustable)
        lr_mean = 0.1 / (1 + len(self.convergence_history))
        lr_cov = 0.01 / (1 + len(self.convergence_history))

        # Update prototype
        updated_mean = current_mean + lr_mean * grad_mean
        updated_cov = current_cov + lr_cov * grad_cov

        # Ensure covariance is symmetric positive definite
        updated_cov = self.wasserstein._ensure_spd(updated_cov)

        return updated_mean, updated_cov

    def align_to_prototypes(
        self,
        features: torch.Tensor,
        labels: torch.Tensor,
        max_alignment_iter: int = 5,
        alignment_strength: float = 0.5
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Align features to prototypes (E-Step)

        Args:
            features: Feature tensor [batch_size, feature_dim]
            labels: Labels [batch_size]
            max_alignment_iter: Maximum alignment iterations
            alignment_strength: Alignment strength

        Returns:
            Aligned features, alignment loss
        """

        batch_size = features.size(0)
        aligned_features = features.clone()
        total_alignment_loss = 0.0

        for iteration in range(max_alignment_iter):
            alignment_loss = 0.0

            for i in range(batch_size):
                y = labels[i]
                feature = aligned_features[i]

                # Get corresponding category prototype
                prototype_mean = self.prototype_means[y]
                prototype_cov = self.prototype_covs[y]

                # Compute current feature distribution (assume unit covariance)
                feature_cov = torch.eye(self.feature_dim, device=features.device) * 0.1

                # Compute Wasserstein distance to prototype
                distance = self.wasserstein(feature, feature_cov, prototype_mean, prototype_cov)
                alignment_loss += distance

                # Align feature toward prototype (simplified)
                if iteration < max_alignment_iter - 1:
                    # Compute alignment direction
                    alignment_direction = prototype_mean - feature

                    # Apply alignment (gradually decreasing)
                    step_size = alignment_strength * (1.0 - iteration / max_alignment_iter)
                    aligned_features[i] = feature + step_size * alignment_direction

            total_alignment_loss += alignment_loss / batch_size

            if iteration < max_alignment_iter - 1:
                # Optionally recompute prototypes after intermediate steps
                pass

        return aligned_features, total_alignment_loss / max_alignment_iter

    def forward(
        self,
        features_list: List[torch.Tensor],
        labels_list: List[torch.Tensor],
        num_environments: int,
        verbose: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
        """
        Main forward pass: Execute recursive prototype computation and alignment

        Args:
            features_list: List of features for each environment
            labels_list: List of labels for each environment
            num_environments: Number of environments
            verbose: Whether to print detailed information

        Returns:
            Prototype means, prototype covariances, statistics
        """

        assert len(features_list) == len(labels_list) == num_environments

        # Statistics
        stats = {
            'convergence_history': [],
            'alignment_losses': [],
            'inter_class_distances': []
        }

        # Compute prototype for each category
        final_prototype_means = []
        final_prototype_covs = []

        for y in range(self.num_categories):
            if verbose:
                print(f"\nComputing prototype for class {y}")

            # Collect features of class y from all environments
            class_distributions = []

            for env_idx in range(num_environments):
                env_features = features_list[env_idx]
                env_labels = labels_list[env_idx]

                # Get features of class y in current environment
                mask = (env_labels == y)
                if mask.sum() > 0:
                    class_features = env_features[mask]

                    # Compute mean and covariance
                    class_mean = class_features.mean(dim=0)
                    if class_features.size(0) > 1:
                        class_cov = torch.cov(class_features.T)
                    else:
                        class_cov = torch.eye(self.feature_dim, device=class_features.device) * 0.1

                    class_distributions.append((class_mean, class_cov))

            if len(class_distributions) == 0:
                # If no samples for this category, use random prototype
                prototype_mean = torch.randn(self.feature_dim, device=features_list[0].device) * 0.1
                prototype_cov = torch.eye(self.feature_dim, device=features_list[0].device)
            else:
                # Recursively compute Wasserstein barycenter (L-Step)
                prototype_mean, prototype_cov, conv_history = self.compute_barycenter_recursive(
                    class_distributions, verbose=verbose
                )

                stats['convergence_history'].append(conv_history)

            final_prototype_means.append(prototype_mean)
            final_prototype_covs.append(prototype_cov)

        # Stack results
        prototype_means = torch.stack(final_prototype_means)
        prototype_covs = torch.stack(final_prototype_covs)

        # Compute inter-class distances (for evaluating prototype separation)
        for i in range(self.num_categories):
            for j in range(i + 1, self.num_categories):
                dist = self.wasserstein(
                    prototype_means[i], prototype_covs[i],
                    prototype_means[j], prototype_covs[j]
                )
                stats['inter_class_distances'].append(dist.item())

        return prototype_means, prototype_covs, stats


class AlternatingOptimization(nn.Module):
    """Alternating Optimizer"""

    def __init__(self, feature_dim: int, num_categories: int = 10,
                 max_alternating_iter: int = 5,
                 prototype_update_iter: int = 3,
                 alignment_iter: int = 3):
        super().__init__()

        self.feature_dim = feature_dim
        self.num_categories = num_categories
        self.max_alternating_iter = max_alternating_iter
        self.prototype_update_iter = prototype_update_iter
        self.alignment_iter = alignment_iter

        # Wasserstein barycenter calculator
        self.barycenter_calculator = RecursiveWassersteinBarycenter(
            feature_dim=feature_dim,
            num_categories=num_categories,
            max_iter=prototype_update_iter
        )

        # Optimization history
        self.optimization_history = {
            'total_loss': [],
            'alignment_loss': [],
            'prototype_change': []
        }

    def forward(
        self,
        invariant_features: List[torch.Tensor],
        labels: List[torch.Tensor],
        num_environments: int
    ) -> Tuple[List[torch.Tensor], torch.Tensor, Dict]:
        """
        Execute alternating optimization:
        1. L-Step: Given invariant representations, update prototypes
        2. E-Step: Given prototypes, align invariant representations

        Args:
            invariant_features: List of invariant features for each environment
            labels: List of labels for each environment
            num_environments: Number of environments

        Returns:
            Aligned features, prototypes, optimization history
        """

        # Initialize
        aligned_features = [feat.clone() for feat in invariant_features]
        prev_prototypes = None

        for alt_iter in range(self.max_alternating_iter):
            print(f"\n{'='*50}")
            print(f"Alternating Optimization Iteration {alt_iter + 1}/{self.max_alternating_iter}")
            print(f"{'='*50}")

            # Step 1: L-Step - Update prototypes
            print("L-Step: Updating prototypes...")
            prototype_means, prototype_covs, prototype_stats = self.barycenter_calculator(
                aligned_features, labels, num_environments, verbose=True
            )

            # Step 2: E-Step - Align features to prototypes
            print("\nE-Step: Aligning features to prototypes...")
            total_alignment_loss = 0.0

            for env_idx in range(num_environments):
                env_features = aligned_features[env_idx]
                env_labels = labels[env_idx]

                # Align features of current environment
                aligned_env_features, alignment_loss = self.barycenter_calculator.align_to_prototypes(
                    env_features, env_labels,
                    max_alignment_iter=self.alignment_iter,
                    alignment_strength=0.3
                )

                aligned_features[env_idx] = aligned_env_features
                total_alignment_loss += alignment_loss

            avg_alignment_loss = total_alignment_loss / num_environments

            # Record optimization history
            self.optimization_history['total_loss'].append(
                avg_alignment_loss + prototype_stats.get('mean_distance', 0)
            )
            self.optimization_history['alignment_loss'].append(avg_alignment_loss.item())

            # Compute prototype change
            if prev_prototypes is not None:
                prototype_change = torch.norm(prototype_means - prev_prototypes, p='fro')
                self.optimization_history['prototype_change'].append(prototype_change.item())
                print(f"Prototype change: {prototype_change.item():.6f}")

            prev_prototypes = prototype_means.clone()

            print(f"Alignment loss: {avg_alignment_loss.item():.6f}")

            # Check convergence (simplified)
            if alt_iter > 0 and avg_alignment_loss < 1e-3:
                print(f"\nConverged early at iteration {alt_iter + 1}")
                break

        return aligned_features, prototype_means, self.optimization_history


# Helper functions
def compute_gaussian_parameters(features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute Gaussian distribution parameters from features"""
    mean = features.mean(dim=0)
    if features.size(0) > 1:
        cov = torch.cov(features.T)
    else:
        cov = torch.eye(features.size(1), device=features.device) * 0.1
    return mean, cov


def wasserstein_alignment_loss(
    invariant_features: torch.Tensor,
    variant_features: torch.Tensor,
    labels: torch.Tensor,
    prototype_means: torch.Tensor,
    prototype_covs: torch.Tensor,
    wasserstein_calculator: WassersteinDistance,
    lambda_align: float = 1.0,
    lambda_sep: float = 0.1
) -> torch.Tensor:
    """
    Compute Wasserstein alignment loss

    Args:
        invariant_features: Invariant features
        variant_features: Variant features
        labels: Labels
        prototype_means: Prototype means
        prototype_covs: Prototype covariances
        wasserstein_calculator: Wasserstein distance calculator
        lambda_align: Alignment loss weight
        lambda_sep: Separation loss weight

    Returns:
        Total alignment loss
    """

    batch_size = invariant_features.size(0)
    total_loss = 0.0

    # Alignment loss: Invariant features should be close to their category prototypes
    align_loss = 0.0
    for i in range(batch_size):
        y = labels[i]
        inv_feat = invariant_features[i]

        # Compute invariant feature distribution (simplified: unit covariance)
        inv_cov = torch.eye(inv_feat.size(0), device=inv_feat.device) * 0.1

        # Compute distance to prototype
        distance = wasserstein_calculator(
            inv_feat, inv_cov,
            prototype_means[y], prototype_covs[y]
        )
        align_loss += distance

    align_loss = align_loss / batch_size

    # Separation loss: Invariant and variant features should be different
    sep_loss = 0.0
    if variant_features is not None:
        for i in range(batch_size):
            inv_feat = invariant_features[i]
            var_feat = variant_features[i]

            # Compute feature distributions
            inv_cov = torch.eye(inv_feat.size(0), device=inv_feat.device) * 0.1
            var_cov = torch.eye(var_feat.size(0), device=var_feat.device) * 0.1

            # Encourage invariant and variant features to be different
            distance = wasserstein_calculator(inv_feat, inv_cov, var_feat, var_cov)
            # We want this distance to be large, so use negative log
            sep_loss += -torch.log(distance + 1e-6)

        sep_loss = sep_loss / batch_size

    # Total loss
    total_loss = lambda_align * align_loss + lambda_sep * sep_loss

    return total_loss, {'align_loss': align_loss, 'sep_loss': sep_loss}


# Module exports
__all__ = [
    'WassersteinDistance',
    'RecursiveWassersteinBarycenter',
    'AlternatingOptimization',
    'compute_gaussian_parameters',
    'wasserstein_alignment_loss'
]
