import torch.nn as nn

####################
### CONVOLUTIONS ###
####################

CONV_TYPES = {
    1: nn.Conv1d,
    2: nn.Conv2d,
    3: nn.Conv3d,
}

###################
### ACTIVATIONS ###
###################

_ACTIVATION_REGISTRY = {
    "linear": nn.Identity,
    "relu": nn.ReLU,
    "sigmoid": nn.Sigmoid,
    "tanh": nn.Tanh,
    "lrelu": nn.LeakyReLU,
    "prelu": nn.PReLU,
    "softplus": nn.Softplus,
    "gelu": nn.GELU,
    "silu": nn.SiLU,
    "elu": nn.ELU,
    "selu": nn.SELU,
}

def get_activation(
        activation: str,
        **kwargs,
    ):
    """
    Retrieve and instantiate an activation layer by name.
    
    This function makes it easier to configure activation functions from config files
    or function calls by handling the lookup and instantiation process. It isolates 
    activation-specific parameters so that calling code can forward kwargs without 
    worrying about parameter conflicts.
    
    Parameters
    ----------
    activation : str
        Name of the activation function to use. Must be one of the registered
        activations in _ACTIVATION_REGISTRY (e.g., "relu", "gelu", "lrelu").
    **kwargs
        Keyword arguments from the caller. Only the nested dictionary under
        the "activation_kwargs" key is passed to the activation constructor.
        All other kwargs are ignored.
        
    Returns
    -------
    nn.Module
        Instantiated activation layer ready to use.
        
    Examples
    --------
    Basic usage:
    >>> layer = get_activation("relu")
    
    With activation-specific parameters:
    >>> layer = get_activation("lrelu", activation_kwargs={"negative_slope": 0.2})
    
    In a model configuration:
    >>> config = {"activation": "gelu", "activation_kwargs": {"approximate": "tanh"}}
    >>> layer = get_activation(**config)
    """

    # Resolve the activation class
    activation_cls = _ACTIVATION_REGISTRY[activation]

    # Extract user‑provided kwargs and return the instantiated activation
    return activation_cls(**(kwargs.pop("activation_kwargs", {}) or {}))

############################
### DEFAULT BLOCK KWARGS ###
############################

# _DEFAULT_BLOCK_KWARGS = {
#     "ResidualBlock": {
#         'kernel_size': 3,            # Size of the convolutional kernel
#         'padding_mode': "circular",  # Padding mode for the convolutional kernel
#         'norm': True,                # Whether to use normalization
#         'dropout_rate': 0.0,         # Dropout rate
#         'bias': True,                # Whether to include bias in convolutions
#         'activation': "gelu",        # Name of activation function
#         'activation_kwargs': None,   # Keyword arguments for activation function
#     },
#     "ConvNextBlock": { 
#         'kernel_size': 7,            # Size of the convolutional kernel
#         'padding_mode': "circular",  # Padding mode for the convolutional kernel
#         'norm': True,                # Whether to use normalization
#         'dropout_rate': 0.0,         # Dropout rate
#         'bias': True,                # Whether to include bias in convolutions
#         'activation': "gelu",        # Name of activation function
#         'activation_kwargs': None,   # Keyword arguments for activation function
#     },
# }

# def get_default_block_kwargs(block_class: str) -> dict:
#     """Return default kwargs for different block types."""
#     return _DEFAULT_BLOCK_KWARGS[block_class]