import typing as tp
import torch
import torch.nn as nn
import torch.nn.functional as F


class MLPInOut(nn.Module):
    """
    Flexible MLP layer with configurable input/output dimensions and hidden layers.

    Supports:
    - Direct linear mapping (hidden_units=[])
    - Single hidden layer (hidden_units=[120])
    - Multiple hidden layers (hidden_units=[120, 256])

    Args:
        dim_in (int): Input dimension.
        dim_out (int): Output dimension.
        hidden_units (List[int], optional): List of hidden layer dimensions.
            - If empty list [], creates direct Linear(dim_in, dim_out) with no activation
            - If [120], creates: Linear(dim_in, 120) → Act → Linear(120, dim_out)
            - If [120, 256], creates: Linear(dim_in, 120) → Act → Linear(120, 256) → Act → Linear(256, dim_out)
            - If None, uses hidden_dim parameter (for backward compatibility)
        hidden_dim (int, optional): DEPRECATED. Single hidden dimension for backward compatibility.
            If hidden_units is None and hidden_dim is None, defaults to 4 * dim_in.
        activation (str, optional): Activation function ('gelu', 'relu', 'silu'). Default: 'relu'.
        dropout (float, optional): Dropout rate applied after each activation. Default: 0.0.
        bias (bool, optional): Whether to use bias in linear layers. Default: True.

    Example:
        >>> # Direct linear mapping (no hidden layer)
        >>> mlp = MLPInOut(dim_in=16, dim_out=128, hidden_units=[])
        >>> x = torch.randn(2, 10, 16)
        >>> out = mlp(x)  # (2, 10, 128) via Linear(16, 128)

        >>> # Single hidden layer
        >>> mlp = MLPInOut(dim_in=16, dim_out=128, hidden_units=[120])
        >>> out = mlp(x)  # Linear(16, 120) → ReLU → Linear(120, 128)

        >>> # Multiple hidden layers
        >>> mlp = MLPInOut(dim_in=16, dim_out=128, hidden_units=[64, 96])
        >>> out = mlp(x)  # Linear(16, 64) → ReLU → Linear(64, 96) → ReLU → Linear(96, 128)
    """

    def __init__(
        self,
        dim_in: int,
        dim_out: int,
        hidden_units: tp.List[int] = None,
        hidden_dim: int = None,
        activation: str = 'relu',
        dropout: float = 0.0,
        bias: bool = True,
        **kwargs  # Accept and ignore extra kwargs like 'dim', 'max_length'
    ):
        super().__init__()

        self.dim_in = dim_in
        self.dim_out = dim_out

        # Handle backward compatibility with old hidden_dim API
        if hidden_units is None:
            if hidden_dim is not None:
                hidden_units = [hidden_dim]  # Convert old API to new
            else:
                hidden_units = [4 * dim_in]  # Original default

        self.hidden_units = hidden_units

        # Get activation module
        if activation == 'gelu':
            act_module = nn.GELU()
        elif activation == 'relu':
            act_module = nn.ReLU()
        elif activation == 'silu':
            act_module = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}. Choose from ['gelu', 'relu', 'silu']")

        # Build MLP architecture based on hidden_units
        if not hidden_units or len(hidden_units) == 0:
            # Direct linear mapping (no hidden layers, no activation)
            self.mlp = nn.Linear(dim_in, dim_out, bias=bias)
        else:
            # Multi-layer MLP: dim_in → [hidden_units] → dim_out
            layers = []
            input_size = dim_in

            for hidden_size in hidden_units:
                layers.append(nn.Linear(input_size, hidden_size, bias=bias))
                layers.append(act_module)
                if dropout > 0.0:
                    layers.append(nn.Dropout(dropout))
                input_size = hidden_size

            # Final layer (no activation after)
            layers.append(nn.Linear(input_size, dim_out, bias=bias))
            self.mlp = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Args:
            x (torch.Tensor): Input tensor of shape (..., dim_in)

        Returns:
            torch.Tensor: Output tensor of shape (..., dim_out)
        """
        return self.mlp(x)

    def __repr__(self):
        hidden_str = str(self.hidden_units) if self.hidden_units else "[]"
        return (f"{self.__class__.__name__}(dim_in={self.dim_in}, "
                f"dim_out={self.dim_out}, hidden_units={hidden_str})")
