"""mlp.py
Mulit-layer Perceptron pytorch model class.
"""

import torch.nn as nn


class FullyConnectedBlock(nn.Module):
    def __init__(self, width, bn=False):
        super().__init__()
        self.linear = nn.Linear(width, width, bias=not bn)
        self.bn = bn
        if bn:
            self.bn_layer = nn.BatchNorm1d(width)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.linear(x)
        if self.bn:
            out = self.bn_layer(x)
        return self.relu(out)


class MLP(nn.Module):
    def __init__(self, block=FullyConnectedBlock, num_inputs=32*32*3, num_outputs=1, width=1000,
                 depth=5, bn=False):
        super().__init__()
        self.block = block
        self.bn = bn
        self.linear_first = nn.Linear(num_inputs, width, bias=not self.bn)
        if bn:
            self.bn_first = nn.BatchNorm1d(width)
        self.relu = nn.ReLU()
        self.layers = self._make_layer(block, width, depth-2, self.bn)
        self.linear_last = nn.Linear(width, num_outputs)

    def _make_layer(self, block, width, depth, bn):
        layers = []
        for i in range(depth):
            layers.append(block(width, bn=bn))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = x.view(x.size(0), -1)

        out = self.linear_first(out)
        if self.bn:
            out = self.bn_first(out)
        out = self.relu(out)

        out = self.layers(out)

        out = self.linear_last(out)
        return out


def mlp_100_3(num_outputs=10):
    return MLP(width=100, depth=3, num_outputs=num_outputs)


def mlp_100_4(num_outputs=10):
    return MLP(width=100, depth=4, num_outputs=num_outputs)


def mlp_100_5(num_outputs=10):
    return MLP(width=100, depth=5, num_outputs=num_outputs)


def mlp_100_6(num_outputs=10):
    return MLP(width=100, depth=6, num_outputs=num_outputs)


def mlp_100_7(num_outputs=10):
    return MLP(width=100, depth=7, num_outputs=num_outputs)


def mlp_100_8(num_outputs=10):
    return MLP(width=100, depth=8, num_outputs=num_outputs)


def mlp_100_9(num_outputs=10):
    return MLP(width=100, depth=9, num_outputs=num_outputs)


def mlp_100_10(num_outputs=10):
    return MLP(width=100, depth=10, num_outputs=num_outputs)


def mlp_200_3(num_outputs=10):
    return MLP(width=200, depth=3, num_outputs=num_outputs)


def mlp_200_4(num_outputs=10):
    return MLP(width=200, depth=4, num_outputs=num_outputs)


def mlp_200_5(num_outputs=10):
    return MLP(width=200, depth=5, num_outputs=num_outputs)


def mlp_200_6(num_outputs=10):
    return MLP(width=200, depth=6, num_outputs=num_outputs)


def mlp_200_7(num_outputs=10):
    return MLP(width=200, depth=7, num_outputs=num_outputs)


def mlp_200_8(num_outputs=10):
    return MLP(width=200, depth=8, num_outputs=num_outputs)


def mlp_200_9(num_outputs=10):
    return MLP(width=200, depth=9, num_outputs=num_outputs)


def mlp_200_10(num_outputs=10):
    return MLP(width=200, depth=10, num_outputs=num_outputs)


def mlp_250_3(num_outputs=10):
    return MLP(width=250, depth=3, num_outputs=num_outputs)


def mlp_250_4(num_outputs=10):
    return MLP(width=250, depth=4, num_outputs=num_outputs)


def mlp_250_5(num_outputs=10):
    return MLP(width=250, depth=5, num_outputs=num_outputs)


def mlp_250_6(num_outputs=10):
    return MLP(width=250, depth=6, num_outputs=num_outputs)


def mlp_250_7(num_outputs=10):
    return MLP(width=250, depth=7, num_outputs=num_outputs)


def mlp_250_8(num_outputs=10):
    return MLP(width=250, depth=8, num_outputs=num_outputs)


def mlp_250_9(num_outputs=10):
    return MLP(width=250, depth=9, num_outputs=num_outputs)


def mlp_250_10(num_outputs=10):
    return MLP(width=250, depth=10, num_outputs=num_outputs)

def mlp_500_3(num_outputs=10):
    return MLP(width=500, depth=3, num_outputs=num_outputs)

def mlp_500_4(num_outputs=10):
    return MLP(width=500, depth=4, num_outputs=num_outputs)


def mlp_500_5(num_outputs=10):
    return MLP(width=500, depth=5, num_outputs=num_outputs)


def mlp_500_6(num_outputs=10):
    return MLP(width=500, depth=6, num_outputs=num_outputs)


def mlp_500_7(num_outputs=10):
    return MLP(width=500, depth=7, num_outputs=num_outputs)


def mlp_1000_4(num_outputs=10):
    return MLP(width=1000, depth=4, num_outputs=num_outputs)


def mlp_1000_5(num_outputs=10):
    return MLP(width=1000, depth=5, num_outputs=num_outputs)


def mlp_1000_6(num_outputs=10):
    return MLP(width=1000, depth=6, num_outputs=num_outputs)


