## Import libraries.
from collections import OrderedDict

import torch
import torch.nn as nn


class DNN(torch.nn.Module):
    def __init__(self, layers, activation):
        super(DNN, self).__init__()

        self.depth = len(layers) - 1  # 'layers' will be given as [2, @@, ..., 1].
        if activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "softplus":
            self.activation = nn.Softplus()
        elif activation == "linear":
            self.activation = None
        elif activation == "tanh":
            self.activation = nn.Tanh()
        elif activation == "leakyrelu":
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == "softmax":
            self.activation = nn.Softmax(dim=1)
        elif activation == "gelu":
            self.activation = nn.GELU()
        elif activation == "prelu":
            self.activation = nn.PReLU()
        else:
            raise ValueError(f"Unexpected activation: {s_act}")

        layer_list = []

        for i in range(self.depth - 1):
            layer_list.append(('layer_%d' %i, torch.nn.Linear(layers[i], layers[i+1])))
            layer_list.append(('activation_%d' %i, self.activation))

        layer_list.append(('layer_%d' %(self.depth - 1), torch.nn.Linear(layers[-2], layers[-1])))

        layerDict = OrderedDict(layer_list)

        self.layers = torch.nn.Sequential(layerDict)

    def forward(self, x):
        out = self.layers(x)
        return out