import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_scatter import scatter_mean


class ProcessModule(nn.Module):
    """
    Process Module for MIHC.

    This module processes raw multi-view data, including cell-based netlist and
    grid-based layout, into hypergraphs for subsequent processing.

    Args:
        cell_feature_dim (int): Dimension of cell node features
        grid_feature_dim (int): Dimension of grid node features
        hidden_dim (int): Dimension of hidden features
    """

    def __init__(self, cell_feature_dim, grid_feature_dim, hidden_dim):
        super(ProcessModule, self).__init__()

        # Cell feature transformation
        self.cell_feature_transform = nn.Sequential(
            nn.Linear(cell_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Grid feature transformation
        self.grid_feature_transform = nn.Sequential(
            nn.Linear(grid_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Net feature initialization
        self.net_feature_init = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.hidden_dim = hidden_dim

    def forward(self, cell_hypergraph, grid_hypergraph):
        """
        Forward pass of the Process Module.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            tuple: Processed cell-based and grid-based hypergraphs
        """
        # Process cell hypergraph
        cell_features = cell_hypergraph['cell'].x
        cell_features = self.cell_feature_transform(cell_features)
        cell_hypergraph['cell'].h = cell_features

        # Process grid hypergraph
        grid_features = grid_hypergraph['grid'].x
        grid_features = self.grid_feature_transform(grid_features)
        grid_hypergraph['grid'].h = grid_features

        # Initialize net features in cell hypergraph through message passing
        if 'cell' in cell_hypergraph and 'net' in cell_hypergraph:
            if hasattr(cell_hypergraph, 'edge_index_dict') and ('cell', 'to', 'net') in cell_hypergraph.edge_index_dict:
                # Get cell-to-net edges
                edge_index = cell_hypergraph[('cell', 'to', 'net')].edge_index

                # Message passing from cell to net
                src_nodes = edge_index[0]
                dst_nodes = edge_index[1]
                src_features = cell_features[src_nodes]

                # Aggregate cell features to nets using mean aggregation
                net_features = scatter_mean(src_features, dst_nodes, dim=0, dim_size=cell_hypergraph['net'].x.size(0))

                # Initialize net features
                cell_hypergraph['net'].h = self.net_feature_init(net_features)
            else:
                # If no edges exist, initialize with zeros
                cell_hypergraph['net'].h = torch.zeros(cell_hypergraph['net'].x.size(0), self.hidden_dim,
                                                       device=cell_features.device)

        # Initialize net features in grid hypergraph through message passing
        if 'grid' in grid_hypergraph and 'net' in grid_hypergraph:
            if hasattr(grid_hypergraph, 'edge_index_dict') and ('grid', 'to', 'net') in grid_hypergraph.edge_index_dict:
                # Get grid-to-net edges
                edge_index = grid_hypergraph[('grid', 'to', 'net')].edge_index

                # Message passing from grid to net
                src_nodes = edge_index[0]
                dst_nodes = edge_index[1]
                src_features = grid_features[src_nodes]

                # Aggregate grid features to nets using mean aggregation
                net_features = scatter_mean(src_features, dst_nodes, dim=0, dim_size=grid_hypergraph['net'].x.size(0))

                # Initialize net features
                grid_hypergraph['net'].h = self.net_feature_init(net_features)
            else:
                # If no edges exist, initialize with zeros
                grid_hypergraph['net'].h = torch.zeros(grid_hypergraph['net'].x.size(0), self.hidden_dim,
                                                       device=grid_features.device)

        return cell_hypergraph, grid_hypergraph