from .utils import get_scale, basic_Q, flatten_layer, transpose_depthwise
import numpy as np
from typing import Dict, Any, Tuple


def TQ_dense(W: np.ndarray, bits: int) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    This function works over dense/fully-connected layers
    """
    scale, quantization_range = get_scale(W=W, bits=bits)
    if quantization_range < 0:
        return None, None, False
    new_W = basic_Q(W=W, s=scale)
    return new_W, scale, True


def TQ_conv(W: np.ndarray, bits: int) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    This function works over convolutional layers that are not depthwise layers.
    This function leverages the fully-connected primitive.
    """
    original_shape = W.shape
    new_W, scale, quantized = TQ_dense(W=flatten_layer(W), bits=bits)
    if quantized:
        new_W = np.reshape(new_W, newshape=(original_shape))
    return new_W, scale, quantized


def TQ_depthwise(W: np.ndarray, bits: int) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    This function works over convolutional layers that are depthwise layers.
    This function leverages the convolution primitive.
    """
    new_W, scale, quantized = TQ_conv(W=transpose_depthwise(W=W), bits=bits)
    if quantized:
        new_W = transpose_depthwise(W=new_W)
    return new_W, scale, quantized


def TQuant(W: np.ndarray, layer_type: str, bits: int) -> Dict[str, Any]:
    """
    This function fetches the appropriate function based on the layer_type argument
    It returns the processed weights, scale and confirms the quantization
    """
    if bits != 2:
        print(
            f"[WARNING] you asked to quantize a kernel with TQuant in {bits}bits. Automatic switch to ternary quantization."
        )
        bits = 2
    if layer_type == "Conv2D":
        new_W, scale, quantized = TQ_conv(W=W, bits=bits)
    elif layer_type == "Dense":
        new_W, scale, quantized = TQ_dense(W=W, bits=bits)
    elif layer_type == "DepthwiseConv2D":
        new_W, scale, quantized = TQ_depthwise(W=W, bits=bits)
    else:
        print(f"[WARNING] in TQuant : unsupported layer type : {layer_type}")
        new_W = None
        scale = None
        quantized = False
    return {
        "original weights": W,
        "quantized weights": new_W,
        "scale": scale,
        "quantized": quantized,
    }
