import torch

import model


class Mlp(model.Model):

    NAME = "mlp"
    HIDDEN_SIZE = 128

    def create_layers(self, input_size):
        """Return list of torch.nn.Module objects, layers of this network.

        Parameters:
        ===========
        input_size: tuple of int dimensions of input.
        """
        (C,) = input_size
        return [
            torch.nn.Linear(C, Mlp.HIDDEN_SIZE),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(Mlp.HIDDEN_SIZE, Mlp.HIDDEN_SIZE),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
        ]

    def create_classifier(self, targets):
        """Return the classifier module of this network.

        Parameters:
        ===========
        targets: int number of classes to predict.
        """
        return torch.nn.Linear(Mlp.HIDDEN_SIZE, targets)