import torch.nn as nn
from typing import Union


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, inputs):
        return inputs


class ActivationUtil(object):
    @staticmethod
    def get_common_activation_layer(act_obj: Union[str, nn.Module] = "relu") -> nn.Module:
        if isinstance(act_obj, str):
            if act_obj.lower() == 'relu':
                return nn.ReLU(inplace=True)
            elif act_obj.lower() == 'sigmoid':
                return nn.Sigmoid()
            elif act_obj.lower() == 'linear':
                return Identity()
            elif act_obj.lower() == 'prelu':
                return nn.PReLU()
            elif act_obj.lower() == 'elu':
                return nn.ELU(inplace=True)
            elif act_obj.lower() == 'leakyrelu':
                return nn.LeakyReLU(0.2, inplace=True)
        else:
            return act_obj()
