"""
Common "internal" utilities for modules.
"""

import copy
from typing import List

import torch


def get_activation(e):
    from .op import NNLayerExp, NNLayerClipExp, NNLayerClipSiLU
    ACTIVATION_LOOKUP = {
        "relu": torch.nn.ReLU,
        "silu": torch.nn.SiLU,
        "gelu": torch.nn.GELU,
        "sigmoid": torch.nn.Sigmoid,
        "tanh": torch.nn.Tanh,
        "exp": NNLayerExp,
        "clipexp": NNLayerClipExp,
        "clipsilu": NNLayerClipSiLU,
    }
    if isinstance(e, str):
        le = e.lower().strip()
        if le not in ACTIVATION_LOOKUP:
            raise ValueError(f"unrecognized activation \"{e}\", known ones are {list(ACTIVATION_LOOKUP.keys())}")
        return ACTIVATION_LOOKUP[le]
    else:
        return e


def make_mlp(indim : int, hidden_dims : List[int], outdim : int,
             dropout : float = 0.0,
             activation : torch.nn.Module = torch.nn.ReLU,
             initial_layers : List[torch.nn.Module] = [],
             following_layers : List[torch.nn.Module] = [],
             return_as_list : bool = False):
    layers = copy.deepcopy(initial_layers)
    dims = [indim] + hidden_dims + [outdim]
    for din, dout in zip(dims[:-2], dims[1:-1]):
        layers.append(torch.nn.Linear(din, dout))
        if dropout > 0:
            layers.append(torch.nn.Dropout(p=dropout))
        layers.append(activation())
    layers.append(torch.nn.Linear(dims[-2], dims[-1]))
    layers += following_layers
    if return_as_list:
        return layers
    else:
        return torch.nn.Sequential(*layers)
