import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv, GCNConv, SAGEConv
from torch_scatter import scatter_max, scatter_mean
from torch_geometric.data import HeteroData


class SubgraphExtractor(nn.Module):
    """
    Subgraph Extractor for Information Bottleneck.

    This module extracts critical bottleneck subgraphs that are highly correlated
    with circuit congestion.

    Args:
        hidden_dim (int): Dimension of hidden features
        num_layers (int): Number of GNN layers
    """

    def __init__(self, hidden_dim, num_layers=2):
        super(SubgraphExtractor, self).__init__()

        # GNN layers for cell-based hypergraph
        self.gnn_layers = nn.ModuleList()
        for _ in range(num_layers):
            self.gnn_layers.append(
                HeteroConv({
                    ('cell', 'to', 'net'): GCNConv(hidden_dim, hidden_dim),
                    ('net', 'to', 'cell'): GCNConv(hidden_dim, hidden_dim)
                })
            )

        # Node importance prediction
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, cell_hypergraph, grid_hypergraph):
        """
        Forward pass of the Subgraph Extractor.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            tuple: Bottleneck subgraphs for cell-based and grid-based views and
                  node probabilities
        """
        # Get initial node features
        x_dict = {
            'cell': cell_hypergraph['cell'].h,
            'net': cell_hypergraph['net'].h
        }

        edge_index_dict = {
            ('cell', 'to', 'net'): cell_hypergraph[('cell', 'to', 'net')].edge_index,
            ('net', 'to', 'cell'): cell_hypergraph[('net', 'to', 'cell')].edge_index
        }

        # Apply GNN layers to extract node embeddings
        for layer in self.gnn_layers:
            x_dict = layer(x_dict, edge_index_dict)
            # Apply non-linearity
            for node_type in x_dict:
                x_dict[node_type] = F.relu(x_dict[node_type])

        # Compute node importance probabilities for cell nodes
        cell_node_probs = self.node_importance(x_dict['cell']).squeeze(-1)

        # Map cell probabilities to grid nodes
        grid_node_probs = self._map_cell_to_grid_probs(
            cell_node_probs, cell_hypergraph, grid_hypergraph
        )

        # Create bottleneck subgraphs with weighted features
        cell_bottleneck = self._create_bottleneck_subgraph(
            cell_hypergraph, 'cell', cell_node_probs
        )

        grid_bottleneck = self._create_bottleneck_subgraph(
            grid_hypergraph, 'grid', grid_node_probs
        )

        return cell_bottleneck, grid_bottleneck, cell_node_probs, grid_node_probs

    def _map_cell_to_grid_probs(self, cell_probs, cell_hypergraph, grid_hypergraph):
        """
        Map cell node probabilities to grid nodes through shared nets.

        Args:
            cell_probs (torch.Tensor): Cell node probabilities
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            torch.Tensor: Grid node probabilities
        """
        # First map cell probabilities to nets
        if ('cell', 'to', 'net') in cell_hypergraph.edge_index_dict:
            cell_to_net_edge_index = cell_hypergraph[('cell', 'to', 'net')].edge_index

            # Get source nodes (cells) and target nodes (nets)
            src_nodes = cell_to_net_edge_index[0]
            dst_nodes = cell_to_net_edge_index[1]

            # Get probabilities of source nodes
            src_probs = cell_probs[src_nodes]

            # Aggregate to nets using max aggregation
            net_probs, _ = scatter_max(src_probs, dst_nodes, dim=0, dim_size=cell_hypergraph['net'].num_nodes)
        else:
            # If no edges exist, assign zero probability to all nets
            net_probs = torch.zeros(cell_hypergraph['net'].num_nodes, device=cell_probs.device)

        # Then map net probabilities to grids
        if ('net', 'to', 'grid') in grid_hypergraph.edge_index_dict:
            net_to_grid_edge_index = grid_hypergraph[('net', 'to', 'grid')].edge_index

            # Get source nodes (nets) and target nodes (grids)
            src_nodes = net_to_grid_edge_index[0]
            dst_nodes = net_to_grid_edge_index[1]

            # Get probabilities of source nodes
            src_probs = net_probs[src_nodes]

            # Aggregate to grids using max aggregation
            grid_probs, _ = scatter_max(src_probs, dst_nodes, dim=0, dim_size=grid_hypergraph['grid'].num_nodes)
        else:
            # If no edges exist, assign zero probability to all grids
            grid_probs = torch.zeros(grid_hypergraph['grid'].num_nodes, device=cell_probs.device)

        # Replace NaN values (from scatter_max with empty groups) with zeros
        grid_probs = torch.nan_to_num(grid_probs, nan=0.0)

        return grid_probs

    def _create_bottleneck_subgraph(self, original_graph, node_type, node_probs):
        """
        Create a bottleneck subgraph from the original hypergraph.

        Args:
            original_graph (HeteroData): Original hypergraph
            node_type (str): Node type ('cell' or 'grid')
            node_probs (torch.Tensor): Node probabilities

        Returns:
            HeteroData: Bottleneck subgraph
        """
        # Create a new HeteroData for the bottleneck subgraph
        bottleneck = HeteroData()

        # Copy node features and multiply by node probabilities
        bottleneck[node_type].h = original_graph[node_type].h.clone()
        bottleneck[node_type].h_bottleneck = original_graph[node_type].h * node_probs.unsqueeze(1)
        bottleneck[node_type].prob = node_probs

        # Copy other node types and their features
        for n_type in original_graph.node_types:
            if n_type != node_type:
                bottleneck[n_type].h = original_graph[n_type].h.clone()

        # Copy edge structure
        for edge_type in original_graph.edge_types:
            bottleneck[edge_type].edge_index = original_graph[edge_type].edge_index.clone()

        return bottleneck


