import torch
from torch import nn
import numpy as np
from layers import CLOPLayer
import torchvision.models as models


class MNISTClassifier(nn.Module):
    def __init__(self, img_size, regul=None, p=0):
        super(MNISTClassifier, self).__init__()

        self.regul = regul

        if self.regul == "batch_norm":
            self.conv_feat = nn.Sequential(
                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.BatchNorm2d(32),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.BatchNorm2d(64),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                nn.ReLU(True),
                nn.BatchNorm2d(128),
            )

        else:
            self.conv_feat = nn.Sequential(
                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                nn.ReLU(True),
                nn.MaxPool2d(kernel_size=2, stride=2),
                nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
                nn.ReLU(True),
            )

        self.dropout = nn.Dropout2d(p)
        self.clop = CLOPLayer(p)

        self.conv_feat_size = self.conv_feat(torch.zeros(1, *img_size)).shape[1:]
        self.dense_feature_size = np.prod(self.conv_feat_size)

        self.classifier = nn.Sequential(
            nn.Linear(in_features=self.dense_feature_size, out_features=512),
            nn.ReLU(True),
            nn.Linear(in_features=512, out_features=100),
            nn.ReLU(True),
            nn.Linear(in_features=100, out_features=10),
            nn.LogSoftmax(),
        )

    def forward(self, x):
        x = self.conv_feat(x)
        if self.regul == "clop":
            x = self.clop(x)
        if self.regul == "dropout":
            x = self.dropout(x)
        x = x.view(-1, self.dense_feature_size)
        y = self.classifier(x)
        return y


class VGG11(nn.Module):
    def __init__(self, regul=None, p=0.7):
        super(VGG11, self).__init__()
        self.regul = regul
        if self.regul == "batch_norm":
            vgg = models.vgg11_bn(pretrained=False)
        else:
            vgg = models.vgg11(pretrained=False)
        self.features = vgg.features

        self.classifier = nn.Sequential(
            nn.Linear(in_features=25088, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=10, bias=True),
        )
        self.dropout = nn.Dropout2d(p)
        self.clop = CLOPLayer(p)
        self.batchnorm = nn.BatchNorm2d(512)

    def forward(self, x):
        x = self.features(x)
        if self.regul == "clop":
            x = self.clop(x)
        if self.regul == "dropout":
            x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        y = torch.log_softmax(x, 1)
        return y


class VGG9(nn.Module):
    def __init__(self, p=0.7, position=16, batch_norm=False):
        super(VGG9, self).__init__()
        self.position = position
        if batch_norm:
            vgg = models.vgg11_bn(pretrained=False)
        else:
            vgg = models.vgg11(pretrained=False)
        self.features = vgg.features[:-5]

        self.classifier = nn.Sequential(
            nn.Linear(in_features=18432, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=4096, out_features=10, bias=True),
        )

        self.clop = CLOPLayer(p)

    def forward(self, x):
        for i in range(0, self.position):
            x = self.features[i](x)
        x = self.clop(x)
        for i in range(self.position, 16):
            x = self.features[i](x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        y = torch.log_softmax(x, 1)
        return y
