import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv, HeteroConv, SAGEConv
from torch_scatter import scatter_mean, scatter_max
from torch_geometric.data import HeteroData


class MultiViewHGNNLayer(nn.Module):
    """
    Multi-View Heterogeneous Graph Neural Network Layer.

    This layer implements the node-to-hyperedge and hyperedge-to-node message passing
    mechanism to fuse topological and geometric information.

    Args:
        hidden_dim (int): Dimension of hidden features
        num_heads (int): Number of attention heads
        dropout (float): Dropout rate
    """

    def __init__(self, hidden_dim, num_heads=4, dropout=0.1):
        super(MultiViewHGNNLayer, self).__init__()

        self.hidden_dim = hidden_dim

        # Node-to-hyperedge message passing
        # Cell-to-net transformation
        self.cell_to_net = nn.Linear(hidden_dim, hidden_dim)

        # Grid-to-net transformation
        self.grid_to_net = nn.Linear(hidden_dim, hidden_dim)

        # Net fusion transformation
        self.net_fusion = nn.Linear(hidden_dim * 2, hidden_dim)

        # Hyperedge-to-node message passing
        # Cell-based aggregation
        self.cell_hyperedge_agg = nn.Linear(hidden_dim, hidden_dim)

        # Net-based aggregation for cell
        self.net_to_cell_agg = nn.Linear(hidden_dim, hidden_dim)

        # Grid-based aggregation
        self.grid_hyperedge_agg = nn.Linear(hidden_dim, hidden_dim)

        # Net-based aggregation for grid
        self.net_to_grid_agg = nn.Linear(hidden_dim, hidden_dim)

        # Grid spatial aggregation using attention
        self.grid_spatial_agg = GATConv(
            hidden_dim,
            hidden_dim // num_heads,
            num_heads,
            dropout=dropout,
            concat=True
        )

        # Final fusion transformations
        self.cell_fusion = nn.Linear(hidden_dim * 2, hidden_dim)
        self.grid_fusion = nn.Linear(hidden_dim * 2, hidden_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, cell_hypergraph, grid_hypergraph):
        """
        Forward pass of the Multi-View HGNN Layer.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            tuple: Updated cell-based and grid-based hypergraphs
        """
        # Create new HeteroData objects for updated hypergraphs
        updated_cell_hypergraph = HeteroData()
        updated_grid_hypergraph = HeteroData()

        # Copy edge structure
        for edge_type in cell_hypergraph.edge_types:
            updated_cell_hypergraph[edge_type].edge_index = cell_hypergraph[edge_type].edge_index.clone()

        for edge_type in grid_hypergraph.edge_types:
            updated_grid_hypergraph[edge_type].edge_index = grid_hypergraph[edge_type].edge_index.clone()

        # Step 1: Node-to-hyperedge message passing
        # Process cell-to-net messages
        cell_to_net_embeddings = self._node_to_hyperedge(
            cell_hypergraph,
            'cell',
            'net',
            ('cell', 'to', 'net'),
            self.cell_to_net
        )

        # Process grid-to-net messages
        grid_to_net_embeddings = self._node_to_hyperedge(
            grid_hypergraph,
            'grid',
            'net',
            ('grid', 'to', 'net'),
            self.grid_to_net
        )

        # Step 2: Fuse net embeddings from both views
        # Make sure the net embeddings have the same size
        num_nets = min(cell_to_net_embeddings.size(0), grid_to_net_embeddings.size(0))

        # Resize embeddings if necessary
        cell_to_net_embeddings = cell_to_net_embeddings[:num_nets]
        grid_to_net_embeddings = grid_to_net_embeddings[:num_nets]

        # Concatenate embeddings and apply fusion
        fused_net_embeddings = torch.cat([cell_to_net_embeddings, grid_to_net_embeddings], dim=1)
        fused_net_embeddings = self.net_fusion(fused_net_embeddings)
        fused_net_embeddings = F.relu(fused_net_embeddings)
        fused_net_embeddings = self.dropout(fused_net_embeddings)

        # Step 3: Hyperedge-to-node message passing
        # Process net-to-cell messages
        net_to_cell_embeddings = self._hyperedge_to_node(
            cell_hypergraph,
            'net',
            'cell',
            ('net', 'to', 'cell'),
            fused_net_embeddings,
            self.net_to_cell_agg
        )

        # Process cell-based hyperedge messages (cell-to-cell through hyperedges)
        cell_hyperedge_embeddings = self._hyperedge_through_path(
            cell_hypergraph,
            'cell',
            'net',
            'cell',
            ('cell', 'to', 'net'),
            ('net', 'to', 'cell'),
            cell_hypergraph['cell'].h,
            self.cell_hyperedge_agg
        )

        # Fuse cell embeddings
        cell_embeddings = torch.cat([net_to_cell_embeddings, cell_hyperedge_embeddings], dim=1)
        cell_embeddings = self.cell_fusion(cell_embeddings)
        cell_embeddings = F.relu(cell_embeddings)
        cell_embeddings = self.dropout(cell_embeddings)

        # Process net-to-grid messages
        net_to_grid_embeddings = self._hyperedge_to_node(
            grid_hypergraph,
            'net',
            'grid',
            ('net', 'to', 'grid'),
            fused_net_embeddings,
            self.net_to_grid_agg
        )

        # Process grid spatial messages (using attention-based aggregation)
        if ('grid', 'adjacent', 'grid') in grid_hypergraph.edge_types:
            grid_features = grid_hypergraph['grid'].h
            grid_edge_index = grid_hypergraph[('grid', 'adjacent', 'grid')].edge_index
            grid_spatial_embeddings = self.grid_spatial_agg(grid_features, grid_edge_index)
        else:
            # If no spatial edges, use the original grid embeddings
            grid_spatial_embeddings = grid_hypergraph['grid'].h

        # Fuse grid embeddings
        grid_embeddings = torch.cat([net_to_grid_embeddings, grid_spatial_embeddings], dim=1)
        grid_embeddings = self.grid_fusion(grid_embeddings)
        grid_embeddings = F.relu(grid_embeddings)
        grid_embeddings = self.dropout(grid_embeddings)

        # Update node embeddings in the output hypergraphs
        updated_cell_hypergraph['cell'].h = cell_embeddings
        updated_cell_hypergraph['net'].h = fused_net_embeddings

        updated_grid_hypergraph['grid'].h = grid_embeddings
        updated_grid_hypergraph['net'].h = fused_net_embeddings

        return updated_cell_hypergraph, updated_grid_hypergraph

    def _node_to_hyperedge(self, hypergraph, src_type, dst_type, edge_type, transform_fn):
        """
        Perform node-to-hyperedge message passing.

        Args:
            hypergraph (HeteroData): Input hypergraph
            src_type (str): Source node type
            dst_type (str): Destination node type
            edge_type (tuple): Edge type as (src_type, relation, dst_type)
            transform_fn (nn.Module): Transformation function

        Returns:
            torch.Tensor: Hyperedge embeddings
        """
        if edge_type in hypergraph.edge_types:
            edge_index = hypergraph[edge_type].edge_index

            # Get source node embeddings
            src_embeddings = hypergraph[src_type].h

            # Apply transformation
            src_embeddings = transform_fn(src_embeddings)

            # Message passing from src to dst
            src_nodes = edge_index[0]
            dst_nodes = edge_index[1]

            # Gather source node embeddings
            src_node_embeddings = src_embeddings[src_nodes]

            # Aggregate using mean
            dst_embeddings = scatter_mean(src_node_embeddings, dst_nodes, dim=0,
                                          dim_size=hypergraph[dst_type].num_nodes)

            return dst_embeddings
        else:
            # If no edges exist, return zeros
            return torch.zeros(hypergraph[dst_type].num_nodes, self.hidden_dim, device=hypergraph[src_type].h.device)

    def _hyperedge_to_node(self, hypergraph, src_type, dst_type, edge_type, src_embeddings, transform_fn):
        """
        Perform hyperedge-to-node message passing.

        Args:
            hypergraph (HeteroData): Input hypergraph
            src_type (str): Source node type
            dst_type (str): Destination node type
            edge_type (tuple): Edge type as (src_type, relation, dst_type)
            src_embeddings (torch.Tensor): Source node embeddings
            transform_fn (nn.Module): Transformation function

        Returns:
            torch.Tensor: Updated node embeddings
        """
        if edge_type in hypergraph.edge_types:
            edge_index = hypergraph[edge_type].edge_index

            # Apply transformation
            transformed_embeddings = transform_fn(src_embeddings)

            # Message passing from src to dst
            src_nodes = edge_index[0]
            dst_nodes = edge_index[1]

            # Gather source node embeddings
            src_node_embeddings = transformed_embeddings[src_nodes]

            # Aggregate using mean
            dst_embeddings = scatter_mean(src_node_embeddings, dst_nodes, dim=0,
                                          dim_size=hypergraph[dst_type].num_nodes)

            return dst_embeddings
        else:
            # If no edges exist, return zeros
            return torch.zeros(hypergraph[dst_type].num_nodes, self.hidden_dim, device=src_embeddings.device)

    def _hyperedge_through_path(self, hypergraph, src_type, mid_type, dst_type,
                                src_to_mid_edge_type, mid_to_dst_edge_type,
                                src_embeddings, transform_fn):
        """
        Perform node-to-node message passing through hyperedges (two-hop path).

        Args:
            hypergraph (HeteroData): Input hypergraph
            src_type (str): Source node type
            mid_type (str): Intermediate node type (hyperedge)
            dst_type (str): Destination node type
            src_to_mid_edge_type (tuple): First hop edge type
            mid_to_dst_edge_type (tuple): Second hop edge type
            src_embeddings (torch.Tensor): Source node embeddings
            transform_fn (nn.Module): Transformation function

        Returns:
            torch.Tensor: Updated node embeddings
        """
        if src_to_mid_edge_type in hypergraph.edge_types and mid_to_dst_edge_type in hypergraph.edge_types:
            # First hop: src to mid
            src_to_mid_edge_index = hypergraph[src_to_mid_edge_type].edge_index
            mid_to_dst_edge_index = hypergraph[mid_to_dst_edge_type].edge_index

            # Apply transformation
            transformed_embeddings = transform_fn(src_embeddings)

            # First hop: src to mid
            src_nodes_1 = src_to_mid_edge_index[0]
            mid_nodes_1 = src_to_mid_edge_index[1]

            # Gather source node embeddings
            src_node_embeddings = transformed_embeddings[src_nodes_1]

            # Aggregate at mid nodes
            mid_embeddings = scatter_mean(src_node_embeddings, mid_nodes_1, dim=0,
                                          dim_size=hypergraph[mid_type].num_nodes)

            # Second hop: mid to dst
            mid_nodes_2 = mid_to_dst_edge_index[0]
            dst_nodes_2 = mid_to_dst_edge_index[1]

            # Gather mid node embeddings
            mid_node_embeddings = mid_embeddings[mid_nodes_2]

            # Aggregate at dst nodes
            dst_embeddings = scatter_mean(mid_node_embeddings, dst_nodes_2, dim=0,
                                          dim_size=hypergraph[dst_type].num_nodes)

            return dst_embeddings
        else:
            # If no path exists, return zeros
            return torch.zeros(hypergraph[dst_type].num_nodes, self.hidden_dim, device=src_embeddings.device)


