import torch
import torch.nn as nn
import torch.nn.init as init
import math

from torch.nn.modules.activation import Sigmoid

class deepNN(nn.Module):
    def __init__(self, n_classes):
        super(deepNN, self).__init__()
        self.output_size = n_classes
        self.features = nn.Sequential()
        self.classifier = nn.Sequential(
            nn.Linear(in_features=28 * 28, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=n_classes)
        )

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# Pytorch implementation of https://github.com/GeorgiosSmyrnis/multiclass_minimization_icml2020/blob/master/mnist_training.py
class CNN2D(nn.Module):
    def __init__(self, n_classes, hidden_size=500, conv_filters=16):
        super(CNN2D, self).__init__()
        self.output_size = n_classes
        self.features = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=conv_filters, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3),
            nn.Conv2d(in_channels=conv_filters, out_channels=conv_filters, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3)
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=conv_filters, out_features=hidden_size),
            nn.ReLU(),
            nn.Linear(in_features=hidden_size, out_features=n_classes)
        )

        # only for binary classification case
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)

        if self.output_size == 1:
            # binary classification case
            x = self.sigmoid(x).squeeze(1).double()

        return x


# ref https://towardsdatascience.com/implementing-yann-lecuns-lenet-5-in-pytorch-5e05a0911320
class LeNet5(nn.Module):
    def __init__(self, n_classes):
        super(LeNet5, self).__init__()
        
        self.features = nn.Sequential(            
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1),
            nn.Tanh()
        )

        self.classifier = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=n_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


# CIFAR-VGG implementation from https://github.com/JJGO/shrinkbench/blob/master/models/cifar_vgg.py
def pretrained_weights(model):
    url = f'https://raw.githubusercontent.com/JJGO/shrinkbench-models/master/cifar10/{model}.th'
    print(url)
    return torch.hub.load_state_dict_from_url(url, map_location='cpu')

class ConvBNReLU(nn.Module):

    def __init__(self, in_planes, out_planes):
        super(ConvBNReLU, self).__init__()

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=3//2)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-3)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class VGGBnDrop(nn.Module):

    def __init__(self, num_classes=10):

        super(VGGBnDrop, self).__init__()

        self.num_classes = num_classes

        self.features = nn.Sequential(

            ConvBNReLU(3, 64), nn.Dropout(0.3),
            ConvBNReLU(64, 64),
            nn.MaxPool2d(2, 2, ceil_mode=True),

            ConvBNReLU(64, 128), nn.Dropout(0.4),
            ConvBNReLU(128, 128),
            nn.MaxPool2d(2, 2, ceil_mode=True),

            ConvBNReLU(128, 256), nn.Dropout(0.4),
            ConvBNReLU(256, 256), nn.Dropout(0.4),
            ConvBNReLU(256, 256),
            nn.MaxPool2d(2, 2, ceil_mode=True),

            ConvBNReLU(256, 512), nn.Dropout(0.4),
            ConvBNReLU(512, 512), nn.Dropout(0.4),
            ConvBNReLU(512, 512),
            nn.MaxPool2d(2, 2, ceil_mode=True),

            ConvBNReLU(512, 512), nn.Dropout(0.4),
            ConvBNReLU(512, 512), nn.Dropout(0.4),
            ConvBNReLU(512, 512),
            nn.MaxPool2d(2, 2, ceil_mode=True),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, self.num_classes),
        )

        # To prevent pruning
        self.classifier[-1].is_classifier = True

    def forward(self, input):

        x = self.features(input)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def reset_weights(self):

        def init_weights(module):
            if isinstance(module, nn.Conv2d):
                fan_in, _ = init._calculate_fan_in_and_fan_out(module.weight)
                init.normal_(module.weight, 0, math.sqrt(2)/fan_in)
                init.zeros_(module.bias)

        self.apply(init_weights)


def vgg_bn_drop(pretrained=True):
    model = VGGBnDrop(num_classes=10)
    if pretrained:
        weights = pretrained_weights('vgg_bn_drop')
        model.load_state_dict(weights)
    # else:
        # model.reset_weights()
    return model


def vgg_bn_drop_100(pretrained=True):
    # For CIFAR 100
    model = VGGBnDrop(num_classes=100)
    if pretrained:
        weights = pretrained_weights('vgg_bn_drop_100')
        model.load_state_dict(weights)
    # else:
        # model.reset_weights()
    return model
