import torch
import torch.nn as nn

class GatedCGM(nn.Module):
    """
    A Contextual Grounding Module that uses gating (multiplicative filtering)
    instead of cross-attention.
    """
    def __init__(self, fusion_dim: int, text_dim: int, hidden_dim_scale: int = 4):
        super().__init__()
        self.fusion_dim = fusion_dim
        self.text_dim = text_dim
        
        # This will create the global description vector
        # We need a robust way to handle this outside the forward pass, 
        # but for this example, let's assume it's pre-computed.

        gate_hidden_dim = (fusion_dim + text_dim) // hidden_dim_scale

        # A separate gating network for each visual modality, as their features might
        # need to be modulated differently.
        self.point_gate_generator = nn.Sequential(
            nn.Linear(fusion_dim + text_dim, gate_hidden_dim),
            nn.LayerNorm(gate_hidden_dim),
            nn.GELU(),
            nn.Linear(gate_hidden_dim, fusion_dim),
            nn.Sigmoid()
        )
        
        self.image_gate_generator = nn.Sequential(
            nn.Linear(fusion_dim + text_dim, gate_hidden_dim),
            nn.LayerNorm(gate_hidden_dim),
            nn.GELU(),
            nn.Linear(gate_hidden_dim, fusion_dim),
            nn.Sigmoid()
        )

    def forward(self, point_repr, image_repr, description_global):
        """
        Args:
            point_repr (Tensor): Shape [B, Np, D_fusion]
            image_repr (Tensor): Shape [B, Np, D_fusion]
            description_global (Tensor): Shape [B, D_text]
        """
        # 1. Broadcast the global description to match the shape of the visual features.
        # Shape: [B, Np, D_text]
        B, Np, _ = point_repr.shape
        global_desc_broadcast = description_global.unsqueeze(1).expand(-1, Np, -1)

        # --- Modulate the Point Representation ---
        # Concatenate each point feature with the global description
        point_gate_input = torch.cat([point_repr, global_desc_broadcast], dim=-1)
        point_gate_weights = self.point_gate_generator(point_gate_input) # [B, Np, D_fusion]
        
        # Apply the gate (multiplicative filtering) and add a residual connection
        # This is CRITICAL: the residual allows the model to learn an identity mapping
        # if modulation is not needed, which stabilizes training.
        point_repr_mod = point_repr + (point_repr * point_gate_weights)

        # --- Modulate the Image Representation ---
        image_gate_input = torch.cat([image_repr, global_desc_broadcast], dim=-1)
        image_gate_weights = self.image_gate_generator(image_gate_input) # [B, Np, D_fusion]
        image_repr_mod = image_repr + (image_repr * image_gate_weights)

        return point_repr_mod, image_repr_mod