from torch import nn


class CelebaNet(nn.Module):
    """
    Multi-class classifier for the CelebA dataset.
    """

    def __init__(self, name='CelebaNet', args=None):
        super(CelebaNet, self).__init__()

        self.name = name
        self.args = args

        self.ConvLayer1 = nn.Sequential(
            nn.Conv2d(3, 64, 3),  # 3, 256, 256
            nn.MaxPool2d(2),  # op: 16, 127, 127
            nn.ReLU(),  # op: 64, 127, 127
        )
        self.ConvLayer2 = nn.Sequential(
            nn.Conv2d(64, 128, 3),  # 64, 127, 127
            nn.MaxPool2d(2),  # op: 128, 63, 63
            nn.ReLU()  # op: 128, 63, 63
        )
        self.ConvLayer3 = nn.Sequential(
            nn.Conv2d(128, 256, 3),  # 128, 63, 63
            nn.MaxPool2d(2),  # op: 256, 30, 30
            nn.ReLU()  # op: 256, 30, 30
        )
        self.ConvLayer4 = nn.Sequential(
            nn.Conv2d(256, 512, 3),  # 256, 30, 30
            nn.MaxPool2d(2),  # op: 512, 14, 14
            nn.ReLU(),  # op: 512, 14, 14
            nn.Dropout(0.2)
        )
        self.Linear1 = nn.Linear(512 * 14 * 14, 1024)
        self.Linear2 = nn.Linear(1024, 256)
        self.Linear3 = nn.Linear(256, 40)

    def forward(self, x):
        x = self.ConvLayer1(x)
        x = self.ConvLayer2(x)
        x = self.ConvLayer3(x)
        x = self.ConvLayer4(x)
        x = x.view(x.size(0), -1)
        # print('x size: ', x.size())
        x = self.Linear1(x)
        x = self.Linear2(x)
        x = self.Linear3(x)
        return x
