
import torch.nn as nn

class MLP_AutoEncoder(nn.Module):

    def __init__(self, input_dim=2, num_classes=1, num_hidden_nodes=128, af_name='tanh'):
        super(MLP_AutoEncoder, self).__init__()

        self.input_dim = input_dim
        self.num_classes = num_classes
        self.num_hidden_nodes = num_hidden_nodes

        if af_name == 'tanh':
            af = nn.Tanh
        elif af_name == 'relu':
            af = nn.ReLU
        elif af_name == 'leaky_relu':
            af = nn.LeakyReLU
        elif af_name == 'sigmoid':
            af = nn.Sigmoid
        elif af_name == 'elu':
            af = nn.ELU
        

        self.encoder = nn.Sequential(
            # ================================================
            nn.Linear(self.input_dim, 128, bias=False),
            af(),

            nn.Linear(128, 128, bias=False),
            af(),

            nn.Linear(128, 64, bias=False),
            af(),

            nn.Linear(64, 32, bias=False),
            af(),

            nn.Linear(32, 8, bias=False),
            af(),

            nn.Linear(8, self.num_hidden_nodes, bias=False),

            
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.num_hidden_nodes, 8, bias=False),
            af(),
            nn.Linear(8, 16, bias=False),
            af(),
            nn.Linear(16, self.input_dim, bias=False),

        )

    def forward(self, x):

        x = self.encoder(x)
        mid_repre = x
        out = self.decoder(x)
        return out, mid_repre


