import torch.nn as nn
from Models.Layers.Layers import MaskedLinear, MaskedConv2d

cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class  VGG(nn.Module):
    def __init__(self, cfg, num_classes, batch_norm = False):
        super(VGG, self).__init__()
        layers = []
        in_channels = 3

        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = MaskedConv2d(in_channels, v, kernel_size=3, padding=1)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v

        self.layers = nn.Sequential(*layers)
        self.fc = MaskedLinear(512, num_classes)

        self._initialize_weights()

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (MaskedLinear, MaskedConv2d)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _initialize_weights_normal(self):
        for m in self.modules():
            if isinstance(m, (MaskedLinear, MaskedConv2d)):
                nn.init.normal_(m.weight, mean=0.0, std=0.1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _initialize_weights_uniform(self):
        for m in self.modules():
            if isinstance(m, (MaskedLinear, MaskedConv2d)):
                nn.init.xavier_uniform_(m.weight, gain=1.0)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def set_masks(self, weight_mask, bias_mask):
        i = 0
        for m in self.modules():
            if isinstance(m,(MaskedLinear, MaskedConv2d)):
                m.set_mask(weight_mask[i],bias_mask[i])
                i = i + 1

def vgg11(input_shape, num_classes):
    return VGG(cfg['A'], num_classes)

def vgg11_bn(input_shape, num_classes):
    return VGG(cfg['A'], num_classes, batch_norm = True)

def vgg13(input_shape, num_classes):
    return VGG(cfg['B'], num_classes)

def vgg13_bn(input_shape, num_classes):
    return VGG(cfg['B'], num_classes, batch_norm = True)

def vgg16(input_shape, num_classes):
    return VGG(cfg['D'], num_classes)

def vgg16_bn(input_shape, num_classes):
    return VGG(cfg['D'], num_classes, batch_norm = True)

def vgg19(input_shape, num_classes):
    return VGG(cfg['E'], num_classes)

def vgg19_bn(input_shape, num_classes):
    return VGG(cfg['E'], num_classes, batch_norm = True)