import math
import torch.nn.functional as F
import torch
import torch.nn as nn


class MLP(nn.Module):

    def __init__(self,
                 input_dim: int = 28 * 28,
                 width: int = 50,
                 depth: int = 3,
                 num_classes: int = 10):
        '''
        :param input_dim: input shape of data
        :param width: the width of the net
        :param depth: the depth of the net
        :param num_classes: #classes
        '''
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.width = width
        self.depth = depth
        self.num_classes = num_classes

        layers = self.get_layers()
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, self.width, bias=False),
            nn.ReLU(inplace=True),
            *layers,
            nn.Linear(self.width, self.num_classes, bias=False),
        )

    def get_layers(self):
        '''
        :return: the middle layers
        '''
        layers = []
        for i in range(self.depth - 2):
            layers.append(nn.Linear(self.width, self.width, bias=False))
            layers.append(nn.ReLU())

        return layers

    def forward(self, x):
        x = x.view(x.size(0), self.input_dim)
        x = self.fc(x)
        return x


