import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.builder import get_builder
from args import args

class SmallCNN(nn.Module):
    def __init__(self, builder=get_builder(), num_classes=10, width=128):
        super(SmallCNN, self).__init__()
        planes = width
        self.conv1 = builder.conv3x3(3, planes, stride=1)
        self.conv2 = builder.conv3x3(planes, planes, stride=1)
        self.conv3 = builder.conv3x3(planes, planes, stride=1)
        self.conv4 = nn.Linear(planes * 4, num_classes)
        self.bn1 = builder.batchnorm(planes)
        self.bn2 = builder.batchnorm(planes)
        self.bn3 = builder.batchnorm(planes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.avg_pool2d(x, 16)
        x = x.view(x.shape[0], -1)
        # print(x.shape, self.conv4.weight.shape)
        x = self.conv4(x)
        return x


class CNN5(nn.Module):
    def __init__(self, builder=get_builder(), num_classes=10, width=128):
        super(CNN5, self).__init__()
        planes = width
        
        self.conv1 = builder.conv3x3(3, planes, stride=1)
        self.conv2 = builder.conv3x3(planes, planes, stride=1)
        self.conv3 = builder.conv3x3(planes, planes, stride=1)
        self.conv4 = builder.conv3x3(planes, planes, stride=1)
        self.conv5 = builder.conv3x3(planes, planes, stride=1)
        self.conv6 = builder.conv1x1(planes * 4, num_classes)
        self.bn1 = builder.batchnorm(planes)
        self.bn2 = builder.batchnorm(planes)
        self.bn3 = builder.batchnorm(planes)
        self.bn4 = builder.batchnorm(planes)
        self.bn5 = builder.batchnorm(planes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.relu(self.bn5(self.conv5(x)))
        # print(x.shape)
        x = F.avg_pool2d(x, 16)
        # print(x.shape)
        x = x.view(x.shape[0], -1)[:, :, None, None]
        # print(x.shape)
        x = self.conv6(x).squeeze()
        return x
