import torch
import torch.nn as nn
import torch.nn.functional as F

from .process import ProcessModule
from .interpretable import InterpretableSubgraphModule
from .multi_view import MultiViewHGNN
from .prediction import PredictionModule, ContrastiveLoss, InformationBottleneckLoss


class MIHC(nn.Module):
    """
    Multi-view Interpretable Hypergraph Neural Networks with Information Bottleneck
    for Chip Congestion Prediction.

    This is the main model that combines all modules:
    1. Process Module: Converts raw data into hypergraphs
    2. Interpretable Subgraph Module: Extracts critical bottleneck subgraphs
    3. Multi-View HGNN: Fuses information from both views
    4. Prediction Module: Maps node embeddings to congestion predictions

    Args:
        cell_feature_dim (int): Dimension of cell node features
        grid_feature_dim (int): Dimension of grid node features
        hidden_dim (int): Dimension of hidden features
        num_layers (int): Number of HGNN layers
        num_heads (int): Number of attention heads
        dropout (float): Dropout rate
        bottleneck_enable (bool): Whether to enable bottleneck mechanism
        temperature (float): Temperature parameter for contrastive loss
        beta (float): Trade-off parameter for IB loss
    """

    def __init__(self, cell_feature_dim, grid_feature_dim, hidden_dim=128,
                 num_layers=4, num_heads=4, dropout=0.1, bottleneck_enable=True,
                 temperature=0.07, beta=0.1):
        super(MIHC, self).__init__()

        # Process Module
        self.process_module = ProcessModule(
            cell_feature_dim=cell_feature_dim,
            grid_feature_dim=grid_feature_dim,
            hidden_dim=hidden_dim
        )

        # Interpretable Subgraph Module
        self.interpretable_module = InterpretableSubgraphModule(
            hidden_dim=hidden_dim,
            bottleneck_enable=bottleneck_enable
        )

        # Multi-View HGNN
        self.multi_view_hgnn = MultiViewHGNN(
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            dropout=dropout
        )

        # Prediction Module
        self.prediction_module = PredictionModule(
            hidden_dim=hidden_dim,
            dropout=dropout
        )

        # Loss functions
        self.contrastive_loss = ContrastiveLoss(temperature=temperature)
        self.ib_loss = InformationBottleneckLoss(beta=beta)

        # Configuration
        self.bottleneck_enable = bottleneck_enable

    def forward(self, cell_hypergraph, grid_hypergraph):
        """
        Forward pass of the MIHC model.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            dict: Dictionary containing predictions and intermediate results
        """
        # Process Module: Convert raw data into hypergraphs
        cell_hypergraph, grid_hypergraph = self.process_module(cell_hypergraph, grid_hypergraph)

        # Interpretable Subgraph Module: Extract bottleneck subgraphs
        (augmented_cell_hypergraph,
         augmented_grid_hypergraph,
         cell_bottleneck,
         grid_bottleneck,
         cell_probs,
         grid_probs) = self.interpretable_module(cell_hypergraph, grid_hypergraph)

        # Multi-View HGNN: Fuse information from both views
        (updated_cell_hypergraph,
         updated_grid_hypergraph,
         cell_subgraph_embedding,
         grid_subgraph_embedding) = self.multi_view_hgnn(
            augmented_cell_hypergraph,
            augmented_grid_hypergraph,
            cell_bottleneck,
            grid_bottleneck
        )

        # Prediction Module: Map node embeddings to congestion predictions
        cell_congestion, grid_congestion = self.prediction_module(
            updated_cell_hypergraph,
            updated_grid_hypergraph
        )

        return {
            'cell_congestion': cell_congestion,
            'grid_congestion': grid_congestion,
            'cell_probs': cell_probs,
            'grid_probs': grid_probs,
            'cell_subgraph_embedding': cell_subgraph_embedding,
            'grid_subgraph_embedding': grid_subgraph_embedding,
            'updated_cell_hypergraph': updated_cell_hypergraph,
            'updated_grid_hypergraph': updated_grid_hypergraph
        }

    def compute_loss(self, predictions, targets):
        """
        Compute total loss for training.

        Args:
            predictions (dict): Dictionary containing predictions from forward pass
            targets (dict): Dictionary containing ground truth targets

        Returns:
            tuple: Total loss and individual loss components
        """
        # Get predictions
        cell_congestion = predictions['cell_congestion']
        grid_congestion = predictions['grid_congestion']
        cell_probs = predictions['cell_probs']
        grid_probs = predictions['grid_probs']
        cell_subgraph_embedding = predictions['cell_subgraph_embedding']
        grid_subgraph_embedding = predictions['grid_subgraph_embedding']

        # Get targets
        cell_congestion_target = targets['cell_congestion']
        grid_congestion_target = targets['grid_congestion']

        # Compute supervised loss (Mean Squared Error)
        cell_sup_loss = F.mse_loss(cell_congestion, cell_congestion_target)
        grid_sup_loss = F.mse_loss(grid_congestion, grid_congestion_target)
        supervision_loss = cell_sup_loss + grid_sup_loss

        # Compute information bottleneck loss if enabled
        if self.bottleneck_enable and cell_probs is not None and grid_probs is not None:
            ib_loss = self.ib_loss(
                cell_probs,
                grid_probs,
                cell_congestion,
                grid_congestion
            )
        else:
            ib_loss = torch.tensor(0.0, device=cell_congestion.device)

        # Compute contrastive loss if subgraph embeddings are available
        if (cell_subgraph_embedding is not None and
                grid_subgraph_embedding is not None and
                cell_subgraph_embedding.size(0) > 1):  # Need at least 2 samples for contrastive learning

            contrastive_loss = self.contrastive_loss(
                cell_subgraph_embedding,
                grid_subgraph_embedding
            )
        else:
            contrastive_loss = torch.tensor(0.0, device=cell_congestion.device)

        # Compute total loss
        total_loss = supervision_loss + ib_loss + contrastive_loss

        return {
            'total_loss': total_loss,
            'supervision_loss': supervision_loss,
            'ib_loss': ib_loss,
            'contrastive_loss': contrastive_loss,
            'cell_sup_loss': cell_sup_loss,
            'grid_sup_loss': grid_sup_loss
        }