import numpy as np


def transpose_depthwise(W: np.ndarray) -> np.ndarray:
    """
    A depthwise kernel has shape :
            K x K x Ci x 1
    This is an equivalent behavior to a conv of shape:
            K x K x 1 x Ci
    For this reason instead of repeating functions with switched axis,
    we transpose the axis of the depthwise once to work as if it was a regualr conv
    and we tranpose once again to get the processed depthwise
    """
    return W.transpose(0, 1, 3, 2)


def flatten_layer(W: np.ndarray) -> np.ndarray:
    """
    In order to limit the number of primitives, we convert conv kernels to matrices
    """
    K1, K2, Ci, Co = W.shape
    return np.reshape(W, newshape=(K1 * K2 * Ci, Co))


def get_bounds(W: np.ndarray) -> np.ndarray:
    """
    This function computes the range of the weight distribution under
    a symmetric prior
    """
    upper_bound = np.max(np.abs(W), axis=0)
    return upper_bound


def get_quantization_range(bits: int) -> int:
    """
    The quantization range is the maximum integer value of the quantized kernel, e.g.
            bits |   Q
            8    |  127
            4    |   7
            3    |   3
            2    |   1
    """
    if bits in [2, 3, 4, 6, 8]:
        return 2 ** (bits - 1) - 1
    else:
        print(
            f"[WARNING] in SQuant : unsupported bit-wise representation {bits}\n\tsupported values are : 2,3,4,6,8"
        )
        return -1


def get_scale(W: np.ndarray, bits: int) -> np.ndarray:
    """
    This function computes the quantization scale s such that
            Q(W) = ⌊W/s⌉
    """
    upper_bound = get_bounds(W)
    quantization_range = get_quantization_range(bits=bits)
    scale = np.ones(upper_bound.shape)
    scale[np.where(upper_bound >= 1e-12)] = (
        quantization_range / upper_bound[np.where(upper_bound >= 1e-12)]
    )
    return scale, quantization_range


def basic_Q(W: np.ndarray, s: np.ndarray) -> np.ndarray:
    """
    This is the basic quantization operator as defined by
            Q(W) = ⌊W/s⌉
    """
    return np.round(W * s)


def argmax_k(array: np.ndarray, k: int) -> np.ndarray:
    """
    This function returns the indices of the k highest values of an array
    This is used in the SQuant_Flip_Algorithm
    """
    return np.argpartition(array, len(array) - k)[-k:]


def argmax_k_sorted(array: np.ndarray, k: int) -> np.ndarray:
    """
    This function returns the sorted indices of the k highest values of an array
    This is used in the Perturbation_Update_Algorithm
    """
    idx = np.argpartition(array, len(array) - k)[-k:]
    return idx[np.argsort((-array)[idx])]


def dense_to_conv(W: np.ndarray) -> np.ndarray:
    """
    In order to limit the number of primitives we will use convolution primitives for dense layers.
    To do so we need to add virtual kernel sizes of one
    """
    return np.expand_dims(np.expand_dims(W, axis=0), axis=0)


def conv_to_dense(W: np.ndarray) -> np.ndarray:
    """
    In order to limit the number of primitives we will use convolution primitives for dense layers.
    To do so we need to remove virtual kernel sizes of one afterwards
    """
    return W[0, 0, :, :]
