import torch
from torch.autograd import Function
import torch.nn as nn

class BNNBinarizeF(Function):
    
    @staticmethod
    def forward(ctx, inputs, beta):
        ctx.beta = float(beta)
        ctx.save_for_backward(inputs)
        return inputs.sign()*inputs.abs().mean() + (inputs == 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        b = torch.zeros(grad_output.size())
        b=b.to('cuda')
        grad_output = torch.where(grad_output<1,grad_output,b)
        grad_output = torch.where(grad_output>-1,grad_output,b)
        return grad_output


bnn_binarize = BNNBinarizeF.apply
