import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData


class PredictionModule(nn.Module):
    """
    Prediction Module for MIHC.

    This module maps node embeddings to congestion predictions in both cell-based
    and grid-based views.

    Args:
        hidden_dim (int): Dimension of hidden features
        dropout (float): Dropout rate
    """

    def __init__(self, hidden_dim, dropout=0.1):
        super(PredictionModule, self).__init__()

        # MLP for cell-based congestion prediction
        self.cell_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

        # MLP for grid-based congestion prediction
        self.grid_predictor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, cell_hypergraph, grid_hypergraph):
        """
        Forward pass of the Prediction Module.

        Args:
            cell_hypergraph (HeteroData): Cell-based hypergraph
            grid_hypergraph (HeteroData): Grid-based hypergraph

        Returns:
            tuple: Cell-based and grid-based congestion predictions
        """
        # Extract node embeddings
        cell_embeddings = cell_hypergraph['cell'].h
        grid_embeddings = grid_hypergraph['grid'].h

        # Predict congestion
        cell_congestion = self.cell_predictor(cell_embeddings).squeeze(-1)
        grid_congestion = self.grid_predictor(grid_embeddings).squeeze(-1)

        # Apply non-negative constraint (congestion is always non-negative)
        cell_congestion = F.relu(cell_congestion)
        grid_congestion = F.relu(grid_congestion)

        return cell_congestion, grid_congestion


class ContrastiveLoss(nn.Module):
    """
    Contrastive Loss for aligning cell-based and grid-based subgraph embeddings.

    Args:
        temperature (float): Temperature parameter
    """

    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, cell_embeddings, grid_embeddings):
        """
        Forward pass of the Contrastive Loss.

        Args:
            cell_embeddings (torch.Tensor): Cell-based subgraph embeddings [batch_size, hidden_dim]
            grid_embeddings (torch.Tensor): Grid-based subgraph embeddings [batch_size, hidden_dim]

        Returns:
            torch.Tensor: Contrastive loss
        """
        # Normalize embeddings
        cell_embeddings = F.normalize(cell_embeddings, dim=1)
        grid_embeddings = F.normalize(grid_embeddings, dim=1)

        # Compute similarity matrix
        batch_size = cell_embeddings.size(0)
        similarity_matrix = torch.matmul(cell_embeddings, grid_embeddings.T) / self.temperature

        # Labels: positive pairs are the diagonal
        labels = torch.arange(batch_size, device=similarity_matrix.device)

        # Compute loss (cross-entropy loss)
        loss_cell_to_grid = F.cross_entropy(similarity_matrix, labels)
        loss_grid_to_cell = F.cross_entropy(similarity_matrix.T, labels)

        # Average bidirectional loss
        loss = (loss_cell_to_grid + loss_grid_to_cell) / 2

        return loss


class InformationBottleneckLoss(nn.Module):
    """
    Information Bottleneck Loss for subgraph extraction.

    Args:
        beta (float): Trade-off parameter between compression and relevant information
    """

    def __init__(self, beta=0.1):
        super(InformationBottleneckLoss, self).__init__()
        self.beta = beta

    def forward(self, cell_probs, grid_probs, cell_congestion, grid_congestion):
        """
        Forward pass of the Information Bottleneck Loss.

        Args:
            cell_probs (torch.Tensor): Cell node probabilities
            grid_probs (torch.Tensor): Grid node probabilities
            cell_congestion (torch.Tensor): Cell-based congestion predictions
            grid_congestion (torch.Tensor): Grid-based congestion predictions

        Returns:
            torch.Tensor: Information bottleneck loss
        """
        # Compute KL divergence between cell and grid probabilities
        # This encourages consistency between cell-based and grid-based bottleneck subgraphs
        kl_divergence = self._symmetrized_kl_divergence(cell_probs, grid_probs)

        # Compute mutual information between subgraph (bottleneck) and congestion (target)
        # Higher mutual information means bottleneck captures more relevant information
        mutual_info_cell = self._estimate_mutual_information(cell_probs, cell_congestion)
        mutual_info_grid = self._estimate_mutual_information(grid_probs, grid_congestion)

        # Average mutual information
        mutual_info = (mutual_info_cell + mutual_info_grid) / 2

        # Compute IB loss: minimize KL divergence (compression) and maximize mutual information
        loss = kl_divergence - self.beta * mutual_info

        return loss

    def _symmetrized_kl_divergence(self, p, q):
        """
        Compute symmetrized KL divergence between two probability distributions.

        Args:
            p (torch.Tensor): First probability distribution
            q (torch.Tensor): Second probability distribution

        Returns:
            torch.Tensor: Symmetrized KL divergence
        """
        # Normalize probabilities
        p = F.softmax(p, dim=0)
        q = F.softmax(q, dim=0)

        # Add small constant to avoid numerical issues
        eps = 1e-8
        p = p + eps
        q = q + eps

        # Compute KL(p||q) and KL(q||p)
        kl_p_q = torch.sum(p * torch.log(p / q))
        kl_q_p = torch.sum(q * torch.log(q / p))

        # Return symmetrized KL divergence
        return (kl_p_q + kl_q_p) / 2

    def _estimate_mutual_information(self, subgraph_probs, congestion):
        """
        Estimate mutual information between subgraph probabilities and congestion.

        Args:
            subgraph_probs (torch.Tensor): Subgraph node probabilities
            congestion (torch.Tensor): Congestion values

        Returns:
            torch.Tensor: Estimated mutual information
        """
        # Convert probabilities to binary masks based on threshold
        mask = (subgraph_probs > 0.5).float()

        # Compute correlation between mask and congestion
        # High correlation indicates high mutual information
        mask_mean = mask.mean()
        congestion_mean = congestion.mean()

        mask_centered = mask - mask_mean
        congestion_centered = congestion - congestion_mean

        covariance = (mask_centered * congestion_centered).mean()
        mask_var = ((mask - mask_mean) ** 2).mean()
        congestion_var = ((congestion - congestion_mean) ** 2).mean()

        # Pearson correlation coefficient as a proxy for mutual information
        eps = 1e-8
        correlation = covariance / (torch.sqrt(mask_var * congestion_var) + eps)

        # Convert correlation to mutual information estimate
        # Using the relationship that MI and correlation are related for Gaussian variables
        mutual_info = -0.5 * torch.log(1 - correlation ** 2 + eps)

        return mutual_info