'''
!! CAUTION !!
This class is not general implementation for DNN, but for the interface function only.
It forces the interface function(of classical-PINNs) being positive.

'''


## Import libraries.
from collections import OrderedDict

import torch
import torch.nn as nn


class DNN_positive(torch.nn.Module):
    def __init__(self, args):
        super(DNN_positive, self).__init__()

        init_flag = args.init_flag
        layers = args.layers_interface
        activation = args.activation_interface

        shape = args.PSI_list[args.arg_conformal]

        self.depth = len(layers) - 1  # 'layers' will be given as [2, @@, ..., 1].
        if activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "softplus":
            self.activation = nn.Softplus()
        elif activation == "linear":
            self.activation = None
        elif activation == "tanh":
            self.activation = nn.Tanh()
        elif activation == "leakyrelu":
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == "softmax":
            self.activation = nn.Softmax(dim=1)
        elif activation == "gelu":
            self.activation = nn.GELU()
        elif activation == "prelu":
            self.activation = nn.PReLU()
        else:
            raise ValueError(f"Unexpected activation: {s_act}")

        layer_list = []

        for i in range(self.depth - 1):
            layer_list.append(('layer_%d' %i, torch.nn.Linear(layers[i], layers[i+1])))
            layer_list.append(('activation_%d' %i, self.activation))

        if init_flag == "False" or shape != "custom":    
            layer_list.append(('layer_%d' %(self.depth - 1), torch.nn.Linear(layers[-2], layers[-1])))

        elif init_flag == "True" and shape == "custom":
            linear_layer = torch.nn.Linear(layers[-2], layers[-1])

            with torch.no_grad():

                linear_layer.bias[0] = 5  # Setting the first bias element to 5

            layer_list.append(('layer_%d' %(self.depth - 1), linear_layer))
            
        ## We should append the relu-activation to make the interface being positive ##
        # layer_list.append(('activation_%d' %(self.depth - 1), nn.ReLU(inplace=True)))

        layerDict = OrderedDict(layer_list)

        self.layers = torch.nn.Sequential(layerDict)

    def forward(self, x):
        out = self.layers(x)
        return out