import math
import torch.nn as nn


class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()

        self.num_classes = 100  # NOTE : HARDCODE
        self.in_channels = 3
        
        self.cfg = [
            64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512,
            'M', 512, 512, 512, 'M'
        ]
       
        self.features = self.make_layers()
        self.classifier = nn.Linear(512, self.num_classes)
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = nn.AdaptiveAvgPool2d((1, 1))(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def make_layers(self):
        layers = []
        in_channels = self.in_channels
        for c in self.cfg:
            if c == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv = nn.Conv2d(in_channels,
                                 c,
                                 kernel_size=3,
                                 padding=1,
                                 bias=False)
                layers += [conv, nn.BatchNorm2d(c), nn.ReLU(inplace=True)]
                in_channels = c
        return nn.Sequential(*layers)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
