import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn


class InvariantEncoder(nn.Module):
    """Invariant Subgraph Encoder (captures class-related information)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2,
                 gnn_type='gcn', dropout=0.1):
        super().__init__()

        self.gnn_type = gnn_type
        self.dropout = dropout

        # GNN layers
        self.gnn_layers = nn.ModuleList()
        in_channels = input_dim

        for _ in range(num_layers):
            if gnn_type == 'gcn':
                layer = pyg_nn.GCNConv(in_channels, hidden_dim)
            elif gnn_type == 'gin':
                layer = pyg_nn.GINConv(
                    nn.Sequential(
                        nn.Linear(in_channels, hidden_dim),
                        nn.ReLU(),
                        nn.Linear(hidden_dim, hidden_dim)
                    )
                )
            elif gnn_type == 'gat':
                layer = pyg_nn.GATConv(in_channels, hidden_dim, heads=4, concat=False)
            elif gnn_type == 'sage':
                layer = pyg_nn.SAGEConv(in_channels, hidden_dim)
            else:
                raise ValueError(f"Unknown GNN type: {gnn_type}")

            self.gnn_layers.append(layer)
            in_channels = hidden_dim

        # Output layer
        self.output_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

        # Batch normalization
        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim) for _ in range(num_layers)
        ])

    def forward(self, x, adj):
        """
        x: [batch_size, num_nodes, input_dim]
        adj: [batch_size, num_nodes, num_nodes]
        """
        batch_size, num_nodes, _ = x.shape

        # Convert to sparse format
        edge_indices = []
        batch_indices = []

        for b in range(batch_size):
            # Get non-zero edges
            edge_index = adj[b].nonzero().t()
            edge_indices.append(edge_index)
            batch_indices.append(torch.full((num_nodes,), b, dtype=torch.long))

        if len(edge_indices) == 0:
            # If no edges, use fully connected graph
            for b in range(batch_size):
                edge_index = torch.combinations(torch.arange(num_nodes)).t()
                edge_indices.append(edge_index)
                batch_indices.append(torch.full((num_nodes,), b, dtype=torch.long))

        # Concatenate edges from all graphs
        edge_index = torch.cat(edge_indices, dim=1)
        batch_idx = torch.cat(batch_indices)

        # Flatten node features
        x_flat = x.view(-1, x.shape[-1])

        # GNN forward propagation
        h = x_flat
        for i, layer in enumerate(self.gnn_layers):
            h = layer(h, edge_index)
            if i < len(self.gnn_layers) - 1:
                h = self.batch_norms[i](h)
                h = F.relu(h)
                h = F.dropout(h, p=self.dropout, training=self.training)

        # Restore batch dimension
        h = h.view(batch_size, num_nodes, -1)

        # Graph-level readout (using attention pooling)
        graph_emb = self.attention_pooling(h)

        # Output projection
        output = self.output_proj(graph_emb)

        return output

    def attention_pooling(self, h):
        """Attention pooling"""
        batch_size, num_nodes, hidden_dim = h.shape

        # Attention weights
        attn = torch.einsum('bnh,hd->bnd', h, self.attn_weight) if hasattr(self, 'attn_weight') else h.mean(dim=-1,
                                                                                                            keepdim=True)
        attn = F.softmax(attn, dim=1)

        # Weighted sum
        pooled = torch.einsum('bnh,bnd->bhd', h, attn).squeeze(1)

        return pooled


class VariantEncoder(nn.Module):
    """Variant Subgraph Encoder (captures environment-related information)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2,
                 gnn_type='gcn', dropout=0.1):
        super().__init__()

        # Main encoder
        self.main_encoder = InvariantEncoder(
            input_dim, hidden_dim, output_dim,
            num_layers, gnn_type, dropout
        )

        # Environment-aware module
        self.env_proj = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

        # Adversarial discriminator
        self.env_discriminator = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, adj):
        # Base encoding
        base_emb = self.main_encoder(x, adj)

        # Environment-aware projection
        env_emb = self.env_proj(base_emb)

        return env_emb

    def discriminate_environment(self, embeddings):
        """Discriminate environment labels"""
        return self.env_discriminator(embeddings)


class SharedEncoder(nn.Module):
    """Shared Encoder (for invariant and variant representations)"""

    def __init__(self, input_dim, hidden_dim, num_layers=3):
        super().__init__()

        self.layers = nn.ModuleList()
        in_dim = input_dim

        for i in range(num_layers):
            self.layers.append(nn.Linear(in_dim, hidden_dim))
            in_dim = hidden_dim

        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:
                x = self.activation(x)
                x = self.dropout(x)
        return x