import torch
import torch.nn.functional as F
from layers.base_layer import BaseBatchNorm2d, BaseConv2d, BaseActivation
from layers.vanilla import VanillaConv2d, VanillaBatchNorm2d

class HardTanh(BaseActivation):
    def __init__(self):
        super(HardTanh, self).__init__()
        self.hardtanh = torch.nn.Hardtanh()

    def forward(self, x):
        return self.hardtanh(x)

    @staticmethod
    def get_from_bn_conv(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        return HardTanh()

class BinaryActivation(BaseActivation):
    def __init__(self):
        super(BinaryActivation, self).__init__()

    def forward(self, x):
        out_forward = torch.sign(x)
        mask1 = x < -1
        mask2 = x < 0
        mask3 = x < 1
        out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
        out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
        out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
        out = out_forward.detach() - out3.detach() + out3
        return out

    @staticmethod
    def get_from_bn_conv(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        return HardTanh()

class HardBinaryConv2d(BaseConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(HardBinaryConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.activation = BinaryActivation()

    def forward(self, x):
        x = self.activation(x)
        real_weights = self.weight
        scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights), dim=3, keepdim=True), dim=2, keepdim=True), dim=1, keepdim=True)
        scaling_factor = scaling_factor.detach()
        binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
        cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
        binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
        y = F.conv2d(x, binary_weights, self.bias,
                        self.stride, self.padding,
                        self.dilation, self.groups)
        return y

    @staticmethod
    def get_from_conv(conv_module, is_skip_connection=False, is_first_layer=False, keep_full_precision=False):
        if is_first_layer or is_skip_connection:
            return VanillaConv2d.get_from_conv(conv_module, is_skip_connection, is_first_layer, keep_full_precision)
        else:
            if keep_full_precision:
                return FullPrecisionConv2dWithActivation.get_from_conv(conv_module, \
                                                    is_skip_connection, is_first_layer, \
                                                    keep_full_precision)
            else:
                return HardBinaryConv2d(conv_module.in_channels, \
                    conv_module.out_channels, conv_module.kernel_size, \
                    conv_module.stride, conv_module.padding, conv_module.dilation, \
                    conv_module.groups, False)

class BinaryBatchNorm2d(BaseBatchNorm2d):
    '''
    Copied from layers.vanilla.VanillaBatchNorm2d with VanillaBatchNorm2d -> BinaryBatchNorm2d
    '''
    def __init__(self, num_features, conv_module=None):
        super(BinaryBatchNorm2d, self).__init__(num_features=num_features, conv_module=conv_module)

    @staticmethod
    def get_from_batchnorm(batchnorm_layer, conv_module, is_skip_connection=False, is_first_layer=False):
        if is_skip_connection or is_first_layer:
            return VanillaBatchNorm2d.get_from_batchnorm(batchnorm_layer, conv_module, is_skip_connection, is_first_layer)
        else: 
            if isinstance(conv_module, HardBinaryConv2d):
                return BinaryBatchNorm2d(batchnorm_layer.num_features, conv_module=conv_module)
            else:
                return VanillaBatchNorm2d.get_from_batchnorm(batchnorm_layer, conv_module, is_skip_connection, is_first_layer)

    def n_remaining(self, reduction='sum'):
        remaining = torch.ones(self.num_features).type_as(self.weight)
        return remaining.sum() if reduction == 'sum' else remaining

class FullPrecisionConv2dWithActivation(VanillaConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
        super(FullPrecisionConv2dWithActivation, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.activation = HardTanh()
        
    def forward(self, x):
        x = self.activation(x)
        y = super(FullPrecisionConv2dWithActivation, self).forward(x)
        return y

    @staticmethod
    def get_from_conv(conv_module, is_skip_connection=False, is_first_layer=False, keep_full_precision=False):
        return FullPrecisionConv2dWithActivation(conv_module.in_channels, conv_module.out_channels, conv_module.kernel_size, \
            conv_module.stride, conv_module.padding, conv_module.dilation, conv_module.groups, conv_module.bias)