import torch.nn as nn
import torch.nn.functional as F


class MultiLeNetBackbone(nn.Module):
    def __init__(self, shape, dropout):
        super(MultiLeNetBackbone, self).__init__()
        self.shape = shape
        self.conv_net = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5), nn.MaxPool2d(kernel_size=2), nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5), nn.Dropout2d(dropout), nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(20, 20, kernel_size=5), nn.Dropout2d(dropout)
        )
        self.linear_net = nn.Sequential(
            nn.Linear(180, 50), nn.ReLU(), nn.Dropout(dropout),
            nn.BatchNorm1d(50)
        )
        self.output_size = 50

    def last_layer(self):
        return self.linear_net.parameters()

    def forward(self, x):
        y = self.conv_net(x.view((-1,) + self.shape))
        return self.linear_net(y.flatten(start_dim=1))


class MultiLeNetClassification(nn.Module):
    def __init__(self, dropout):
        super(MultiLeNetClassification, self).__init__()
        self.input_size = 50

        self.net = nn.Sequential(
            nn.Linear(50, 50), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(50, 10)
        )

    def forward(self, x):
        return F.log_softmax(self.net(x), dim=1)


class MultiLeNetRegression(nn.Module):
    def __init__(self, dropout):
        super(MultiLeNetRegression, self).__init__()
        self.input_size = 50

        self.net = nn.Sequential(
            nn.Linear(50, 50), nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(50, 1)
        )

    def forward(self, x):
        return self.net(x)


class MultiLeNetBinaryClassification(nn.Module):
    def __init__(self, dropout):
        super(MultiLeNetBinaryClassification, self).__init__()
        self.input_size = 50

        self.net = nn.Sequential(
            nn.Dropout(dropout), nn.Linear(50, 1), nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

