import logging

import torch
import torchvision

import model
import modules


class Vgg(model.Model):

    NAME = "vgg"
    HIDDEN = 512
    LINHID = 512

    def is_pretrained(self):
        return False

    def create_layers(self, input_size):
        """Return list of torch.nn.Module objects, layers of this network.

        Parameters:
        ===========
        input_size: tuple of int dimensions of input.
        """
        logging.info("Creating pretrained={} VGG16...".format(self.is_pretrained()))
        vgg16 = torchvision.models.vgg16(pretrained=self.is_pretrained())
        h = Vgg.HIDDEN
        lh = Vgg.LINHID
        return [
            vgg16.features,
            modules.AssertShape(h, 1, 1),
            modules.Reshape(h),
            torch.nn.Linear(h, lh),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(lh, lh),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=0.5)
        ]

    def create_classifier(self, targets):
        """Return the classifier module of this network.

        Parameters:
        ===========
        targets: int number of classes to predict.
        """
        return torch.nn.Linear(Vgg.LINHID, targets)
