import os
import sys
sys.path.insert(0, './')

import torch
import torch.nn as nn
import torch.nn.functional as F

''' VGG '''
cfg_vgg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, channel, num_classes, norm='batchnorm', im_size=(32, 32)):
        super(VGG, self).__init__()
        self.channel = channel
        self.features = self._make_layers(cfg_vgg[vgg_name], norm)
        self.classifier = nn.Linear(512*im_size[0]//32*im_size[1]//32 if vgg_name != 'VGGS' else 128*im_size[1]//32*im_size[1]//32, num_classes)

    def forward(self, x, output_feat=False):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        if output_feat:
            return x
        x = self.classifier(x)
        return x

    def _make_layers(self, cfg, norm):
        layers = []
        in_channels = self.channel
        for ic, x in enumerate(cfg):
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=3 if self.channel == 1 and ic == 0 else 1),
                           nn.GroupNorm(x, x, affine=True) if norm == 'instancenorm' else nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def VGG11(num_classes):
    return VGG('VGG11', channel=3, num_classes=num_classes)


def VGG13(num_classes):
    return VGG('VGG13', channel=3, num_classes=num_classes)


def VGG16(num_classes):
    return VGG('VGG16', channel=3, num_classes=num_classes)


def VGG19(num_classes):
    return VGG('VGG19', channel=3, num_classes=num_classes)