'''2-layer CNN in PyTorch.
Reference:
[1] Negrea, J., Haghifam, M., Dziugaite, G. K., Khisti, A., & Roy, D. M. (2019).
    Information-theoretic generalization bounds for SGLD via data-dependent estimates.
    arXiv preprint arXiv:1911.02151.
'''


import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class smallCNN1(nn.Module):  
        
    # archtecture for MNIST and Fashion-MNIST
    def __init__(self):
        super(smallCNN1, self).__init__()
        self.l1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
        self.l2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
        
        self.classifier = nn.Sequential(
            nn.Linear(64*5*5, 1024),      
            nn.ReLU(inplace=True),
            nn.Linear(1024, 10),
        )

        for m in self.l1.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        for m in self.l2.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))


        for m in self.classifier.children():
            if isinstance(m, nn.Linear):
                m.weight = nn.init.xavier_normal_(m.weight,gain=nn.init.calculate_gain('relu'))

    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        # using NLL loss
        return F.log_softmax(x, dim=1) 






class smallCNN2(nn.Module):  
        
    # archtecture for CIFAR-10
    
    def __init__(self):
        super(smallCNN2, self).__init__()
        self.l1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
        self.l2 = nn.Sequential(
            nn.Conv2d(64, 192, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
        
        self.classifier = nn.Sequential(
            nn.Linear(192*6*6, 384),      
            nn.ReLU(inplace=True),
            nn.Linear(384,192),
            nn.ReLU(inplace=True),
            nn.Linear(192, 10),
        )


        for m in self.l1.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        for m in self.l2.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))


        for m in self.classifier.children():
            if isinstance(m, nn.Linear):
                m.weight = nn.init.xavier_normal_(m.weight,gain=nn.init.calculate_gain('relu'))

    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)




class smallCNN3(nn.Module):  
        
    # archtecture for CIFAR-100
   
    def __init__(self):
        super(smallCNN3, self).__init__()
        self.l1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
        self.l2 = nn.Sequential(
            nn.Conv2d(64, 192, kernel_size=5, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2)
            )
        
        self.classifier = nn.Sequential(
            nn.Linear(192*6*6, 384),      
            nn.ReLU(inplace=True),
            nn.Linear(384,192),
            nn.ReLU(inplace=True),
            nn.Linear(192, 100),
        )


        for m in self.l1.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))

        for m in self.l2.children():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))


        for m in self.classifier.children():
            if isinstance(m, nn.Linear):
                m.weight = nn.init.xavier_normal_(m.weight,gain=nn.init.calculate_gain('relu'))

    
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)

