import torch.nn as nn
from models.utils.utils import make_conv_block


class DepthSeperabelConv2d(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, stride=1, padding=0,
                 activation_generator=None, oper_order='cba', depthwise_acti=True):
        super(DepthSeperabelConv2d, self).__init__()

        if oper_order == 'cba':
            if depthwise_acti:
                oper_order_depth = 'dba'
            else:
                oper_order_depth = 'db'
            oper_order_point = 'pba'
        elif oper_order == 'cab':
            if depthwise_acti:
                oper_order_depth = 'dab'
            else:
                oper_order_depth = 'db'
            oper_order_point = 'pab'

        self.depthwise = make_conv_block(input_channels, input_channels, kernel_size=kernel_size,
                                         stride=stride, padding=padding,
                                         activation_generator=activation_generator, oper_order=oper_order_depth)

        self.pointwise = make_conv_block(input_channels, output_channels, kernel_size=1,
                                         stride=1, padding=padding,
                                         activation_generator=activation_generator, oper_order=oper_order_point)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)

        return x
