import torch.nn as nn
import torch.nn.functional as F
import math

class ConvNet2Resize(nn.Module):
    def __init__(self, in_channels, in_size, out_size, hidden):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=3, padding=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(8, 10, kernel_size=3, padding=1, stride=1, bias=False),
            nn.ReLU(),
            nn.BatchNorm2d(10),
            nn.MaxPool2d(kernel_size=2),
        )
        self.linear_size = in_size // 2 // 2
        self.fc = nn.Sequential(
            nn.Linear(10 * self.linear_size ** 2, hidden),
            nn.ReLU(),
        )
        self.out = nn.Linear(hidden, out_size)
    
    def forward(self, inputs):
        feat = self.conv1(inputs) 
        feat = self.conv2(feat)
        feat = feat.view(feat.shape[0], -1)
        h = self.fc(feat)
        out = self.out(h)
        return out

# class ConvNet2Resize(nn.Module):
#     def __init__(self, inp_dim, out_dim, last_linear_dim):
#         super().__init__()
#         self.inp_dim = inp_dim
#         self.out_dim = out_dim
#         self.n_channel = self.inp_dim[0]
#         assert self.inp_dim[1] == self.inp_dim[2]
#         self.cnn_config = {'kernel_size': 3, 'padding': 1, 'stride': 1}
#         self.pool_config = {'kernel_size': 2}
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(self.n_channel, 8, bias=False, **self.cnn_config),
#             nn.ReLU(),
#             nn.BatchNorm2d(8),
#             nn.MaxPool2d(**self.pool_config),
#         )
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(8, 10, bias=False, **self.cnn_config),
#             nn.ReLU(),
#             nn.BatchNorm2d(10),
#             nn.MaxPool2d(**self.pool_config),
#         )
#         self.last_w = self._compute_last_linear_dim(self.inp_dim[1])
#         self.last_h = self._compute_last_linear_dim(self.inp_dim[2])
#         self.last_inp_dim = 10 * self.last_w * self.last_h
#         print('Last layer size: ({}, {})'.format(self.last_w, self.last_h))
#         self.fc = nn.Sequential(nn.Linear(self.last_inp_dim, last_linear_dim))
#         self.out = nn.Linear(last_linear_dim, self.out_dim)
    
#     def _compute_last_linear_dim(self, inp_dim):
#         res = inp_dim
#         for _ in range(2):
#             res = (res - self.cnn_config['kernel_size'] + 2 * self.cnn_config['padding']) / self.cnn_config['stride']
#             res = int(math.floor(res))
#             res = (res - self.pool_config['kernel_size'] + 1) / self.pool_config['kernel_size']
#             res = int(math.floor(res)) + 1
#         return res

#     def forward(self, inputs):
#         feat = self.conv1(inputs) 
#         feat = self.conv2(feat)
#         feat = feat.view(-1, self.last_inp_dim)
#         h = F.relu(self.fc(feat))
#         out = self.out(h)
#         return out


class ConvNet2(nn.Module):
    def __init__(self, 
                 input_size, 
                 out_size, 
                 in_channels=3, 
                 hidden=64,
                 n_kernels=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, n_kernels, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, n_kernels * 2, 3)
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2)
        self.fc1   = nn.Linear(n_kernels * 2 * input_size * input_size, hidden)
        self.fc2   = nn.Linear(hidden, out_size)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool(out)
        out = F.relu(self.conv2(out))
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out


class ConvNet3(nn.Module):
    def __init__(self, 
                 input_size, 
                 out_size, 
                 in_channels=3,
                 hidden=64,
                 n_kernels=32):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, n_kernels, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, n_kernels * 2, 3)
        self.conv3 = nn.Conv2d(n_kernels * 2, n_kernels * 2, 3)
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2) // 2
        input_size = (input_size - 2)
        self.fc1   = nn.Linear(n_kernels * 2 * input_size * input_size, hidden)
        self.fc2   = nn.Linear(hidden, out_size)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.pool(out)
        out = F.relu(self.conv2(out))
        out = self.pool(out)
        out = F.relu(self.conv3(out))
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

# class ConvNet2(nn.Module):
#     def __init__(self, in_channels, in_size, out_size, hidden):
#         super().__init__()
#         self.conv1 = nn.Sequential(
#             nn.Conv2d(in_channels, 8, kernel_size=3, padding=),
#             nn.ReLU(),
#             nn.BatchNorm2d(8),
#             nn.MaxPool2d(kernel_size=2),
#         )
#         self.conv2 = nn.Sequential(
#             nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=1, bias=False),
#             nn.ReLU(),
#             nn.BatchNorm2d(16),
#             nn.MaxPool2d(kernel_size=2),
#         )
#         self.linear_size = in_size // 2 // 2
#         self.fc = nn.Sequential(
#             nn.Linear(16 * self.linear_size ** 2, hidden),
#             nn.ReLU(),
#         )
#         self.out = nn.Linear(hidden, out_size)
    
#     def forward(self, inputs):
#         feat = self.conv1(inputs) 
#         feat = self.conv2(feat)
#         feat = feat.view(feat.shape[0], -1)
#         h = self.fc(feat)
#         out = self.out(h)
#         return out