import numpy as np


def transpose_depthwise(W: np.ndarray) -> np.ndarray:
    """
    A depthwise kernel has shape :
            K x K x Ci x 1
    This is an equivalent behavior to a conv of shape:
            K x K x 1 x Ci
    For this reason instead of repeating functions with switched axis,
    we transpose the axis of the depthwise once to work as if it was a regualr conv
    and we tranpose once again to get the processed depthwise
    """
    return W.transpose(0, 1, 3, 2)


def flatten_layer(W: np.ndarray) -> np.ndarray:
    """
    In order to limit the number of primitives, we convert conv kernels to matrices
    """
    K1, K2, Ci, Co = W.shape
    return np.reshape(W, newshape=(K1 * K2 * Ci, Co))


def get_bounds(W: np.ndarray) -> np.ndarray:
    """
    This function computes the range of the weight distribution under
    a symmetric prior
    """
    upper_bound = np.max(np.abs(W), axis=0)
    return upper_bound


def get_quantization_range(bits: int) -> int:
    """
    The quantization range is the maximum integer value of the quantized kernel, e.g.
            bits |   Q
            2    |   1
    """
    if bits in [2]:
        return 2 ** (bits - 1) - 1
    else:
        print(
            f"[WARNING] in SQuant : unsupported bit-wise representation {bits}\n\tsupported values are : 2"
        )
        return -1


def get_scale(W: np.ndarray, bits: int) -> int:
    """
    This function computes the quantization scale s such that
            Q(W) = ⌊W/s⌉
    """
    upper_bound = get_bounds(W)
    quantization_range = get_quantization_range(bits=bits)
    scale = np.ones(upper_bound.shape)
    scale[np.where(upper_bound >= 1e-12)] = (
        3 * (quantization_range / upper_bound[np.where(upper_bound >= 1e-12)]) / 2
    )
    return scale, quantization_range


def basic_Q(W: np.ndarray, s: np.ndarray) -> np.ndarray:
    """
    This is the basic quantization operator as defined by
            Q(W) = ⌊W/s⌉
    """
    return np.clip(a=np.round(W * s), a_min=-1, a_max=1)


if __name__ == "__main__":
    K = 10
    W = np.random.uniform(low=-1, high=1, size=(10, 3))
    print(W)
    accumulated_W = 0
    for i in range(K):
        scale, quantization_range = get_scale(W=W - accumulated_W, bits=2)
        new_W = basic_Q(W=W - accumulated_W, s=scale)
        accumulated_W += new_W / scale
        print(f"error at order {i} : {np.sum(np.abs(W-accumulated_W)):.5f}")
