import torch
import torch.nn as nn
from models.reduced.inn_resnet import DoNothing

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, vgg_name, num_classes, inn_classes=2, class_in_activation='none'):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.class_in_0 = nn.Linear(num_classes, 256)
        self.class_in_1 = nn.Linear(256, 512)
        if class_in_activation == 'relu':
            self.class_in_activation = nn.ReLU(inplace=True)
        elif class_in_activation == 'sigmoid':
            self.class_in_activation = nn.Sigmoid()
        elif class_in_activation == 'l-relu':
            self.class_in_activation = nn.LeakyReLU(inplace=True)
        elif class_in_activation == 'tanh':
            self.class_in_activation = nn.Tanh()
        elif class_in_activation == 'none':
            self.class_in_activation = DoNothing()
        else:
            raise NotImplementedError('{} activation not implemented!'.format(class_in_activation))
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.inn_fc = nn.Linear(512, inn_classes)

    def forward(self, x, y):
        x = self.features(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        y = self.class_in_activation(self.class_in_0(y))
        y = self.class_in_activation(self.class_in_1(y))
        z = x * y
        f_z = self.inn_fc(z)
        return f_z

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
