from .NumpyOperators.SQuant_ICLR22.ops import SQuant_operator
from .NumpyOperators.TQuant.ops import TQuant_Operator
from .NumpyOperators.BasicQuant.ops import Basic_Quant_Operator
from .NumpyOperators.MQuant.ops import MQuant_Operator
import numpy as np
import tensorflow as tf
from typing import Any, Dict, Tuple


@tf.keras.utils.register_keras_serializable()
class QWeightInitializer(tf.keras.initializers.Initializer):
    def __init__(
        self,
        layer_name: str,
        weight_tensor: np.ndarray,
        operator: str,
        layer_type: str,
        bits: int,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.layer_name = layer_name
        self.operator = operator
        self.layer_type = layer_type
        self.bits = bits
        if operator.lower() == "squant":
            output = SQuant_operator(W=weight_tensor, layer_type=layer_type, bits=bits)
            self.weight_tensor = np.float32(
                output["quantized weights"] / output["scale"]
            )
        else:
            self.weight_tensor = np.float32(weight_tensor)

    def __call__(
        self, shape: tuple, dtype=None, *args: Any, **kwargs: Any
    ) -> tf.Tensor:
        return tf.constant(
            value=self.weight_tensor, name=self.layer_name + f"/weight_initializer"
        )

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["layer_name"] = self.layer_name
        config["weight_tensor"] = self.weight_tensor
        config["operator"] = self.operator
        config["layer_type"] = self.layer_type
        config["bits"] = self.bits
        return config


@tf.keras.utils.register_keras_serializable()
class QBiasInitializer(tf.keras.initializers.Initializer):
    def __init__(
        self, layer_name: str, bias_tensor: np.ndarray, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.layer_name = layer_name
        self.bias_tensor = bias_tensor

    def __call__(
        self, shape: tuple, dtype=None, *args: Any, **kwargs: Any
    ) -> tf.Tensor:
        return tf.constant(
            value=self.bias_tensor, name=self.layer_name + f"/bias_initializer"
        )

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["layer_name"] = self.layer_name
        config["bias_tensor"] = self.bias_tensor
        return config


@tf.keras.utils.register_keras_serializable()
class QGammaInitializer(tf.keras.initializers.Initializer):
    def __init__(
        self, layer_name: str, gamma_tensor: np.ndarray, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.layer_name = layer_name
        self.gamma_tensor = gamma_tensor

    def __call__(
        self, shape: tuple, dtype=None, *args: Any, **kwargs: Any
    ) -> tf.Tensor:
        return tf.constant(
            value=self.gamma_tensor, name=self.layer_name + f"/gamma_initializer"
        )

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["layer_name"] = self.layer_name
        config["gamma_tensor"] = self.gamma_tensor
        return config


@tf.keras.utils.register_keras_serializable()
class QBetaInitializer(tf.keras.initializers.Initializer):
    def __init__(
        self, layer_name: str, beta_tensor: np.ndarray, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.layer_name = layer_name
        self.beta_tensor = beta_tensor

    def __call__(
        self, shape: tuple, dtype=None, *args: Any, **kwargs: Any
    ) -> tf.Tensor:
        return tf.constant(
            value=self.beta_tensor, name=self.layer_name + f"/beta_initializer"
        )

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["layer_name"] = self.layer_name
        config["beta_tensor"] = self.beta_tensor
        return config


def get_qlayer_initializers(
    layer: tf.keras.layers.Layer, config: Dict[str, Any]
) -> Tuple[tf.keras.initializers.Initializer, tf.keras.initializers.Initializer]:
    """
    This function creates the adequate initializers either based on full-precision weights
    or for training from scratch.
    """
    layer_type = "Dense"
    if isinstance(layer, tf.keras.layers.DepthwiseConv2D):
        layer_type = "DepthwiseConv2D"
    elif isinstance(layer, tf.keras.layers.Conv2D):
        layer_type = "Conv2D"
    weight_init = "glorot_uniform"
    bias_init = "zeros"
    if config["fp weights init"]:
        weights_list = layer.get_weights()
        weight_init = QWeightInitializer(
            weight_tensor=weights_list[0],
            layer_name=layer.name,
            operator=config["operator"],
            layer_type=layer_type,
            bits=config["weights bits"],
        )
        if "use_bias" in layer.get_config():
            if layer.get_config()["use_bias"]:
                bias_init = QBiasInitializer(
                    bias_tensor=weights_list[1], layer_name=layer.name
                )
    return [weight_init, bias_init]
