import torch.nn as nn


class CelebABackbone(nn.Module):
    def __init__(self):
        super(CelebABackbone, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64), nn.MaxPool2d(3, 2),
            nn.Conv2d(64, 128, 3), nn.BatchNorm2d(128),
            nn.Conv2d(128, 128, 3), nn.BatchNorm2d(128), nn.MaxPool2d(3, 2),
            nn.Conv2d(128, 256, 3), nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 3), nn.BatchNorm2d(256), nn.MaxPool2d(3, 2),
            nn.Conv2d(256, 512, 3), nn.BatchNorm2d(512),
            nn.Flatten(), nn.Linear(512, 512), nn.BatchNorm1d(512),
        )

        self.output_size = 512

    def forward(self, input):
        return self.net(input)

    def last_layer(self):
        return self.net[-3].parameters()


class CelebAClassification(nn.Module):
    def __init__(self, index):
        super(CelebAClassification, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(512, 128), nn.BatchNorm1d(128),
            nn.Linear(128, 1), nn.Sigmoid()
        )

        self.index = index

    def forward(self, input):
        return self.net(input)
