
from typing import Any, List

import torch
import torch.nn as nn
from torch import Tensor


class MLP(nn.Module):
    """create a multi-layer perceptron (MLP) model using pytorch.
    """

    def __init__(self,
                 input_size: int,
                 output_size: int,
                 hidden_sizes: List[int],
                 activations: List[str],
                 dropout: bool = False,
                 dropout_ratio: List[float] = None,
                 device: torch.device = torch.device("cpu")) -> None:
        super(MLP, self).__init__()

        """Initialize the MLP model.

        Parameters
        ----------
        input_size : int
            input size of the network.
        output_size : int
            output size of the network.
        hidden_sizes : List[int]
            list of hidden layer sizes.
        activations : List[str]
            list of activation functions.
        dropout : bool, optional
            use dropout or not, by default False
        dropout_ratio : List[float], optional
            dropout ratio for each layer, by default None
        device : torch.device, optional
            use GPU or not, by default False
        """
        # set device.
        self.set_device(device=device)

        # dimensions of the network.
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_sizes = hidden_sizes
        self.activations = activations

        # check the dropout parameters.
        self.dropout = dropout
        self.dropout_ratio = dropout_ratio

        # check the dropout parameters.
        self._check_dropout()
        # check the parameters.
        self._check_params()

        # create the layers.
        self.net = self._create_layers()

    def forward(self, X: Tensor) -> Tensor:
        """forward pass.

        Parameters
        ----------
        X : Tensor
            input of the network.

        Returns
        -------
        Tensor
            output of the network.
        """
        # forward pass.
        X = self.net(X)
        return X

    def _create_layers(self) -> nn.Sequential:
        """create the layers.

        Returns
        -------
        nn.Sequential
            network architecture.
        """

        # create the layers.
        net = nn.Sequential()
        # input layer.
        net.append(nn.Linear(self.input_size,
                             self.hidden_sizes[0]))

        # add dropout if exists.
        if self.dropout:
            net.append(nn.Dropout(p=self.dropout_ratio[0]))
        net.append(self._get_activation(self.activations[0]))
        # hidden layers.
        for i in range(1, len(self.hidden_sizes)):
            net.append(nn.Linear(self.hidden_sizes[i-1],
                                 self.hidden_sizes[i]))
            # add dropout if exists.
            if self.dropout:
                net.append(nn.Dropout(p=self.dropout_ratio[i]))

            # add activation function.
            net.append(self._get_activation(self.activations[i]))
        # output layer.
        net.append(nn.Linear(self.hidden_sizes[-1],
                             self.output_size))

        # output activation if exists.
        if len(self.hidden_sizes) + 1 == len(self.activations):
            net.append(self._get_activation(self.activations[-1]))

        return net

    def _get_activation(self, activation: str) -> nn.Module:
        """get the activation function.

        Parameters
        ----------
        activation : str
            activation function name.

        """
        if activation == "ReLU":
            return nn.ReLU()
        elif activation == "Sigmoid":
            return nn.Sigmoid()
        elif activation == "Tanh":
            return nn.Tanh()
        elif activation == "LeakyReLU":
            return nn.LeakyReLU()
        elif activation == "ELU":
            return nn.ELU()
        elif activation == "SELU":
            return nn.SELU()
        elif activation == "Softplus":
            return nn.Softplus()
        elif activation == "Softsign":
            return nn.Softsign()
        elif activation == "Identity":
            return nn.Identity()
        else:
            raise ValueError("The activation function is not supported.")

    def _check_params(self) -> None:
        """check the parameters.

        Raises
        ------
        ValueError
            if the length of hidden_features and activations are not the same
            or activations is one more than hidden_features

        """
        # check the length of hidden_features and activations are the same or
        #  activations is one more than hidden_features.
        if len(self.hidden_sizes) != len(self.activations) and \
                len(self.hidden_sizes) + 1 != len(self.activations):
            raise ValueError("The length of hidden_features and activations "
                             "should be the same or activations should be one "
                             "more than hidden_features.")
        # output info if one more activation is provided.
        if len(self.hidden_sizes) + 1 == len(self.activations):
            print("Last activate function is used after the output layer.")

    def _check_dropout(self) -> None:
        """check the dropout parameters.

        Raises
        ------
        ValueError
            if the length of dropout_ratio is not the same as the length of
            hidden_features.
        """
        if self.dropout:
            if len(self.hidden_sizes) != len(self.dropout_ratio):
                raise ValueError("The length of dropout_ratio should be the "
                                 "same as the length of hidden_features.")

    def set_device(self, device: torch.device) -> None:
        """set device

        Parameters
        ----------
        device : torch.device
            use GPU or not, by default False

        Raises
        ------
        ValueError
            if the device is not supported.
        """

        if device == torch.device("cpu"):
            torch.set_default_tensor_type('torch.FloatTensor')
        else:
            torch.set_default_tensor_type('torch.cuda.FloatTensor')


