import torch
import torch.nn as nn


def _append_if_not_none(list, item):
    if item is not None:
        list.append(item)


class FCNet(nn.Module):

    def __init__(self, input_size, depth, width, num_classes, activation, num_input_channels=1):
        super(FCNet, self).__init__()
        layers = []
        layers.append(nn.Linear(input_size, width))
        _append_if_not_none(layers, activation)
        for i in range(1, depth - 1):
            layers.append(nn.Linear(width, width))
            if i != depth - 2:
                _append_if_not_none(layers, activation)
        layers.append(nn.Linear(width, num_classes))
        self.layers = nn.Sequential(*layers)

    def forward(self, input_dict):
        x = input_dict['inputs']
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        out = self.layers(x)
        return {
            'logits': out
        }