"""
    Script for the different used models.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F

# This part is for the first considered experiment
class OneLayerMLP(nn.Module):
    def __init__(self, input_dim=784, output_dim=10):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

class MultiLayerMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.5):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(hidden_dims)):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.BatchNorm1d(dims[i+1]))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-1], output_dim))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x)

# This second part is for the experiene on the different sizes of MLP
# --- Model definitions ---
class SmallMLP(nn.Module):
    def __init__(self, input_dim=784, output_dim=10):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

class MediumMLP(nn.Module):
    def __init__(self, input_dim=784, hidden=512, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden)
        self.fc2 = nn.Linear(hidden, output_dim)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)

class LargeMLP(nn.Module):
    def __init__(self, input_dim=784, h1=512, h2=256, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, h1)
        self.fc2 = nn.Linear(h1, h2)
        self.fc3 = nn.Linear(h2, output_dim)
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

if __name__=="__main__":
    pass
