import numpy as np
from typing import Tuple


def transpose_depthwise(W: np.ndarray) -> np.ndarray:
    """
    A depthwise kernel has shape :
            K x K x Ci x 1
    This is an equivalent behavior to a conv of shape:
            K x K x 1 x Ci
    For this reason instead of repeating functions with switched axis,
    we transpose the axis of the depthwise once to work as if it was a regualr conv
    and we tranpose once again to get the processed depthwise
    """
    return W.transpose(0, 1, 3, 2)


def flatten_layer(W: np.ndarray) -> np.ndarray:
    """
    In order to limit the number of primitives, we convert conv kernels to matrices
    """
    K1, K2, Ci, Co = W.shape
    return np.reshape(W, newshape=(K1 * K2 * Ci, Co))


def get_quantization_range(bits: int) -> int:
    """
    The quantization range is the maximum integer value of the quantized kernel, e.g.
            2    |   1
    """
    if bits in [2]:
        return 2 ** (bits - 1) - 1
    else:
        print(
            f"[WARNING] in SQuant : unsupported bit-wise representation {bits}\n\tsupported value is : 2"
        )
        bits = 2
        return 2 ** (bits - 1) - 1


def get_scale(W: np.ndarray, bits: int) -> Tuple[np.ndarray, int, np.ndarray]:
    """
    This function computes the quantization scale s such that
            Q(W) = ⌊W/s⌉
    In mass quantization, we want a third of the values to be assigned -1, another third 0
    and the remaining to 1.
    To do so, we use the rescaling value equal to twice the 1/3 percentile of the absolute
    weight values.
    """
    quantization_range = get_quantization_range(bits=bits)
    scale = np.percentile(a=np.abs(W), q=33, axis=0)
    output_scale = np.ones(shape=(W.shape[-1],))
    for row in range(W.shape[-1]):
        output_scale[row] = np.mean(
            np.abs(W[:, row])[np.where(np.abs(W[:, row]) > scale[row])]
        )
    return scale, quantization_range, 1 / output_scale


def M_Q(W: np.ndarray, s: np.ndarray) -> np.ndarray:
    """
    This is the basic quantization operator as defined by
            Q(W) = ⌊W/s⌉
    """
    new_W = np.copy(W)
    for row in range(W.shape[-1]):
        new_W[:, row][np.where(np.abs(new_W[:, row]) < s[row])] = 0
    new_W[new_W < 0] = -1
    new_W[new_W > 0] = 1
    return np.clip(a=new_W, a_min=-1, a_max=1)


if __name__ == "__main__":
    K = 10
    W = np.random.uniform(low=-1, high=1, size=(10, 3))
    print(W)
    accumulated_W = 0
    for i in range(K):
        scale, quantization_range, output_scale = get_scale(W=W - accumulated_W, bits=2)
        new_W = M_Q(W=W - accumulated_W, s=scale)
        accumulated_W += new_W / output_scale
        print(f"error at order {i} : {np.sum(np.abs(W-accumulated_W)):.5f}")
