import torch
import torch.nn as nn
from torch.autograd import Variable
import pdb

class sgn_with_grad(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.sign(x)

    @staticmethod
    def backward(cts, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.abs() > 1] = 0
        return grad_input

def quantize_tensor(x, bits, s, zp=0):
    if bits > 1:
        qmin = 0. - 2**(bits-1)
        qmax = 2.**bits - 1 + qmin
        min_val, max_val = x.min(), x.max()

        scale = s
        zero_point = int(zp/scale)
        q_x = x / scale - zero_point
        q_x.clamp_(qmin, qmax).round_()
    return q_x, scale, zero_point

def dequantize_tensor(q_x, scale, zp):
    return scale * (q_x + zp)


class quantize(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, bits, s, zp):
        #ctx.save_for_backward(input)
        q_x, scale, zp = quantize_tensor(input, bits, s, zp)
        q_x = dequantize_tensor(q_x, scale, zp)
        return q_x

    @staticmethod
    def backward(cts, grad_output):
        return grad_output, None, None, None

class binarize(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        q_x = torch.sign(input) 
        return q_x 

    @staticmethod
    def backward(cts, grad_output):
        return grad_output

class ternarize(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        input_ = input.view(input.shape[0], -1)
        thred = (input_.abs().mean(-1) * 0.7).view(-1,1,1,1)
        q_x = torch.sign(input) 
        q_x[input.abs() <= thred] = 0 
        return q_x 

    @staticmethod
    def backward(cts, grad_output):
        #input, = ctx.saved_tensors
        return grad_output

class round_back(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        #ctx.save_for_backward(input)
        #return input.round()
        return torch.floor(input)

    @staticmethod
    def backward(cts, grad_output):
        #input, = ctx.saved_tensors
        return grad_output

class no_grad_mul(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input, s):
        #ctx.save_for_backward(input)
        return input * s

    @staticmethod
    def backward(ctx, grad_output):
        #input, = ctx.saved_tensors
        return grad_output, None

class pact(torch.autograd.Function):
	
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.save_for_backward(x, alpha)
        return x.clamp(min=0.0).min(alpha)
	
    @staticmethod
    def backward(ctx, dLdy):
        x, alpha = ctx.saved_variables
        lt0 = x < 0
        gta = x > alpha
        gi = 1.0-lt0.float()-gta.float()
        dLdx = dLdy*gi
        dLdalpha = torch.sum(dLdy*x.ge(alpha).float())
        return dLdx, dLdalpha