class InterpretableSubgraphModule(nn.Module):
    """
    Interpretable Subgraph Module for MIHC.

    This module implements the Information Bottleneck mechanism to identify
    critical subgraphs highly correlated with circuit congestion.

    Args:
        hidden_dim (int): Dimension of hidden features
        bottleneck_enable (bool): Whether to enable bottleneck mechanism
    """

    def __init__(self, hidden_dim, bottleneck_enable=True):
        super(InterpretableSubgraphModule, self).__init__()

        self.hidden_dim = hidden_dim
        self.bottleneck_enable = bottleneck_enable

        # Subgraph extractor
        self.subgraph_extractor = SubgraphExtractor(hidden_dim)

        # Feature reweighting for augmentation
        self.cell_reweight = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

        self.grid_reweight = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Sigmoid()
        )

    def forward(self, cell_hypergraph, grid_hypergraph):
        """
        Forward pass of the Interpretable Subgraph Module.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            tuple: Augmented hypergraphs, bottleneck subgraphs, and node probabilities
        """
        if not self.bottleneck_enable:
            # If bottleneck is disabled, return the original hypergraphs
            return (
                cell_hypergraph,
                grid_hypergraph,
                None,
                None,
                torch.ones(cell_hypergraph['cell'].num_nodes, device=cell_hypergraph['cell'].h.device),
                torch.ones(grid_hypergraph['grid'].num_nodes, device=grid_hypergraph['grid'].h.device)
            )

        # Extract bottleneck subgraphs
        cell_bottleneck, grid_bottleneck, cell_probs, grid_probs = self.subgraph_extractor(
            cell_hypergraph, grid_hypergraph
        )

        # Augment cell hypergraph with bottleneck information
        augmented_cell_hypergraph = self._augment_hypergraph(
            cell_hypergraph,
            cell_bottleneck,
            'cell',
            self.cell_reweight
        )

        # Augment grid hypergraph with bottleneck information
        augmented_grid_hypergraph = self._augment_hypergraph(
            grid_hypergraph,
            grid_bottleneck,
            'grid',
            self.grid_reweight
        )

        return (
            augmented_cell_hypergraph,
            augmented_grid_hypergraph,
            cell_bottleneck,
            grid_bottleneck,
            cell_probs,
            grid_probs
        )

    def _augment_hypergraph(self, original, bottleneck, node_type, reweight_mlp):
        """
        Augment hypergraph with bottleneck information.

        Args:
            original (HeteroData): Original hypergraph
            bottleneck (HeteroData): Bottleneck subgraph
            node_type (str): Node type ('cell' or 'grid')
            reweight_mlp (nn.Module): MLP for feature reweighting

        Returns:
            HeteroData: Augmented hypergraph
        """
        # Create a new HeteroData for the augmented hypergraph
        augmented = HeteroData()

        # Copy all node features and edge structure from the original graph
        for n_type in original.node_types:
            augmented[n_type].h = original[n_type].h.clone()

        for edge_type in original.edge_types:
            augmented[edge_type].edge_index = original[edge_type].edge_index.clone()

        # Concatenate original and bottleneck features
        original_features = original[node_type].h
        bottleneck_features = bottleneck[node_type].h_bottleneck
        concat_features = torch.cat([original_features, bottleneck_features], dim=1)

        # Compute reweighting factors
        reweight_factors = reweight_mlp(concat_features)

        # Apply reweighting
        augmented[node_type].h = original_features * (1 + reweight_factors)

        return augmented