import torch
from layers.base_layer import BaseBatchNorm2d, BaseConv2d, BaseActivation

class VanillaBatchNorm2d(BaseBatchNorm2d):
    def __init__(self, num_features, conv_module=None):
        super(VanillaBatchNorm2d, self).__init__(num_features=num_features, conv_module=conv_module)
    
    @staticmethod
    def get_from_batchnorm(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        return VanillaBatchNorm2d(batchnorm_layer.num_features, conv_module=conv_module)
    
    def n_remaining(self, reduction='sum'):
        remaining = torch.ones(self.num_features).type_as(self.weight)
        return remaining.sum() if reduction == 'sum' else remaining

class VanillaConv2d(BaseConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(VanillaConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

    @staticmethod
    def get_from_conv(conv_module, is_skip_connection=False, is_first_layer=False, keep_full_precision=False):
        return VanillaConv2d(conv_module.in_channels, conv_module.out_channels, conv_module.kernel_size, \
            conv_module.stride, conv_module.padding, conv_module.dilation, conv_module.groups, conv_module.bias)

class VanillaActivation(BaseActivation):
    @staticmethod
    def get_from_bn_conv(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        return torch.nn.ReLU()