import torch
import timeit
import torch.utils.benchmark as benchmark


@torch.jit.script
def binact(x):
    out_forward = torch.sign(x)
    mask1 = x < -1
    mask2 = x < 0
    mask3 = x < 1
    out1 = (-1) * mask1.type(torch.float32) + (x * x + 2 * x) * (1 - mask1.type(torch.float32))
    out2 = out1 * mask2.type(torch.float32) + (-x * x + 2 * x) * (1 - mask2.type(torch.float32))
    out3 = out2 * mask3.type(torch.float32) + 1 * (1 - mask3.type(torch.float32))
    out = out_forward.detach() - out3.detach() + out3
    return out


class BinAct(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(ctx, g):
        x, = ctx.saved_tensors
        return g * torch.clamp(2 * (1 - torch.abs(x)), min=0.0), None


x = torch.randn(1000, 1024, device='cuda', requires_grad=True)
x.data *= 0.1

binact(x).sum().backward()
a = x.grad.clone()
x.grad = None

BinAct.apply(x).sum().backward()
b = x.grad.clone()
x.grad = None

print((a[a != b] - b[a != b]).abs().max(), x[a!=b])
a = binact(x)
b = BinAct.apply(x)
assert a.allclose(b)


t0 = timeit.Timer(
    stmt='binact(x).sum().backward()',
    setup='from __main__ import binact',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='BinAct.apply(x).sum().backward()',
    setup='from __main__ import BinAct',
    globals={'x': x})

# Ran each twice to show difference before/after warmup
print(f'binact(x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'binact(x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'BinAct.apply(x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'BinAct.apply(x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')