import tensorflow as tf
from typing import Any, Dict, Tuple
from .QLinear import *
from .QUpdateEMA import *
from .QUtils import *
from .QOverflow import ModuLayer


@tf.keras.utils.register_keras_serializable()
class QBlockDense(tf.keras.layers.Layer):
    def __init__(
        self,
        name: str,
        units: int,
        power_exponent: float = 1.0,
        activation_bits: int = 8,
        weight_bits: int = 8,
        accumulator_bits: int = 32,
        use_bias: bool = True,
        activation: str = "linear",
        per_channel: bool = False,
        deterministic: bool = True,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        pooling: Any = None,
        batchnormalization: bool = False,
        weight_initializer: tf.keras.initializers.Initializer = "glorot_uniform",
        bias_initializer: tf.keras.initializers.Initializer = "zeros",
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.ema = 0.99
        self.core_name = name
        self.units = units
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.use_bias = use_bias
        self.activation = activation.lower()
        self.per_channel = per_channel
        self.power_exponent = power_exponent
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializer = weight_initializer
        self.bias_initializer = bias_initializer
        self.pooling = None
        if pooling is not None:
            self.pooling = type(pooling).from_config(pooling.get_config())
        self.batchnormalization = batchnormalization
        if batchnormalization:
            self.batchnormalization_layer = tf.keras.layers.BatchNormalization()

        self.M_x = convert_num_bits_to_range(num_bits=activation_bits)
        self.M_w = convert_num_bits_to_range(num_bits=weight_bits)
        self.M_a = convert_num_bits_to_range(num_bits=accumulator_bits)

        if self.accumulator_bits != 32:
            self.cyclical_activation = ModuLayer(
                representation_bits=accumulator_bits,
                use_ste=ste_overflow,
                alpha=self.cyclical_alpha,
            )

        self.axis_per_channel_weights = 0
        self.axis_per_channel_activations = 0

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["name"] = self.core_name
        config["units"] = self.units
        config["power_exponent"] = self.power_exponent
        config["activation_bits"] = self.activation_bits
        config["weight_bits"] = self.weight_bits
        config["accumulator_bits"] = self.accumulator_bits
        config["use_bias"] = self.use_bias
        config["activation"] = self.activation
        config["per_channel"] = self.per_channel
        config["deterministic"] = self.deterministic
        config["ste_overflow"] = self.ste_overflow
        config["cyclical_alpha"] = self.cyclical_alpha
        config["pooling"] = self.pooling
        config["batchnormalization"] = self.batchnormalization
        config["weight_initializer"] = self.weight_initializer
        config["bias_initializer"] = self.bias_initializer
        return config

    def build(self, input_shape: tuple):
        """
        builds the trainable and non-trainable variables of the model
        by first computing their shapes based on hyper-parameters and input shape
        before calling self.build_parameters
        """
        self.kernel = self.add_weight(
            name=self.core_name + "/kernel",
            shape=(input_shape[-1], self.units),
            initializer=self.weight_initializer,
            trainable=True,
        )
        if self.use_bias:
            self.b = self.add_weight(
                name=self.core_name + "/bias",
                shape=(self.units,),
                initializer=self.bias_initializer,
                trainable=True,
            )
        self.weight_scales = self.add_weight(
            name=self.core_name + "/kernel_scale",
            shape=(self.units,),
            initializer="ones",
            trainable=False,
        )
        self.input_scales = self.add_weight(
            name=self.core_name + "/input_scale",
            shape=(input_shape[-1],) if self.per_channel else (1,),
            initializer="ones",
            trainable=False,
        )
        if self.accumulator_bits == 32:
            self.output_scales = self.add_weight(
                name=self.core_name + "/output_scale",
                shape=(self.units,),
                initializer="ones",
                trainable=False,
            )

    def update_parameter(self, param_name: str, new_value: Any):
        """
        changes the value of a parameter given its name with the new
        provided value
        """
        if param_name == "ema":
            self.ema = new_value
        elif param_name == "alpha":
            self.cyclical_alpha = new_value
        else:
            raise NotImplementedError(
                f"requested to update {param_name}, but only 'ema' and 'alpha' are parameters to update"
            )

    def get_operation_scales(
        self, training: bool, input: tf.Variable
    ) -> Tuple[tf.Variable, tf.Variable]:
        """
        gets and updates based on inference or training mode.
        """
        weight_scales_to_use = tf.cond(
            pred=training,
            true_fn=lambda: update_weight_scales(
                weights=self.kernel,
                weight_scales=self.weight_scales,
                exponent=self.power_exponent,
                weight_bits=self.weight_bits,
                axis_per_channel=self.axis_per_channel_weights,
            ),
            false_fn=lambda: self.weight_scales,
        )
        input_scales_to_use = tf.cond(
            pred=training,
            true_fn=lambda: update_input_scales(
                I=input,
                per_channel=self.per_channel,
                input_scales=self.input_scales,
                exponent=self.power_exponent,
                activation_bits=self.activation_bits,
                ema=self.ema,
                axis_per_channel=self.axis_per_channel_activations,
            ),
            false_fn=lambda: self.input_scales,
        )
        return weight_scales_to_use, input_scales_to_use

    def get_output_scales(self, training: bool, output: tf.Variable) -> tf.Variable:
        return tf.cond(
            pred=training,
            true_fn=lambda: update_input_scales(
                I=output,
                per_channel=self.per_channel,
                input_scales=self.output_scales,
                exponent=self.power_exponent,
                activation_bits=self.activation_bits,
                ema=self.ema,
                axis_per_channel=self.axis_per_channel_activations,
            ),
            false_fn=lambda: self.output_scales,
        )

    def update_cyclical_parameters(self, alpha: Any = None, convex: Any = None) -> None:
        if alpha is not None:
            self.cyclical_activation.alpha = alpha
        if convex is not None:
            self.cyclical_activation.convex_combine = convex

    def call(self, input: tf.Variable, training=None, *args: Any, **kwargs: Any) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)

        weight_scales_to_use, input_scales_to_use = self.get_operation_scales(
            training=training, input=input
        )

        input_quantized = QuantizeLinearTensor(
            x=input,
            scale=input_scales_to_use,
            q_min=-self.M_x,
            q_max=self.M_x,
            power=self.power_exponent,
            deterministic=self.deterministic,
        )
        if self.per_channel:
            kernel = self.kernel / tf.expand_dims(input_scales_to_use, axis=-1)
        else:
            kernel = self.kernel
        weights_quantized = QuantizeLinearTensor(
            x=kernel,
            scale=weight_scales_to_use,
            q_min=-self.M_w,
            q_max=self.M_w,
            power=self.power_exponent,
            deterministic=self.deterministic,
        )
        if self.accumulator_bits == 32:
            input_quantized = DeQuantizeLinearTensor(
                x=input_quantized,
                scale=input_scales_to_use,
                power=self.power_exponent,
            )
            weights_quantized = DeQuantizeLinearTensor(
                x=weights_quantized,
                scale=weight_scales_to_use,
                power=self.power_exponent,
            )
        input_quantized = tf.stop_gradient(input_quantized - input) + input
        weights_quantized = (
            tf.stop_gradient(weights_quantized - self.kernel) + self.kernel
        )

        output = tf.linalg.matmul(a=input_quantized, b=weights_quantized)

        if self.accumulator_bits != 32:
            output = self.cyclical_activation(output)

        if self.use_bias:
            output = tf.nn.bias_add(output, self.b)

        if self.batchnormalization:
            output = self.batchnormalization_layer(output)
        if self.pooling is not None:
            output = self.pooling(output)
        if self.activation != "linear":
            if isinstance(self.activation, str):
                output = tf.keras.layers.Activation(self.activation)(output)
            else:
                output = self.activation(output)
        return output