class BayesMLP(MLP):
    """Bayesian Multi-layer Perceptron (MLP) model using pytorch."""

    def __init__(self,
                 input_size: int,
                 output_size: int,
                 hidden_sizes: List[int],
                 activations: List[str],
                 prior_mu: float = 0.0,
                 prior_sigma: float = 1.0,
                 device: torch.device = torch.device("cpu")) -> None:
        super(BayesMLP, self).__init__(input_size=input_size,
                                       output_size=output_size,
                                       hidden_sizes=hidden_sizes,
                                       activations=activations,
                                       device=device)
        """Initialize the Bayesian MLP model.

        Parameters
        ----------
        input_size : int
            input size of the network.
        output_size : int
            output size of the network.
        hidden_sizes : List[int]
            list of hidden layer sizes.
        activations : List[str]
            list of activation functions.
        prior_mu : float, optional
            mean of the prior distribution, by default 0.0
        prior_sigma : float, optional
            standard deviation of the prior distribution, by default 1.0
        device : torch.device, optional
            use GPU or not, by default False
        """
        # set device
        self.set_device(device=device)
        # prior distribution (for now unite prior distribution is supported)
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma

    def forward(self,
                X: torch.Tensor,
                Train: bool = False) -> torch.Tensor | Any:
        """forward pass

        Parameters
        ----------
        X : torch.Tensor
            input of the network
        Train : bool, optional
            Train or prediction, by default False

        Returns
        -------
        torch.Tensor
            output of the network
        """
        # pass a single data point
        outputs = self.net(X)

        # forward pass
        if Train:
            # if Train, return the output and prior loss
            prior_loss = 0
            for m in self.net.modules():
                if isinstance(m, nn.Linear):
                    # prior distribution
                    dist = torch.distributions.Normal(
                        self.prior_mu, self.prior_sigma)
                    # prior loss
                    prior_loss = prior_loss + \
                        dist.log_prob(m.weight).sum() + \
                        dist.log_prob(m.bias).sum()
            return outputs, -prior_loss
        else:
            # if prediction, only return the output
            return outputs


class GammaVarMLP(nn.Module):
    """Multi-layer Perceptron, which has the same architecture as the
    MLP but with the output layer that returns the alpha and beta of gamma
    distribution."""

    def __init__(self,
                 input_size: int,
                 output_size: int,
                 hidden_sizes: List[int] = [32, 32],
                 activations: List[str] = ["ReLU", "ReLU", "Softplus"],
                 prior_mu: float = 0.0,
                 prior_sigma: float = 1.0,
                 device: torch.device = torch.device("cpu")) -> None:
        super(GammaVarMLP, self).__init__()
        """Initialize the MLP model, which follows the Gamma header.

        Parameters
        ----------
        input_size : int
            input size of the network
        output_size : int
            output size of the network
        hidden_sizes : List[int], optional
            hidden sizes of the network, by default [32, 32]
        activations : List[str], optional
            activation functions of the network, by default ["ReLU", "ReLU",
            "Softplus"]
        prior_mu : float, optional
            prior mean of the network, by default 0.0
        prior_sigma : float, optional
            prior sigma of the network, by default 1.0
        device : torch.device, optional
            device of the network, by default torch.device("cpu")
        """

        self.net = BayesMLP(input_size=input_size,
                            output_size=2*output_size,
                            hidden_sizes=hidden_sizes,
                            activations=activations,
                            prior_mu=prior_mu,
                            prior_sigma=prior_sigma,
                            device=device)

    def forward(self, X: Tensor, Train: bool = False) -> Tensor:
        """forward pass

        Parameters
        ----------
        X : torch.Tensor
            input of the network
        Train : bool, optional
            whether the network is in training mode, by default False

        Returns
        -------
        torch.Tensor
            output of the network
        """
        # get the forward pass
        if Train:
            X, prior_loss = self.net(X, Train=True)
        else:
            X = self.net(X)

        alpha = X[:, 0:self.net.output_size//2]
        beta = X[:, self.net.output_size//2:]

        if Train:
            return alpha, beta, prior_loss
        else:
            return alpha, beta
