import torch
import torch.nn as nn
import torch.nn.functional as F

class Cifar10CNN(nn.Module):

    def __init__(self):
        super(Cifar10CNN, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))

        self.fc = nn.Linear(4*4*32*4, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)

        x = x.view(x.size(0), -1)
        
        x = self.fc(x)

        return x


    # def __init__(self):
    #     super(Cifar10CNN, self).__init__()

    #     self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
    #     self.bn1 = nn.BatchNorm2d(32)
    #     self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
    #     self.bn2 = nn.BatchNorm2d(32)
    #     self.pool1 = nn.MaxPool2d(kernel_size=2)

    #     self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
    #     self.bn3 = nn.BatchNorm2d(64)
    #     self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
    #     self.bn4 = nn.BatchNorm2d(64)
    #     self.pool2 = nn.MaxPool2d(kernel_size=2)

    #     self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
    #     self.bn5 = nn.BatchNorm2d(128)
    #     self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
    #     self.bn6 = nn.BatchNorm2d(128)
    #     self.pool3 = nn.MaxPool2d(kernel_size=2)

    #     self.fc1 = nn.Linear(128 * 4 * 4, 128)
    #     self.fc2 = nn.Linear(128, 10)

    # def forward(self, x):
    #     x = self.bn1(F.relu(self.conv1(x)))
    #     x = self.bn2(F.relu(self.conv2(x)))
    #     x = self.pool1(x)

    #     x = self.bn3(F.relu(self.conv3(x)))
    #     x = self.bn4(F.relu(self.conv4(x)))
    #     x = self.pool2(x)

    #     x = self.bn5(F.relu(self.conv5(x)))
    #     x = self.bn6(F.relu(self.conv6(x)))
    #     x = self.pool3(x)

    #     x = x.view(-1, 128 * 4 * 4)

    #     x = self.fc1(x)
    #     x = F.softmax(self.fc2(x))

    #     return x
