from .check import check_inputs
from .algo import Progressive_SQuant_Algorithm
import numpy as np
from typing import Dict, Any


def SQuant_operator(W: np.ndarray, layer_type: str, bits: int) -> Dict[str, Any]:
    """
    Implementation of the article : SQUANT: ON-THE-FLY DATA-FREE QUANTIZATION VIA DIAGONAL HESSIAN APPROXIMATION
    published at ICLR 2022.

    This is a novel quantization operator which can be leveraged in addition to DRE.
    The current implementation requires
            W: weight tensor (converted in a numpy array)
            layer type:	which is a string of value : 'Conv2D', 'DepthwiseConv2D' or 'Dense'. Otherwise the method returns the weight unmodified.
            bits: number of bits for the quantization representation

    The output of the function is a dictionnary such that
    output = {
            'original weights':W, (np.array)
            'quantized weights':new_W, (np.array)
            'scale':s (np.array)
            'quantized':True (boolean)
    }
    """
    if not check_inputs(W=W, layer_type=layer_type, bits=bits):
        return {"original weights": W, "quantized": False}
    return Progressive_SQuant_Algorithm(W=W, layer_type=layer_type, bits=bits)


if __name__ == "__main__":
    import numpy as np
    import os

    print("testing Q operator (from ICLR 2022)")

    print(
        '\nquantizing Tensor of shape (3,3,2,4) and type "Conv2D" :',
        SQuant_operator(
            W=np.random.uniform(size=(3, 3, 2, 4)), layer_type="Conv2D", bits=4
        )["quantized"],
    )
    output = SQuant_operator(
        W=np.random.uniform(size=(3, 3, 2, 4)), layer_type="Conv2D", bits=4
    )

    print(
        np.sum(
            np.abs(
                output["original weights"]
                - output["quantized weights"] / output["scale"]
            )
        )
    )
    print(
        '\nquantizing Tensor of shape (4,4) and type "Dense" :',
        SQuant_operator(W=np.random.uniform(size=(4, 4)), layer_type="Dense", bits=4)[
            "quantized"
        ],
    )
    print(
        '\nquantizing Tensor of shape (4,4) and type "Dense" :',
        SQuant_operator(W=np.random.uniform(size=(4, 4)), layer_type="Dense", bits=5)[
            "quantized"
        ],
    )
    print(
        '\nquantizing Tensor of shape (4,4) and type "Batch-Norm" :',
        SQuant_operator(
            W=np.random.uniform(size=(4, 4)), layer_type="Batch-Norm", bits=4
        )["quantized"],
    )
