import torch
import torch.nn as nn
import torch.nn.functional as F


class DIGLModel(nn.Module):
    """Domain-Invariant Graph Learning Model for DisC"""

    def __init__(self, in_dim=784, hidden_dim=256, out_dim=10,
                 num_environments=3, use_wasserstein=True,
                 use_causal_intervention=True, dropout=0.3):
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_environments = num_environments
        self.use_wasserstein = use_wasserstein
        self.use_causal_intervention = use_causal_intervention

        # Feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout * 0.7),

            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
        )

        # Task classifier
        self.task_classifier = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_dim // 4, out_dim)
        )

        # Environment classifier (for Wasserstein distance)
        if use_wasserstein and num_environments > 0:
            self.env_classifier = nn.Sequential(
                nn.Linear(hidden_dim // 2, hidden_dim // 4),
                nn.ReLU(),
                nn.Linear(hidden_dim // 4, num_environments)
            )

        # Causal intervention layer
        if use_causal_intervention:
            self.causal_layer = nn.Sequential(
                nn.Linear(hidden_dim // 2, hidden_dim // 4),
                nn.ReLU(),
                nn.Linear(hidden_dim // 4, hidden_dim // 2)
            )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x, env_labels=None, return_features=False):
        """
        Forward propagation

        Parameters:
        - x: Input image [batch_size, channels, height, width] or [batch_size, features]
        - env_labels: Environment labels
        - return_features: Whether to return features

        Returns:
        - output: Dictionary containing task logits and environment logits
        """
        output = {}

        # Flatten image
        if x.dim() == 4:
            x = x.view(x.size(0), -1)

        # Extract features
        features = self.feature_extractor(x)

        # Causal intervention
        if self.use_causal_intervention and hasattr(self, 'causal_layer'):
            causal_effect = self.causal_layer(features)
            features = features + causal_effect

        # Task prediction
        task_logits = self.task_classifier(features)
        output['task_logits'] = task_logits

        # Environment prediction (for Wasserstein distance)
        if self.use_wasserstein and hasattr(self, 'env_classifier'):
            # Use detach() to prevent gradient backpropagation to feature extractor
            env_logits = self.env_classifier(features.detach())
            output['env_logits'] = env_logits

        if return_features:
            output['features'] = features

        return output

    def compute_wasserstein_loss(self, features, env_labels, lambda_w=1.0):
        """Compute Wasserstein distance loss"""
        if not hasattr(self, 'env_classifier'):
            return torch.tensor(0.0, device=features.device)

        env_logits = self.env_classifier(features)
        env_loss = F.cross_entropy(env_logits, env_labels)

        # Gradient reversal: we want to minimize environment predictability
        return -lambda_w * env_loss

    def compute_causal_loss(self, features, labels, lambda_c=0.5):
        """Compute causal intervention loss"""
        if not hasattr(self, 'causal_layer'):
            return torch.tensor(0.0, device=features.device)

        # Simplified causal loss: encourage feature independence from labels
        # Using simplified version of mutual information estimation
        task_logits = self.task_classifier(features)
        task_probs = F.softmax(task_logits, dim=-1)
        label_probs = torch.zeros_like(task_probs)
        label_probs.scatter_(1, labels.unsqueeze(1), 1)

        # KL divergence as approximation of mutual information
        causal_loss = F.kl_div(task_probs.log(), label_probs, reduction='batchmean')

        return lambda_c * causal_loss


class GIPLD(nn.Module):
    """GIPLD Main Model"""

    def __init__(self, in_dim, hidden_dim, out_dim, num_environments=2):
        super().__init__()

        self.num_environments = num_environments

        # 1. Subgraph extractor
        self.subgraph_extractor = SubgraphExtractor(
            in_dim=in_dim,
            hidden_dim=hidden_dim
        )

        # 2. Feature encoder (using SimpleGNN)
        self.encoder = SimpleGNN(
            input_dim=in_dim,
            hidden_dim=hidden_dim,
            num_classes=out_dim,
            num_envs=num_environments
        )

        # 3. Environment generator
        self.env_generator = EnvironmentGenerator(
            input_dim=hidden_dim // 2,  # Encoder output dimension
            hidden_dim=hidden_dim,
            num_environments=num_environments
        )

        # 4. Prototype aligner
        self.prototype_aligner = PrototypeAligner(
            feature_dim=hidden_dim // 2,
            num_classes=out_dim
        )

        print(f"✅ GIPLD model initialized:")
        print(f"   Input dimension: {in_dim}")
        print(f"   Hidden dimension: {hidden_dim}")
        print(f"   Output dimension: {out_dim}")
        print(f"   Number of environments: {num_environments}")

    def forward(self, batch, method='erm', alpha=0.1):
        """Forward propagation"""
        # 1. Extract features
        features = self.encoder.extract_features(batch)

        # 2. Environment enhancement
        env_features_list = []
        for env_id in range(self.num_environments):
            env_features = self.env_generator(features, env_id)
            env_features_list.append(env_features)

        # 3. Main task prediction
        logits = self.encoder.classifier(features)

        # 4. Subgraph division (if available)
        if hasattr(self.subgraph_extractor, 'forward'):
            try:
                invariant_features, variant_features = self.subgraph_extractor(batch)
            except:
                invariant_features = features
                variant_features = features

        output = {
            'logits': logits,
            'features': features,
            'env_features': env_features_list,
            'invariant_features': invariant_features if 'invariant_features' in locals() else features,
            'variant_features': variant_features if 'variant_features' in locals() else features
        }

        return output