import torch
import torch.nn as nn
from torch.nn.modules.container import Sequential
from torchvision import models
import math

class Abstract(nn.Module):

    def __init__(self, pretrained = True):
        super().__init__()
        base_model = models.vgg11(pretrained=pretrained)
        self.layers = nn.Sequential(*list(base_model.children())[:-1])

        if not pretrained:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

    def forward(self, X):

        X = self.layers(X)


        return X