from .check import check_inputs
from .algo import TQuant
from typing import Dict, Any
import numpy as np


def TQuant_Operator(W: np.ndarray, layer_type: str, bits: int) -> Dict[str, Any]:
    """
    This package implements the basic quantization operator as defined in https://arxiv.org/abs/1806.08342
    This operator finds a scale s and performs the following operation:
            Q(W) = ⌊W/s⌉
    This function takes as inputs:
            W: numpy array weight tensor
            layer_type: string which tells us the kind of layer we are processing
            bits: the target bit size representation
    It returns a dictionnary with at least the two following keys:
            'original weights': W
            'quantized': boolean asserting if the process went well
    """
    if not check_inputs(W=W, layer_type=layer_type, bits=bits):
        return {"original weights": W, "quantized": False}
    return TQuant(W=W, layer_type=layer_type, bits=bits)


if __name__ == "__main__":
    import numpy as np
    import os

    print("testing ternary quantization operator")

    print(
        '\nquantizing Tensor of shape (3,3,2,4) and type "Conv2D" :',
        TQuant_Operator(
            W=np.random.uniform(size=(3, 3, 2, 4)), layer_type="Conv2D", bits=4
        )["quantized"],
    )
    print(
        '\nquantizing Tensor of shape (4,4) and type "Dense" :',
        TQuant_Operator(W=np.random.uniform(size=(4, 4)), layer_type="Dense", bits=4)[
            "quantized"
        ],
    )
    print(
        '\nquantizing Tensor of shape (4,4) and type "Dense" :',
        TQuant_Operator(W=np.random.uniform(size=(4, 4)), layer_type="Dense", bits=5)[
            "quantized"
        ],
    )
    print(
        '\nquantizing Tensor of shape (4,4) and type "Batch-Norm" :',
        TQuant_Operator(
            W=np.random.uniform(size=(4, 4)), layer_type="Batch-Norm", bits=4
        )["quantized"],
    )
