from typing import Callable

from torch import nn


activation_fns = {
    'relu': nn.ReLU,
    'sigmoid': nn.Sigmoid,
    'leaky_relu': nn.LeakyReLU,
    'gelu': nn.GELU,
    'tanh': nn.Tanh,
    'identity': nn.Identity
}


def get_activation_fn(fn_type: str) -> Callable:
    fn_type = fn_type.lower()
    if fn_type not in activation_fns:
        raise KeyError(f'Activation function {fn_type} is not supported. '
                       f'Supported list: {list(activation_fns.keys())}')
    return activation_fns[fn_type]
