from typing import Dict, Optional

import torch.nn as nn
import torch
from config import MoLConfig


class CustomFFN(nn.Module):
    """Feed Forward Network that can be either modality-specific or shared"""

    def __init__(self, hidden_size, intermediate_size, config: MoLConfig):
        super().__init__()
        self.config = config

        if config.use_modality_specific_ffn:
            self.ffns = nn.ModuleDict({
                mod: nn.Sequential(
                    nn.Linear(hidden_size, intermediate_size, bias=False),
                    nn.GELU(),
                    nn.Linear(intermediate_size, hidden_size, bias=False)
                )
                for mod in config.modalities
            })
        else:
            self.shared_ffn = nn.Sequential(
                nn.Linear(hidden_size, intermediate_size, bias=False),
                nn.GELU(),
                nn.Linear(intermediate_size, hidden_size, bias=False)
            )

    def forward(
            self,
            x: torch.Tensor,
            modality_mask: Optional[Dict[str, torch.Tensor]] = None
        ) -> torch.Tensor:
        if self.config.use_modality_specific_ffn and modality_mask is not None:
            output = torch.zeros_like(x)

            for mod, mask in modality_mask.items():
                # Check if the modality exists in our FFNs and if there are any tokens for it
                if mod in self.ffns and mask.any():

                    # creates a dense tensor of shape [num_mod_tokens, hidden_dim].
                    modality_tokens = x[mask]

                    # Pass only the selected tokens through the correct FFN.
                    ffn_result = self.ffns[mod](modality_tokens)

                    # Place the results back into the correct positions.
                    output[mask] = ffn_result

            return output
        else:
            # Fallback to the original shared FFN forward pass
            # You need to ensure self.shared_ffn exists if this path is taken.
            if hasattr(self, 'shared_ffn'):
                 return self.shared_ffn(x)
            # If using modality-specific FFNs but no mask is provided,
            # fall back to a default, like 'text'.
            elif hasattr(self, 'ffns'):
                return self.ffns['text'](x)
            else:
                 raise RuntimeError("No FFN is defined for the forward pass.")