import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class TopAwareFusion(nn.Module):
    """
    Refines bottom-level features with guidance from top-level features.
    Formula:
        z_hat_h = α * z_h + (1 − α) * Proj(z_{h−1})
    """

    def __init__(self, dim_top: int, dim_bottom: int, alpha_init: float = 0.9):
        super().__init__()
        # Projection head: Linear → BN → ReLU
        self.proj = nn.Sequential(
            nn.Linear(dim_top, dim_bottom, bias=False),
            nn.GroupNorm(8, dim_bottom),
            nn.ReLU(inplace=True),
        )

        # Learnable logit for α (sigmoid(logit) -> (0,1))
        logit = math.log(alpha_init / (1.0 - alpha_init))
        self.alpha_raw = nn.Parameter(torch.tensor(logit, dtype=torch.float32))

    def forward(self, z_top: torch.Tensor, z_bottom: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            z_top (Tensor): Top-level embeddings, shape [N, dim_top].
            z_bottom (Tensor): Bottom-level embeddings, shape [N, dim_bottom].

        Returns:
            Tensor: Refined bottom-level embeddings, shape [N, dim_bottom].
        """
        alpha = torch.sigmoid(self.alpha_raw)  # scalar within (0,1)
        z_top_proj = self.proj(z_top)          # [N, dim_bottom]
        z_bottom_refined = alpha * z_bottom + (1.0 - alpha) * z_top_proj
        return z_bottom_refined


class BottomGuidedAggregation(nn.Module):
    """Aggregates top-level features using attention derived from bottom-level similarity,
    followed by LayerNorm and a residual gate to avoid over-smoothing.
    """

    def __init__(self, dim_top: int, dim_bottom: int, tau: float = 1.0):
        super().__init__()
        self.tau = tau
        # LayerNorm applied on the aggregated top feature (N, dim_top)
        self.norm = nn.LayerNorm(dim_top)
        # Learnable gate for residual blending: gamma in (0,1)
        logit = math.log(0.9 / (1.0 - 0.9))  # init gamma≈0.8
        self.gamma_raw = nn.Parameter(torch.tensor(logit, dtype=torch.float32))

    def forward(
        self,
        z_top: torch.Tensor,              # [N, dim_top]
        z_bottom: torch.Tensor,           # [N, dim_bottom]
        neighbor_indices: torch.Tensor,   # [N, K] (dtype=torch.long)
    ) -> torch.Tensor:
        """Forward pass.

        Args:
            z_top (Tensor): Top-level embeddings, shape [N, dim_top].
            z_bottom (Tensor): Bottom-level embeddings, shape [N, dim_bottom].
            neighbor_indices (Tensor): K-NN indices, shape [N, K].

        Returns:
            Tensor: Refined top-level embeddings, shape [N, dim_top].
        """
        # Ensure index dtype
        neighbor_indices = neighbor_indices.long()
        # Gather neighbor bottom features
        bottom_neighbors = z_bottom[neighbor_indices]      # [N, K, dim_bottom]

        # Compute attention scores (similarity)
        query = z_bottom.unsqueeze(1)                      # [N, 1, dim_bottom]
        scores = (query * bottom_neighbors).sum(-1) / self.tau  # [N, K]
        weights = F.softmax(scores, dim=1)                 # [N, K]

        # Gather neighbor top features
        top_neighbors = z_top[neighbor_indices]            # [N, K, dim_top]

        # Weighted aggregation
        weighted_top = weights.unsqueeze(-1) * top_neighbors  # [N, K, dim_top]
        agg_top = weighted_top.sum(dim=1)                     # [N, dim_top]
        # LayerNorm on aggregated feature
        agg_top = self.norm(agg_top)
        # Residual gate
        gamma = torch.sigmoid(self.gamma_raw)
        z_top_refined = gamma * z_top + (1.0 - gamma) * agg_top
        return z_top_refined


class HEFM(nn.Module):
    """Hierarchical Embedding Fusion Module (HEFM)."""

    def __init__(
        self,
        dim_top: int,
        dim_bottom: int,
        alpha_init: float = 0.9,
        tau: float = 1.0,
    ):
        super().__init__()
        self.top_fusion = TopAwareFusion(dim_top, dim_bottom, alpha_init)
        self.bottom_agg = BottomGuidedAggregation(dim_top, dim_bottom, tau)

    def forward(
        self,
        z_top: torch.Tensor,            # [N, dim_top]
        z_bottom: torch.Tensor,         # [N, dim_bottom]
        neighbor_indices: torch.Tensor, # [N, K]
    ) -> (torch.Tensor, torch.Tensor):
        """Forward pass returning refined features.

        Returns:
            Tuple[Tensor, Tensor]: (z_top_refined [N, dim_top], z_bottom_refined [N, dim_bottom])
        """
        z_bottom_refined = self.top_fusion(z_top, z_bottom)
        z_top_refined = self.bottom_agg(z_top, z_bottom, neighbor_indices)
        return z_top_refined, z_bottom_refined 