class MultiViewHGNN(nn.Module):
    """
    Multi-View Heterogeneous Graph Neural Network for MIHC.

    This module processes and fuses information from both cell-based and grid-based
    hypergraphs to learn comprehensive representations of circuit components.

    Args:
        hidden_dim (int): Dimension of hidden features
        num_layers (int): Number of HGNN layers
        num_heads (int): Number of attention heads
        dropout (float): Dropout rate
    """

    def __init__(self, hidden_dim, num_layers=4, num_heads=4, dropout=0.1):
        super(MultiViewHGNN, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # Multi-View HGNN layers
        self.mv_hgnn_layers = nn.ModuleList([
            MultiViewHGNNLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        # Subgraph readout for contrastive learning
        self.cell_subgraph_readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.grid_subgraph_readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, cell_hypergraph, grid_hypergraph, cell_bottleneck=None, grid_bottleneck=None):
        """
        Forward pass of the Multi-View HGNN.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph
            cell_bottleneck (HeteroData, optional): Cell-based bottleneck subgraph
            grid_bottleneck (HeteroData, optional): Grid-based bottleneck subgraph

        Returns:
            tuple: Updated hypergraphs and subgraph embeddings
        """
        # Apply Multi-View HGNN layers
        for layer in self.mv_hgnn_layers:
            cell_hypergraph, grid_hypergraph = layer(cell_hypergraph, grid_hypergraph)

        # Generate subgraph embeddings for contrastive learning if bottleneck subgraphs are provided
        if cell_bottleneck is not None and grid_bottleneck is not None:
            cell_subgraph_embedding = self._generate_subgraph_embedding(
                cell_hypergraph,
                'cell',
                cell_bottleneck,
                self.cell_subgraph_readout
            )

            grid_subgraph_embedding = self._generate_subgraph_embedding(
                grid_hypergraph,
                'grid',
                grid_bottleneck,
                self.grid_subgraph_readout
            )
        else:
            cell_subgraph_embedding = None
            grid_subgraph_embedding = None

        return cell_hypergraph, grid_hypergraph, cell_subgraph_embedding, grid_subgraph_embedding

    def _generate_subgraph_embedding(self, hypergraph, node_type, bottleneck, readout_fn):
        """
        Generate subgraph embedding for contrastive learning.

        Args:
            hypergraph (HeteroData): Input hypergraph
            node_type (str): Node type ('cell' or 'grid')
            bottleneck (HeteroData): Bottleneck subgraph
            readout_fn (nn.Module): Readout function

        Returns:
            torch.Tensor: Subgraph embedding
        """
        # Get node embeddings and probabilities
        node_embeddings = hypergraph[node_type].h
        node_probs = bottleneck[node_type].prob if hasattr(bottleneck[node_type], 'prob') else torch.ones_like(
            node_embeddings[:, 0])

        # Weight node embeddings by their probabilities
        weighted_embeddings = node_embeddings * node_probs.unsqueeze(1)

        # Aggregate weighted embeddings (mean pooling)
        subgraph_embedding = weighted_embeddings.sum(dim=0) / (node_probs.sum() + 1e-8)

        # Apply readout function
        subgraph_embedding = readout_fn(subgraph_embedding.unsqueeze(0)).squeeze(0)

        return subgraph_embedding