from typing import Optional
import torch.nn as nn
from pado.core import PadoModuleMixin

__all__ = [
    "ReLU",
    "ReLU6",
    "LeakyReLU",
    "SiLU",
    "Swish",
    "Hardswish",
    "ELU",
    "GELU",
    "GLU",
    "Sigmoid",
    "Hardsigmoid",
    "Tanh",
    "Hardtanh",
    "Softmax",
    "LogSoftmax",
    "get_activation_cls",
]


class ReLU(nn.ReLU, PadoModuleMixin):

    def __init__(self, inplace: bool = False):
        nn.ReLU.__init__(self, inplace=inplace)
        PadoModuleMixin.__init__(self)


class ReLU6(nn.ReLU6, PadoModuleMixin):
    def __init__(self, inplace: bool = False):
        nn.ReLU6.__init__(self, inplace=inplace)
        PadoModuleMixin.__init__(self)


class LeakyReLU(nn.LeakyReLU, PadoModuleMixin):
    def __init__(self, negative_slope: float = 0.01, inplace: bool = False):
        nn.LeakyReLU.__init__(self, negative_slope=negative_slope, inplace=inplace)
        PadoModuleMixin.__init__(self)


class SiLU(nn.SiLU, PadoModuleMixin):
    def __init__(self, inplace: bool = False):
        nn.SiLU.__init__(self, inplace=inplace)
        PadoModuleMixin.__init__(self)


Swish = SiLU


class Hardswish(nn.Hardswish, PadoModuleMixin):
    def __init__(self, inplace: bool = False):
        nn.Hardswish.__init__(self, inplace=inplace)
        PadoModuleMixin.__init__(self)


class ELU(nn.ELU, PadoModuleMixin):
    def __init__(self, alpha: float = 1.0, inplace: bool = False):
        nn.ELU.__init__(self, alpha=alpha, inplace=inplace)
        PadoModuleMixin.__init__(self)


class GELU(nn.GELU, PadoModuleMixin):
    def __init__(self):
        nn.GELU.__init__(self)
        PadoModuleMixin.__init__(self)


class GLU(nn.GLU, PadoModuleMixin):
    def __init__(self, dim: int = -1):
        nn.GLU.__init__(self, dim=dim)
        PadoModuleMixin.__init__(self)


class Tanh(nn.Tanh, PadoModuleMixin):
    def __init__(self):
        nn.Tanh.__init__(self)
        PadoModuleMixin.__init__(self)


class Hardtanh(nn.Hardtanh, PadoModuleMixin):
    def __init__(self, min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False):
        nn.Hardtanh.__init__(self, min_val, max_val, inplace=inplace)
        PadoModuleMixin.__init__(self)


class Sigmoid(nn.Sigmoid, PadoModuleMixin):
    def __init__(self):
        nn.Sigmoid.__init__(self)
        PadoModuleMixin.__init__(self)


class Hardsigmoid(nn.Hardsigmoid, PadoModuleMixin):
    def __init__(self, inplace: bool = False):
        nn.Hardsigmoid.__init__(self, inplace=inplace)
        PadoModuleMixin.__init__(self)


class Softmax(nn.Softmax, PadoModuleMixin):
    def __init__(self, dim: Optional[int] = None):
        nn.Softmax.__init__(self, dim=dim)
        PadoModuleMixin.__init__(self)


class LogSoftmax(nn.LogSoftmax, PadoModuleMixin):
    def __init__(self, dim: Optional[int] = None):
        nn.LogSoftmax.__init__(self, dim=dim)
        PadoModuleMixin.__init__(self)


def get_activation_cls(act_type: str, inplace: bool = False, **kwargs) -> nn.Module:
    act_type = act_type.lower()
    if act_type == "relu":
        return ReLU(inplace=inplace)
    elif act_type == "relu6":
        return ReLU6(inplace=inplace)
    elif act_type == "leaky_relu":
        return LeakyReLU(inplace=inplace, **kwargs)
    elif act_type == "tanh":
        return Tanh()
    elif act_type == "hard_tanh":
        return Hardtanh(inplace=inplace)
    elif act_type == "sigmoid":
        return Sigmoid()
    elif act_type == "hard_sigmoid":
        return Hardsigmoid(inplace=inplace)
    elif act_type == "swish":
        return Swish(**kwargs)
    elif act_type == "hard_swish":
        return Hardswish(inplace=inplace)
    elif act_type == "softmax":
        return Softmax(**kwargs)
    else:
        raise ValueError(f"Unsupported act_type {act_type} by `get_activation` function.")
