from torch.autograd import Function
import numpy as np
import torch
import quant_cuda

class TestFunction(Function):
    @staticmethod
    def forward(ctx, x):

        b = torch.tensor([0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0]).cuda()
        return quant_cuda.forward(x, b)
    @staticmethod
    def backward(ctx, gradOutput):
        return gradOutput


a = torch.tensor(np.random.uniform(-20,20,100), dtype=torch.float).cuda()
sign = a.sign()
print(a)
a = a.abs()
print("Test...")
c = TestFunction.apply(a) * sign
print(c)