from typing import Dict, Optional

import torch.nn as nn
import torch

from config import MoLConfig


class CustomLayerNorm(nn.Module):
    """LayerNorm that can be either modality-specific or shared"""

    def __init__(self, normalized_shape, config: MoLConfig, eps=1e-5):
        super().__init__()
        self.config = config

        # Use ModuleDicts for clean, modality-specific parameters
        self.ln_layers = nn.ModuleDict({
            mod: nn.LayerNorm(normalized_shape, eps=eps) 
            for mod in config.modalities
        })

    def forward(self, x: torch.Tensor, modality_mask: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
        
        if not self.config.use_modality_specific_ln:
            # Fallback to the 'text' LayerNorm if sharing or no mask
            ln_layer = self.ln_layers.get('text', next(iter(self.ln_layers.values())))
            return ln_layer(x)

        if modality_mask is None:
            raise ValueError(
                "modality_mask must be provided " \
                "for modality-specific LayerNorm.")

        output = torch.zeros_like(x)

        for mod, mask in modality_mask.items():
            if mod in self.ln_layers and mask.any():
                # Directly apply the modality-specific LayerNorm to the selected tokens
                # and assign the result back to the output tensor.
                # This preserves the shape and is numerically stable.
                output[mask] = self.ln_layers[mod](x[mask])

        return output


class CustomRMSNorm(nn.Module):
    """
    RMSNorm that can be either modality-specific or shared,
    designed to replace LlamaRMSNorm in a MoL setup.
    """

    def __init__(
        self,
        normalized_shape,
        config: MoLConfig,
        eps: float = 1e-6
    ):
        super().__init__()
        self.config = config
        self.eps = eps

        self.weight = nn.ParameterDict({
            mod: nn.Parameter(torch.ones(normalized_shape))
            for mod in config.modalities
        })

    def forward(
        self,
        x: torch.Tensor,
        modality_mask: Optional[Dict[str, torch.Tensor]] = None
    ) -> torch.Tensor:

        if not self.config.use_modality_specific_ln:
            # The 'text' LoRA is used as a default or shared adapter
            weight = self.weight.get('text', next(iter(self.weight.values())))
            # Fallback to the first weight if 'text' isn't defined
            return self._rmsnorm(x, weight)

        if modality_mask is None:
            raise ValueError(
                "modality_mask must be provided " \
                "for modality-specific RMSNorm."
                )

        output = torch.zeros_like(x)
        input_dtype = x.dtype

        for mod, mask in modality_mask.items():
            if mod in self.weight and mask.any():
                # Use the boolean mask directly to select tokens.
                modality_tokens = x[mask] # Shape: [num_tokens_for_mod, hidden_dim]

                # Normalize only the selected tokens
                variance = modality_tokens.to(
                    torch.float32).pow(2).mean(-1, keepdim=True)
                norm_tokens = modality_tokens * torch.rsqrt(variance + self.eps)
                
                # Apply modality-specific weight and place back into the output tensor
                output[mask] = (self.weight[mod] * norm_tokens).to(input_dtype)

        return output


    def _rmsnorm(self, x: torch.Tensor, weight: nn.Parameter) -> torch.Tensor:
        """Helper function to perform RMSNorm."""
        input_dtype = x.dtype
        variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = x * torch.rsqrt(variance + self.eps)
        return (weight * hidden_states).to(input_dtype)
