import numpy as np
from typing import Tuple


def transpose_depthwise(W: np.ndarray) -> np.ndarray:
    """
    A depthwise kernel has shape :
            K x K x Ci x 1
    This is an equivalent behavior to a conv of shape:
            K x K x 1 x Ci
    For this reason instead of repeating functions with switched axis,
    we transpose the axis of the depthwise once to work as if it was
    a regualr conv and we tranpose once again to get the processed depthwise
    """
    return W.transpose(0, 1, 3, 2)


def flatten_layer(W: np.ndarray) -> np.ndarray:
    """
    In order to limit the number of primitives, we convert conv kernels
    to matrices
    """
    K1, K2, Ci, Co = W.shape
    return np.reshape(W, newshape=(K1 * K2 * Ci, Co))


def get_bounds(W: np.ndarray) -> np.ndarray:
    """
    This function computes the range of the weight distribution under
    a symmetric prior
    """
    upper_bound = np.max(np.abs(W), axis=0)
    return upper_bound


def get_quantization_range(bits: int) -> int:
    """
    The quantization range is the maximum integer value
    of the quantized kernel, e.g.
            bits |   Q
            8    |  127
            4    |   7
            3    |   3
            2    |   1
    """
    if bits == -1:
        bits = 8
    if bits in [2, 3, 4, 5, 6, 7, 8]:
        return 2 ** (bits - 1) - 1
    else:
        print(f"[WARNING] in Basic Quant : unsupported bit-wise ")
        print(f"representation {bits}\n\tsupported values are : 2,3,4,6,8")
        return -1


def get_scale(W: np.ndarray, bits: int) -> Tuple[np.ndarray, int]:
    """
    This function computes the quantization scale s such that
            Q(W) = ⌊W/s⌉
    """
    upper_bound = get_bounds(W)
    quantization_range = get_quantization_range(bits=bits)
    scale = np.ones(upper_bound.shape)
    scale[np.where(upper_bound >= 1e-12)] = (
        quantization_range / upper_bound[np.where(upper_bound >= 1e-12)]
    )
    return scale, quantization_range


def basic_Q(W: np.ndarray, s: np.ndarray) -> np.ndarray:
    """
    This is the basic quantization operator as defined by
            Q(W) = ⌊W/s⌉
    """
    return np.round(W * s)
