# model for CIFAR 10

import torch
import torch.nn as nn
import torch.nn.functional as F


class CNN_CIFAR(nn.Module):
    def __init__(self, dim):
        super(CNN_CIFAR, self).__init__()
        self.input_dim = dim
        self.num_classifier = 1
        self.conv1 = nn.Conv2d(3, 96, 3)
        self.conv2 = nn.Conv2d(96, 96, 3, stride=2)
        self.conv3 = nn.Conv2d(96, 192, 1)
        self.conv4 = nn.Conv2d(192, 10, 1)
        self.fc1 = nn.Linear(1960, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.af = F.relu
        self.classifier = nn.Linear(1000, 1)

    def forward(self, x):
        h = self.conv1(x)
        h = self.af(h)
        h = self.conv2(h)
        h = self.af(h)
        h = self.conv3(h)
        h = self.af(h)
        h = self.conv4(h)
        h = self.af(h)
        h = h.view(-1, 1960)
        h = self.fc1(h)
        h = self.af(h)
        h = self.fc2(h)
        h = self.af(h)
        return self.classifier(h)


class CNN_STL(nn.Module):
    def __init__(self, dim):
        super(CNN_STL, self).__init__()
        self.input_dim = dim
        self.num_classifier = 1
        self.conv1 = nn.Conv2d(3, 6, 3)
        self.conv2 = nn.Conv2d(6, 6, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(6, 16, 5)
        self.conv4 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.sigmoid = nn.Sigmoid()
        self.m = nn.Dropout2d(0.2)
        self.n = nn.Dropout(0.2)
        self.b1 = nn.BatchNorm2d(6)
        self.b2 = nn.BatchNorm2d(16)
        self.b3 = nn.BatchNorm1d(120)
        self.b4 = nn.BatchNorm1d(84)

        self.classifier = nn.Linear(84, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.classifier(x)


def cnn_cifar(num_classifier=1):
    return CNN_CIFAR(num_classifier)


def cnn_stl(num_classifier=1):
    return CNN_STL(num_classifier)