def mlp_1000_7(num_outputs=10):
    return MLP(width=1000, depth=7, num_outputs=num_outputs)


def mlp_250_3_bn(num_outputs=10):
    return MLP(width=250, depth=3, num_outputs=num_outputs, bn=True)


def mlp_250_4_bn(num_outputs=10):
    return MLP(width=250, depth=4, num_outputs=num_outputs, bn=True)


def mlp_500_3_bn(num_outputs=10):
    return MLP(width=500, depth=3, num_outputs=num_outputs, bn=True)


def mlp_500_4_bn(num_outputs=10):
    return MLP(width=500, depth=4, num_outputs=num_outputs, bn=True)


def mlp_500_7_bn(num_outputs=10):
    return MLP(width=500, depth=7, num_outputs=num_outputs, bn=True)


def mlp_1000_3_bn(num_outputs=10):
    return MLP(width=1000, depth=3, num_outputs=num_outputs, bn=True)


def mlp_20_3_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=3, num_outputs=num_outputs)


def mlp_20_4_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=4, num_outputs=num_outputs)


def mlp_20_5_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=5, num_outputs=num_outputs)


def mlp_20_6_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=6, num_outputs=num_outputs)


def mlp_20_7_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=7, num_outputs=num_outputs)


def mlp_20_8_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=8, num_outputs=num_outputs)


def mlp_20_9_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=9, num_outputs=num_outputs)


def mlp_20_10_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=20, depth=10, num_outputs=num_outputs)

def mlp_100_3_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=3, num_outputs=num_outputs)


def mlp_100_4_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=4, num_outputs=num_outputs)


def mlp_100_5_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=5, num_outputs=num_outputs)


def mlp_100_6_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=6, num_outputs=num_outputs)


def mlp_100_7_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=7, num_outputs=num_outputs)


def mlp_100_8_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=8, num_outputs=num_outputs)


def mlp_100_9_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=9, num_outputs=num_outputs)


def mlp_100_10_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=100, depth=10, num_outputs=num_outputs)


def mlp_200_3_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=3, num_outputs=num_outputs)


def mlp_200_4_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=4, num_outputs=num_outputs)


def mlp_200_5_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=5, num_outputs=num_outputs)


def mlp_200_6_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=6, num_outputs=num_outputs)


def mlp_200_7_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=7, num_outputs=num_outputs)


def mlp_200_8_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=8, num_outputs=num_outputs)


def mlp_200_9_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=9, num_outputs=num_outputs)


def mlp_200_10_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=200, depth=10, num_outputs=num_outputs)


def mlp_250_3_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=3, num_outputs=num_outputs)


def mlp_250_4_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=4, num_outputs=num_outputs)


def mlp_250_5_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=5, num_outputs=num_outputs)


def mlp_250_6_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=6, num_outputs=num_outputs)


def mlp_250_7_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=7, num_outputs=num_outputs)


def mlp_250_8_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=8, num_outputs=num_outputs)


def mlp_250_9_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=9, num_outputs=num_outputs)


def mlp_250_10_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=10, num_outputs=num_outputs)


def mlp_500_4_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=4, num_outputs=num_outputs)


def mlp_500_5_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=5, num_outputs=num_outputs)


def mlp_500_6_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=6, num_outputs=num_outputs)


def mlp_500_7_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=7, num_outputs=num_outputs)


def mlp_1000_4_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=1000, depth=4, num_outputs=num_outputs)


def mlp_1000_5_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=1000, depth=5, num_outputs=num_outputs)


def mlp_1000_6_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=1000, depth=6, num_outputs=num_outputs)


def mlp_1000_7_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=1000, depth=7, num_outputs=num_outputs)


def mlp_250_3_bn_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=3, num_outputs=num_outputs, bn=True)


def mlp_250_4_bn_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=250, depth=4, num_outputs=num_outputs, bn=True)


def mlp_500_3_bn_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=3, num_outputs=num_outputs, bn=True)


def mlp_500_4_bn_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=4, num_outputs=num_outputs, bn=True)


def mlp_500_7_bn_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=500, depth=7, num_outputs=num_outputs, bn=True)


def mlp_1000_3_bn_mnist(num_outputs=10):
    return MLP(num_inputs=1*28*28, width=1000, depth=3, num_outputs=num_outputs, bn=True)


def mlp_500_3_emnist(num_outputs=47):
    return MLP(num_inputs=1*28*28, width=500, depth=3, num_outputs=num_outputs)


def mlp_500_4_emnist(num_outputs=47):
    return MLP(num_inputs=1*28*28, width=500, depth=4, num_outputs=num_outputs)


def mlp_500_5_emnist(num_outputs=47):
    return MLP(num_inputs=1*28*28, width=500, depth=5, num_outputs=num_outputs)


def mlp_500_6_emnist(num_outputs=47):
    return MLP(num_inputs=1*28*28, width=500, depth=6, num_outputs=num_outputs)


def mlp_500_7_emnist(num_outputs=47):
    return MLP(num_inputs=1*28*28, width=500, depth=7, num_outputs=num_outputs)
