"""
Class used to downsample features by 3*3 conv
"""
import torch.nn as nn
import torch

class DoubleConv(nn.Module):
    """
    Double convoltuion
    Args:
        in_channels: input channel num
        out_channels: output channel num
    """

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                      stride=stride, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class DownsampleConv(nn.Module):
    def __init__(self, config):
        super(DownsampleConv, self).__init__()
        self.layers = nn.ModuleList([])
        input_dim = config['input_dim']

        for (ksize, dim, stride, padding) in zip(config['kernal_size'],
                                                 config['dim'],
                                                 config['stride'],
                                                 config['padding']):
            self.layers.append(DoubleConv(input_dim,
                                          dim,
                                          kernel_size=ksize,
                                          stride=stride,
                                          padding=padding))
            input_dim = dim

    def forward(self, x, return_all_feature = True):
        mid_feature = [x]
        for i in range(len(self.layers)):
            x = self.layers[i](x)
            if return_all_feature:
                mid_feature.append(x)
        if return_all_feature:
            return x, mid_feature
        return x

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    args = {'kernal_size': [3, 3],
            'dim': [128, 128],
            'stride': [2, 2],
            'padding': [1, 1],
            'input_dim': 128
            }
    dconv = DownsampleConv(args)
    dconv.to(device)
    test_data = torch.rand(4, 128, 128, 128)
    test_data = test_data.to(device)

    output, features = dconv(test_data, return_all_feature=True)
    print(output)
