import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
from torch.autograd import Function


def shift(x):
    # TODO: edge case, when x contains 0
    return 2. ** torch.round(torch.log2(x))


def S(bits):
    return 2. ** (bits - 1)


def SR(x):
    r = x.new_zeros(x.size()).uniform_()
    return torch.floor(x + r)


def C(x, bits):
    if bits > 15 or bits == 1:
        delta = 0
    else:
        delta = 1. / S(bits)
    upper = 1 - delta
    lower = -1 + delta
    return torch.clamp(x, lower, upper)


def Q(x, bits):
    assert bits != -1
    if bits == 1:
        return torch.sign(x)
    if bits > 15:
        return x
    return torch.round(x * S(bits)) / S(bits)


def QW(x, bits):
    y = Q(C(x, bits), bits)
    return y


def QE(x, bits):
    max_entry = torch.maximum(x.abs().max(), torch.tensor(2**(1-bits), device=x.device))
    x /= shift(max_entry)
    return Q(C(x, bits), bits)


def QG(x, bits_G, bits_R, lr):
    max_entry = torch.maximum(x.abs().max(), torch.tensor(2**(1-bits_G), device=x.device))
    x /= shift(max_entry)
    norm = lr * x
    norm = SR(norm)  # integer between [-lr, lr]
    return norm / S(bits_G)


def get_real_lr(x, bits_G, bits_R, lr):
    max_entry = x.abs().max()
    return lr / shift(max_entry) / S(bits_G)


class WAGERounding(Function):
    @staticmethod
    def forward(self, x, bits_A, bits_E):
        self.bits_E = bits_E
        self.save_for_backward(x)

        if bits_A == -1:
            ret = x
        else:
            ret = Q(x, bits_A)

        return ret

    @staticmethod
    def backward(self, grad_output):
        if self.bits_E == -1:
            return grad_output, None, None, None

        if self.needs_input_grad[0]:
            grad_input = QE(grad_output, self.bits_E)
        else:
            grad_input = grad_output

        return grad_input, None, None, None


quantize_wage = WAGERounding.apply


class WAGEQuantizer(Module):
    def __init__(self, bits_A, bits_E):
        super(WAGEQuantizer, self).__init__()
        self.bits_A = bits_A
        self.bits_E = bits_E

    def forward(self, x):
        if self.bits_A != -1:
            x = C(x, self.bits_A)  # keeps the gradients
        y = quantize_wage(x, self.bits_A, self.bits_E)
        return y
