import abc
from torch import nn

class BaseBatchNorm2d(abc.ABC, nn.BatchNorm2d):
    def __init__(self, num_features, conv_module=None):
        super(BaseBatchNorm2d, self).__init__(num_features=num_features)
        if conv_module is not None:
            def calculate_output_area(module, in_tensor, out_tensor):
                module.output_area = out_tensor.size(2) * out_tensor.size(3)
            conv_module.register_forward_hook(calculate_output_area)
        self._conv_module = conv_module
    
    @staticmethod
    @abc.abstractmethod
    def get_from_batchnorm(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        pass
    
    @abc.abstractmethod
    def n_remaining(self, reduction='sum'):
        pass

class BaseConv2d(abc.ABC, nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(BaseConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

    @staticmethod
    @abc.abstractmethod
    def get_from_conv(conv_module, is_skip_connection=False, is_first_layer=False, keep_full_precision=False):
        pass

class BaseActivation(abc.ABC, nn.Module):

    @staticmethod
    @abc.abstractmethod
    def get_from_bn_conv(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        pass