import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class simple_net(nn.Module):
    def __init__(self, in_channel, widen_factor, n_fc, num_classes):
        super(simple_net, self).__init__()
        nChannels = [32, 32*widen_factor, 64*widen_factor, 64*widen_factor]

        self.conv1 = nn.Conv2d(in_channel, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(nChannels[0], nChannels[1], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(nChannels[1], nChannels[2], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(nChannels[2], nChannels[3], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.nChannels = 64*nChannels[3]
        self.relu4 = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(self.nChannels, n_fc)
        self.relu5 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(n_fc, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.relu3(self.conv3(out))
        out = self.relu4(self.conv4(out))
        out = out.reshape(-1, self.nChannels)
        out = self.relu5(self.fc1(out))
        return self.fc2(out)



class simple_net_stl(nn.Module):
    def __init__(self, in_channel, widen_factor, n_fc, num_classes):
        super(simple_net_stl, self).__init__()
        nChannels = [96, 96*widen_factor, 192*widen_factor, 192*widen_factor]

        self.conv1 = nn.Conv2d(in_channel, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(nChannels[0], nChannels[1], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(nChannels[1], nChannels[2], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(nChannels[2], nChannels[3], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.nChannels = 24*24*nChannels[3]
        # print('nc3',nChannels[3])
        self.relu4 = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(self.nChannels, n_fc)
        self.relu5 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(n_fc, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.relu3(self.conv3(out))
        out = self.relu4(self.conv4(out))
        # print('out shape', out.shape)
        out = out.reshape(-1, self.nChannels)
        out = self.relu5(self.fc1(out))
        return self.fc2(out)        

    
class simple_net_mnist(nn.Module):
    def __init__(self, in_channel, widen_factor, n_fc, num_classes):
        super(simple_net_mnist, self).__init__()
        nChannels = [28, 28*widen_factor, 56*widen_factor, 56*widen_factor]

        self.conv1 = nn.Conv2d(in_channel, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(nChannels[0], nChannels[1], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(nChannels[1], nChannels[2], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(nChannels[2], nChannels[3], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.nChannels = 7*7*nChannels[3]
        # print('nc3',nChannels[3])
        self.relu4 = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(self.nChannels, n_fc)
        self.relu5 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(n_fc, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.relu3(self.conv3(out))
        out = self.relu4(self.conv4(out))
        print('out shape', out.shape)
        out = out.reshape(-1, self.nChannels)
        out = self.relu5(self.fc1(out))
        return self.fc2(out)        

    
class simpler_net(nn.Module):
    def __init__(self, in_channel, widen_factor, n_fc, num_classes):
        super(simpler_net, self).__init__()
        nChannels = [8, 8*widen_factor, 16*widen_factor, 16*widen_factor]

        self.conv1 = nn.Conv2d(in_channel, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(nChannels[0], nChannels[1], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(nChannels[1], nChannels[2], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(nChannels[2], nChannels[3], kernel_size=4, stride=2,
                               padding=1, bias=False)
        self.nChannels = 4*nChannels[3]
        self.relu4 = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(self.nChannels, n_fc)
        self.relu5 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(n_fc, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.relu3(self.conv3(out))
        out = self.relu4(self.conv4(out))
        out = out.reshape(-1, self.nChannels)
        out = self.relu5(self.fc1(out))
        return self.fc2(out)

class simple_cnn(nn.Module):
    def __init__(self, in_channel, widen_factor, n_fc, num_classes):
        super(simple_cnn, self).__init__()
        nChannels = [32, 32*widen_factor, 64*widen_factor, 64*widen_factor]

        self.conv1 = nn.Conv2d(in_channel, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(nChannels[0], nChannels[1], kernel_size=3, stride=1,
                               padding=0, bias=False)
        self.relu2 = nn.ReLU(inplace=True)
        self.mp1 = nn.MaxPool2d(2, 2)
        
        
        self.conv3 = nn.Conv2d(nChannels[1], nChannels[2], kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(nChannels[2], nChannels[3], kernel_size=3, stride=1,
                               padding=0, bias=False)
        self.relu4 = nn.ReLU(inplace=True)
        self.mp2 = nn.MaxPool2d(2, 2)
        
        self.nChannels = 6*6*nChannels[3]
        self.fc1 = nn.Linear(self.nChannels, n_fc) ##### 128
        self.relu5 = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(n_fc, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        out = self.mp1(out)
        out = self.relu3(self.conv3(out))
        out = self.relu4(self.conv4(out))
        out = self.mp2(out)
#         print('last mp out',out.shape)
        out = out.reshape(-1, self.nChannels)
        out = self.relu5(self.fc1(out))
        return self.fc2(out)