from .utils import (
    get_scale,
    flatten_layer,
    dense_to_conv,
    transpose_depthwise,
    conv_to_dense,
)
from .conv import process_conv_weight
import numpy as np
from typing import Dict, Any, Tuple


def Progressive_SQuant_Algorithm_conv(
    W: np.ndarray, bits: int
) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    This function works over convolutional layers that are not depthwise layers
    """
    scale, quantization_range = get_scale(W=flatten_layer(W=W), bits=bits)
    if quantization_range == -1:
        return None, None, False
    new_W = process_conv_weight(W=W, scale=scale, bits=bits)
    return new_W, scale, True


def Progressive_SQuant_Algorithm_dense(
    W: np.ndarray, bits: int
) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    This function works over dense/fully-connected layers.
    """
    scale, quantization_range = get_scale(W=W, bits=bits)
    if quantization_range == -1:
        return None, None, False
    new_W = process_conv_weight(W=dense_to_conv(W=W), scale=scale, bits=bits)
    new_W = conv_to_dense(W=new_W)
    return new_W, scale, True


def Progressive_SQuant_Algorithm_depthwise(
    W: np.ndarray, bits: int
) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    This function works over convolutional layers that are depthwise layers by calling
    the convolutional function over transposed kernel
    """
    new_W, scale, quantized = Progressive_SQuant_Algorithm_conv(
        W=transpose_depthwise(W=W), bits=bits
    )
    if quantized:
        new_W = transpose_depthwise(W=new_W)
    return new_W, scale, quantized


def Progressive_SQuant_Algorithm(
    W: np.ndarray, layer_type: str, bits: int
) -> Dict[str, Any]:
    """
    This function fetches the adapted SQuant operator depending on the layer type.
    It returns the processed weights, scale and confirms the quantization
    """
    if layer_type == "Conv2D":
        new_W, scale, quantized = Progressive_SQuant_Algorithm_conv(W=W, bits=bits)
    elif layer_type == "Dense":
        new_W, scale, quantized = Progressive_SQuant_Algorithm_dense(W=W, bits=bits)
    elif layer_type == "DepthwiseConv2D":
        new_W, scale, quantized = Progressive_SQuant_Algorithm_depthwise(W=W, bits=bits)
    else:
        print(f"[WARNING] in SQuant : unsupported layer type : {layer_type}")
        new_W = None
        scale = None
        quantized = False
    return {
        "original weights": W,
        "quantized weights": new_W,
        "scale": scale,
        "quantized": quantized,
    }
