import torch
import torch.nn as nn
import torch.nn.functional as F


class EnvironmentGenerator(nn.Module):
    """Environment Generator (adversarial training)"""

    def __init__(self, input_dim, hidden_dim, num_environments=2, num_classes=10):
        super().__init__()
        self.num_environments = num_environments
        self.num_classes = num_classes

        # Environment feature transformers
        self.env_transformers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, input_dim)
            ) for _ in range(num_environments)
        ])

        # Adversarial discriminator (distinguish different environments)
        self.discriminator = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_environments)
        )

        # Class-preserving classifier - corrected to support multi-class classification
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)  # Changed to actual number of classes
        )

    def forward(self, x, env_id):
        """Transform input x to specified environment"""
        if env_id >= self.num_environments:
            env_id = env_id % self.num_environments  # Ensure index is within range
        return self.env_transformers[env_id](x)

    def adversarial_loss(self, env_representations):
        """
        Corrected version: Use multi-class cross entropy instead of binary cross entropy
        """
        if not env_representations or len(env_representations) < 2:
            return torch.tensor(0.0, device=next(self.parameters()).device)

        loss = 0
        num_pairs = 0

        for i in range(len(env_representations)):
            for j in range(i + 1, len(env_representations)):
                # Get features from two environments
                feat_i = env_representations[i]
                feat_j = env_representations[j]

                # Ensure batch sizes match
                min_batch = min(feat_i.shape[0], feat_j.shape[0])
                if min_batch == 0:
                    continue

                feat_i = feat_i[:min_batch]
                feat_j = feat_j[:min_batch]

                # Samples from environment i should be classified as class i
                pred_i = self.discriminator(feat_i)
                target_i = torch.full((min_batch,), i, dtype=torch.long, device=feat_i.device)

                # Samples from environment j should be classified as class j
                pred_j = self.discriminator(feat_j)
                target_j = torch.full((min_batch,), j, dtype=torch.long, device=feat_j.device)

                # Use cross entropy loss
                loss_i = F.cross_entropy(pred_i, target_i)
                loss_j = F.cross_entropy(pred_j, target_j)

                loss += loss_i + loss_j
                num_pairs += 2

        if num_pairs > 0:
            return loss / num_pairs
        return torch.tensor(0.0, device=env_representations[0].device)

    def invariance_loss(self, env_representations, labels):
        """
        Corrected version: Handle multi-class classification task
        """
        if not env_representations:
            return torch.tensor(0.0, device=labels.device)

        loss = 0
        valid_count = 0

        for i, env_feat in enumerate(env_representations):
            if env_feat.shape[0] == labels.shape[0]:
                pred = self.classifier(env_feat)
                loss += F.cross_entropy(pred, labels)
                valid_count += 1

        if valid_count > 0:
            return loss / valid_count
        return torch.tensor(0.0, device=labels.device)

    def simple_adversarial_loss(self, env_representations):
        """
        Simplified version: Encourage feature differences between different environments
        """
        if len(env_representations) < 2:
            return torch.tensor(0.0, device=env_representations[0].device)

        loss = 0
        num_pairs = 0

        for i in range(len(env_representations)):
            for j in range(i + 1, len(env_representations)):
                feat_i = env_representations[i]
                feat_j = env_representations[j]

                # Calculate difference in means
                mean_i = feat_i.mean(dim=0)
                mean_j = feat_j.mean(dim=0)

                # Encourage mean difference (negative similarity)
                similarity = F.cosine_similarity(mean_i.unsqueeze(0), mean_j.unsqueeze(0))
                loss += -similarity  # Minimize similarity = Maximize difference
                num_pairs += 1

        if num_pairs > 0:
            return loss / num_pairs
        return torch.tensor(0.0, device=env_representations[0].device)