import torch.nn as nn
import torch.nn.functional as F



class LeNet(nn.Module):
    def __init__(self, 
                 input_size, 
                 out_size, 
                 in_channels=3,
                 ):
        super().__init__()
        self.conv1  = nn.Conv2d(in_channels, 6, 5)
        self.pool   = nn.MaxPool2d(2, 2)
        input_size  = (input_size - 4) // 2  # conv1 -> pool
        self.conv2  = nn.Conv2d(6, 16, 5)
        input_size  = (input_size - 4) // 2  # conv2 -> pool
        self.fc1    = nn.Linear(16 * input_size ** 2, 120)
        self.fc2    = nn.Linear(120, 84)
        self.fc3    = nn.Linear(84, out_size)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool(out)
        out = F.relu(self.conv2(out))
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out


class ConvNet2(nn.Module):
    def __init__(self, 
                 input_size, 
                 out_size,
                 in_channels=3, 
                 n_kernels=32,
                 hidden=32, 
                 ):
        super().__init__()
        self.conv1  = nn.Conv2d(in_channels, n_kernels, 5, padding=2)
        self.conv2  = nn.Conv2d(n_kernels, n_kernels * 2, 5, padding=2)
        self.pool   = nn.MaxPool2d(2)
        input_size  = input_size // 2   # conv1(5,2)-> pool
        input_size  = input_size // 2   # conv1(5,2) -> pool
        self.fc1 = nn.Linear(n_kernels * 2 * input_size**2, hidden)
        self.fc2 = nn.Linear(hidden, out_size)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool(out)
        out = F.relu(self.conv2(out))
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

class ConvNet2BatchNormDropout(nn.Module):
    def __init__(self, in_channels, h=32, w=32, hidden=2048, out_size=10, use_bn=True, dropout=.0):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        self.use_bn = use_bn
        if use_bn:
            self.bn1 = nn.BatchNorm2d(32)
            self.bn2 = nn.BatchNorm2d(64)

        self.fc1 = nn.Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)
        self.fc2 = nn.Linear(hidden, out_size)

        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(2)
        self.dropout = dropout

    def forward(self, x):
        x = self.bn1(self.conv1(x)) if self.use_bn else self.conv1(x)
        x = self.maxpool(self.relu(x))
        x = self.bn2(self.conv2(x)) if self.use_bn else self.conv2(x)
        x = self.maxpool(self.relu(x))
        x = nn.Flatten()(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.relu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)

        return x

class ConvNet2BN(nn.Module):
    def __init__(self, in_channels, h, w, hidden, out_size):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear((h // 2 // 2) * (w // 2 // 2) * 64, hidden)
        self.fc2 = nn.Linear(hidden, out_size)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = self.maxpool(self.relu(x))
        x = self.bn2(self.conv2(x))
        x = self.maxpool(self.relu(x))
        x = nn.Flatten()(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class ConvNet3(nn.Module):
    def __init__(self, 
                 input_size, 
                 out_size, 
                 in_channels=3, 
                 n_kernels=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, n_kernels, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, n_kernels * 2, 3)
        self.conv3 = nn.Conv2d(n_kernels * 2, n_kernels * 2, 3)
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2)
        self.fc1   = nn.Linear(n_kernels * 2 * input_size * input_size, 64)
        self.fc2   = nn.Linear(64, out_size)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool(out)
        out = F.relu(self.conv2(out))
        out = self.pool(out)
        out = F.relu(self.conv3(out))
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

class ConvNet3BatchNorm(nn.Module):
    def __init__(self, 
                 input_size, 
                 out_size, 
                 in_channels=3, 
                 n_kernels=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, n_kernels, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.bn1 = nn.BatchNorm2d(n_kernels)
        self.conv2 = nn.Conv2d(n_kernels, n_kernels * 2, 3)
        self.bn2 = nn.BatchNorm2d(n_kernels * 2)
        self.conv3 = nn.Conv2d(n_kernels * 2, n_kernels * 2, 3)
        self.bn3 = nn.BatchNorm2d(n_kernels * 2)
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2)
        self.fc1   = nn.Linear(n_kernels * 2 * input_size * input_size, 64)
        self.fc2   = nn.Linear(64, out_size)

    def forward(self, x):
        out = self.bn1(self.conv1(x))
        out = F.relu(out)
        out = self.pool(out)
        out = self.bn2(self.conv2(out))
        out = F.relu(out)
        out = self.pool(out)
        out = self.bn3(self.conv3(out))
        out = F.relu(out)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

class ConvNet_CIFAR100_BN(nn.Module):
    """
    References: https://zhenye-na.github.io/2018/09/28/pytorch-cnn-cifar10.html
    """
    def __init__(self, out_size, in_channels=3):
        super(ConvNet_CIFAR100_BN, self).__init__()

        self.conv_layer = nn.Sequential(

            # Conv Layer block 1
            nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1),
            # nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv Layer block 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            # nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=0.05),

            # Conv Layer block 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            # nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )


        self.fc_layer = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, out_size)
        )


    def forward(self, x):
        """Perform forward."""
        
        # conv layers
        x = self.conv_layer(x)
        
        # flatten
        x = x.view(x.size(0), -1)
        
        # fc layer
        x = self.fc_layer(x)

        return F.log_softmax(x, dim=1)

