'''VGG11/13/16/19 in Pytorch.

From pytorch-cifar:
https://raw.githubusercontent.com/kuangliu/pytorch-cifar/master/models/vgg.py
'''
import torch
import torch.nn as nn

import model_utils


cfg = {
  'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
  def __init__(self, vgg_name):
    super(VGG, self).__init__()
    self.features = self._make_layers(cfg[vgg_name])
    self.classifier = nn.Linear(512, 10)

  def forward(self, x):
    out = self.features(x)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def _make_layers(self, cfg):
    layers = []
    in_channels = 3
    for x in cfg:
      if x == 'M':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                   nn.BatchNorm2d(x),
                   nn.ReLU(inplace=True)]
        in_channels = x
    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    return nn.Sequential(*layers)


def vgg11(flags=None, **kwargs):
  old_state = model_utils.set_rng_state(flags)
  model = VGG('VGG11')
  model_utils.restore_rng_state(old_state)
  return model


def vgg13(flags=None, **kwargs):
  old_state = model_utils.set_rng_state(flags)
  model = VGG('VGG13')
  model_utils.restore_rng_state(old_state)
  return model


def vgg16(flags=None, **kwargs):
  old_state = model_utils.set_rng_state(flags)
  model = VGG('VGG16')
  model_utils.restore_rng_state(old_state)
  return model
