import torch
import torch.nn as nn
import torch.nn.functional as F


class SubgraphExtractor(nn.Module):
    """Subgraph Extraction Module (based on GCIB)"""

    def __init__(self, input_dim, hidden_dim, topk_ratio=0.5):
        super().__init__()
        self.topk_ratio = topk_ratio

        # GNN encoder
        self.gnn_encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Mask generator
        self.mask_generator = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x, adj):
        """
        Extract invariant and variant subgraphs
        x: [batch_size, num_nodes, input_dim]
        adj: [batch_size, num_nodes, num_nodes]
        """
        # Node representations
        h = self.gnn_encoder(x)  # [batch, num_nodes, hidden_dim]

        # Generate edge masks
        edge_mask = self._generate_edge_mask(h, adj)

        # Extract invariant subgraph (top-k edges)
        invariant_mask = self._topk_mask(edge_mask, ratio=self.topk_ratio)
        invariant_adj = adj * invariant_mask

        # Extract variant subgraph (remaining edges)
        variant_mask = 1 - invariant_mask
        variant_adj = adj * variant_mask

        return {
            'invariant_adj': invariant_adj,
            'variant_adj': variant_adj,
            'invariant_mask': invariant_mask,
            'variant_mask': variant_mask,
            'node_features': h
        }

    def _generate_edge_mask(self, h, adj):
        """Generate edge-level masks"""
        batch_size, num_nodes, _ = h.shape

        # Calculate edge attention
        h_expanded_i = h.unsqueeze(2).repeat(1, 1, num_nodes, 1)
        h_expanded_j = h.unsqueeze(1).repeat(1, num_nodes, 1, 1)
        edge_feat = torch.cat([h_expanded_i, h_expanded_j], dim=-1)

        # Edge masks
        edge_mask = self.mask_generator(edge_feat.view(-1, edge_feat.shape[-1]))
        edge_mask = edge_mask.view(batch_size, num_nodes, num_nodes)

        # Keep only actual existing edges
        edge_mask = edge_mask * adj

        return edge_mask

    def _topk_mask(self, mask, ratio):
        """Select top-k ratio edges"""
        batch_size, num_nodes, _ = mask.shape

        # Flatten
        mask_flat = mask.view(batch_size, -1)

        # Calculate number of edges to keep per graph
        num_edges = (mask_flat > 0).sum(dim=1).float()
        k = (num_edges * ratio).long()

        # Create top-k mask
        topk_mask = torch.zeros_like(mask_flat)
        for i in range(batch_size):
            if k[i] > 0:
                values, indices = torch.topk(mask_flat[i], k[i])
                topk_mask[i, indices] = 1

        return topk_mask.view(batch_size, num_nodes, num_nodes)