import torch
import torch.nn as nn
import torch.nn.functional as F


class AdversarialLoss(nn.Module):
    """Adversarial Training Loss"""

    def __init__(self, num_environments=2, grad_reverse=True, lambda_gp=10.0):
        super().__init__()
        self.num_environments = num_environments
        self.grad_reverse = grad_reverse
        self.lambda_gp = lambda_gp

        # Environment discriminator
        self.discriminator = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, num_environments)
        )

        # Gradient reversal layer
        if grad_reverse:
            from torch.autograd import Function

            class GradientReversal(Function):
                @staticmethod
                def forward(ctx, x, lambda_):
                    ctx.lambda_ = lambda_
                    return x.view_as(x)

                @staticmethod
                def backward(ctx, grad_output):
                    return grad_output.neg() * ctx.lambda_, None

            self.grad_reverse_func = GradientReversal.apply

    def forward(self, embeddings, env_labels, lambda_=1.0):
        """
        Compute adversarial loss

        Args:
            embeddings: [batch_size, dim] representations
            env_labels: [batch_size] environment labels
            lambda_: gradient reversal strength

        Returns:
            adversarial loss
        """
        if self.grad_reverse:
            embeddings = self.grad_reverse_func(embeddings, lambda_)

        # Environment prediction
        env_pred = self.discriminator(embeddings)

        # Cross entropy loss
        adv_loss = F.cross_entropy(env_pred, env_labels)

        return adv_loss

    def gradient_penalty(self, real_embeddings, fake_embeddings):
        """
        Gradient penalty (WGAN-GP)
        """
        batch_size = real_embeddings.shape[0]

        # Random interpolation
        alpha = torch.rand(batch_size, 1, device=real_embeddings.device)
        alpha = alpha.expand_as(real_embeddings)

        interpolated = alpha * real_embeddings + (1 - alpha) * fake_embeddings
        interpolated.requires_grad_(True)

        # Discriminator output
        disc_interpolated = self.discriminator(interpolated)

        # Compute gradients
        gradients = torch.autograd.grad(
            outputs=disc_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(disc_interpolated),
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        # Gradient penalty
        gradients_norm = gradients.view(batch_size, -1).norm(2, dim=1)
        gradient_penalty = ((gradients_norm - 1) ** 2).mean()

        return gradient_penalty


class InvarianceAdversarialLoss(nn.Module):
    """Invariance adversarial loss (preserve class invariance)"""

    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

        # Class discriminator
        self.class_discriminator = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, num_classes)
        )

    def forward(self, env_embeddings_list, labels):
        """
        Compute invariance loss

        Args:
            env_embeddings_list: list of [batch_size, dim] representations from different environments
            labels: [batch_size] true labels

        Returns:
            invariance loss
        """
        invariance_loss = 0

        for env_emb in env_embeddings_list:
            # Class prediction
            pred = self.class_discriminator(env_emb)

            # Expect predictions to match true labels
            loss = F.cross_entropy(pred, labels)
            invariance_loss += loss

        invariance_loss = invariance_loss / len(env_embeddings_list)

        return invariance_loss


class EnvironmentDivergenceLoss(nn.Module):
    """Environment divergence maximization loss"""

    def __init__(self, divergence_type='mmd'):
        super().__init__()
        self.divergence_type = divergence_type

    def forward(self, env_embeddings_list):
        """
        Maximize divergence between representations from different environments

        Args:
            env_embeddings_list: list of [batch_size, dim]

        Returns:
            environment divergence loss
        """
        if len(env_embeddings_list) < 2:
            return torch.tensor(0.0, device=env_embeddings_list[0].device)

        divergence = 0
        count = 0

        for i in range(len(env_embeddings_list)):
            for j in range(i + 1, len(env_embeddings_list)):
                if self.divergence_type == 'mmd':
                    div = self.mmd_loss(env_embeddings_list[i], env_embeddings_list[j])
                elif self.divergence_type == 'wasserstein':
                    div = self.wasserstein_distance(env_embeddings_list[i], env_embeddings_list[j])
                else:
                    raise ValueError(f"Unknown divergence type: {self.divergence_type}")

                # We want to maximize divergence, so take negative
                divergence += -div
                count += 1

        if count > 0:
            divergence = divergence / count

        return divergence

    def mmd_loss(self, x, y, sigma=None):
        """Maximum Mean Discrepancy (MMD)"""
        xx = self._gaussian_kernel(x, x, sigma)
        yy = self._gaussian_kernel(y, y, sigma)
        xy = self._gaussian_kernel(x, y, sigma)

        mmd = xx.mean() + yy.mean() - 2 * xy.mean()
        return mmd

    def wasserstein_distance(self, x, y):
        """Approximate Wasserstein distance"""
        # Distance between empirical distributions
        x_mean = x.mean(dim=0)
        y_mean = y.mean(dim=0)

        # Mean difference
        mean_diff = torch.norm(x_mean - y_mean, p=2)

        # Covariance difference
        x_cov = torch.cov(x.T)
        y_cov = torch.cov(y.T)

        cov_diff = torch.norm(x_cov - y_cov, p='fro')

        return mean_diff + 0.1 * cov_diff

    def _gaussian_kernel(self, x, y, sigma):
        """Gaussian kernel"""
        x_size = x.shape[0]
        y_size = y.shape[0]
        dim = x.shape[1]

        x = x.view(x_size, 1, dim)
        y = y.view(1, y_size, dim)

        tiled_x = x.repeat(1, y_size, 1)
        tiled_y = y.repeat(x_size, 1, 1)

        if sigma is None:
            sigma = dim * torch.var(torch.cat([x.flatten(), y.flatten()]))
            if sigma == 0:
                sigma = 1.0

        kernel = torch.exp(-torch.sum((tiled_x - tiled_y) ** 2, dim=2) / (2 * sigma))
        return kernel
