import tensorflow as tf
from typing import Any, Dict, Tuple

if __name__ == "__main__":
    from QLinear import *
    from QUpdateEMA import *
    from QUtils import *
    from QOverflow import ModuLayer
else:
    from .QLinear import *
    from .QUpdateEMA import *
    from .QUtils import *
    from .QOverflow import ModuLayer


@tf.keras.utils.register_keras_serializable()
class QBlockConv2D(tf.keras.layers.Layer):
    def __init__(
        self,
        name: str,
        filters: int,
        kernel_size: tuple,
        power_exponent: float = 1.0,
        stride: tuple = (1, 1, 1, 1),
        activation_bits: int = 8,
        weight_bits: int = 8,
        accumulator_bits: int = 32,
        is_depthwise: bool = False,
        padding: str = "SAME",
        use_bias: bool = True,
        activation: str = "linear",
        depth_multiplier: int = 1,
        dilation_rate: tuple = (1, 1),
        per_channel: bool = False,
        deterministic: bool = True,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        pooling: Any = None,
        batchnormalization: bool = False,
        weight_initializer: tf.keras.initializers.Initializer = "glorot_uniform",
        bias_initializer: tf.keras.initializers.Initializer = "zeros",
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.ema = 0.99
        self.core_name = name
        self.filters = filters
        self.kernel_size = kernel_size
        self.stride = stride
        self.activation_bits = activation_bits
        self.weight_bits = weight_bits
        self.accumulator_bits = accumulator_bits
        self.is_depthwise = is_depthwise
        self.padding = padding.upper()
        self.use_bias = use_bias
        self.activation = activation.lower()
        self.depth_multiplier = depth_multiplier
        self.dilation_rate = dilation_rate
        self.per_channel = per_channel
        self.power_exponent = power_exponent
        self.deterministic = deterministic
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha
        self.weight_initializer = weight_initializer
        self.bias_initializer = bias_initializer
        self.pooling = None
        if pooling is not None:
            self.pooling = type(pooling).from_config(pooling.get_config())
        self.batchnormalization = batchnormalization
        if batchnormalization:
            self.batchnormalization_layer = tf.keras.layers.BatchNormalization()

        self.M_x = convert_num_bits_to_range(num_bits=activation_bits)
        self.M_w = convert_num_bits_to_range(num_bits=weight_bits)
        self.M_a = convert_num_bits_to_range(num_bits=accumulator_bits)

        if self.accumulator_bits != 32:
            self.cyclical_activation = ModuLayer(
                representation_bits=accumulator_bits,
                use_ste=ste_overflow,
                alpha=self.cyclical_alpha,
            )

        self.axis_per_channel_weights = (0, 1, 2)
        self.axis_per_channel_activations = (0, 1, 2)
        if is_depthwise:
            self.axis_per_channel_weights = (0, 1, 3)
            if len(self.stride) == 2:
                self.stride = (1, self.stride[0], self.stride[1], 1)

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["name"] = self.core_name
        config["filters"] = self.filters
        config["kernel_size"] = self.kernel_size
        config["power_exponent"] = self.power_exponent
        config["stride"] = self.stride
        config["activation_bits"] = self.activation_bits
        config["weight_bits"] = self.weight_bits
        config["accumulator_bits"] = self.accumulator_bits
        config["is_depthwise"] = self.is_depthwise
        config["padding"] = self.padding
        config["use_bias"] = self.use_bias
        config["activation"] = self.activation
        config["depth_multiplier"] = self.depth_multiplier
        config["dilation_rate"] = self.dilation_rate
        config["per_channel"] = self.per_channel
        config["deterministic"] = self.deterministic
        config["ste_overflow"] = self.ste_overflow
        config["cyclical_alpha"] = self.cyclical_alpha
        config["pooling"] = self.pooling
        config["batchnormalization"] = self.batchnormalization
        config["weight_initializer"] = self.weight_initializer
        config["bias_initializer"] = self.bias_initializer
        return config

    def build_parameters(
        self, kernel_shape: tuple, bias_shape: tuple, input_dim: int
    ) -> None:
        """
        builds the trainable and non-trainable variables based on their shapes
        """
        self.kernel = self.add_weight(
            name=self.core_name + "/kernel",
            shape=kernel_shape,
            initializer=self.weight_initializer,
            trainable=True,
        )
        if self.use_bias:
            self.b = self.add_weight(
                name=self.core_name + "/bias",
                shape=bias_shape,
                initializer=self.bias_initializer,
                trainable=True,
            )
        self.weight_scales = self.add_weight(
            name=self.core_name + "/kernel_scale",
            shape=bias_shape,
            initializer="ones",
            trainable=False,
        )
        self.input_scales = self.add_weight(
            name=self.core_name + "/input_scale",
            shape=(input_dim,),
            initializer="ones",
            trainable=False,
        )
        if self.accumulator_bits == 32:
            self.output_scales = self.add_weight(
                name=self.core_name + "/output_scale",
                shape=bias_shape,
                initializer="ones",
                trainable=False,
            )

    def update_parameter(self, param_name: str, new_value: Any):
        """
        changes the value of a parameter given its name with the new
        provided value
        """
        if param_name == "ema":
            self.ema = new_value
        elif param_name == "alpha":
            self.cyclical_alpha = new_value
        else:
            raise NotImplementedError(
                f"requested to update {param_name}, but only 'ema' and 'alpha' are parameters to update"
            )

    def build(self, input_shape: tuple) -> None:
        """
        builds the trainable and non-trainable variables of the model
        by first computing their shapes based on hyper-parameters and input shape
        before calling self.build_parameters
        """
        if self.is_depthwise:
            kernel_shape = (
                self.kernel_size[0],
                self.kernel_size[1],
                input_shape[-1],
                self.depth_multiplier,
            )
            bias_shape = (input_shape[-1] * self.depth_multiplier,)
        else:
            kernel_shape = (
                self.kernel_size[0],
                self.kernel_size[1],
                input_shape[-1],
                self.filters,
            )
            bias_shape = (self.filters,)
        input_dim = input_shape[-1] if self.per_channel else 1
        self.build_parameters(
            kernel_shape=kernel_shape, bias_shape=bias_shape, input_dim=input_dim
        )

    def main_operation(
        self, input_quantized: tf.Variable, weights_quantized: tf.Variable
    ) -> tf.Variable:
        """
        performs the main operation: either a conv2D or a depthwise conv2d
        """
        if self.is_depthwise:
            return tf.nn.depthwise_conv2d(
                input=input_quantized,
                filter=weights_quantized,
                strides=self.stride,
                padding=self.padding,
                dilations=self.dilation_rate,
            )
        return tf.nn.conv2d(
            input=input_quantized,
            filters=weights_quantized,
            strides=self.stride,
            padding=self.padding,
        )

    def get_operation_scales(
        self, training: bool, input: tf.Variable
    ) -> Tuple[tf.Variable, tf.Variable]:
        """
        gets and updates based on inference or training mode.
        """
        weight_scales_to_use = tf.cond(
            pred=training,
            true_fn=lambda: update_weight_scales(
                weights=self.kernel,
                weight_scales=self.weight_scales,
                exponent=self.power_exponent,
                weight_bits=self.weight_bits,
                axis_per_channel=self.axis_per_channel_weights,
            ),
            false_fn=lambda: self.weight_scales,
        )
        input_scales_to_use = tf.cond(
            pred=training,
            true_fn=lambda: update_input_scales(
                I=input,
                per_channel=self.per_channel,
                input_scales=self.input_scales,
                exponent=self.power_exponent,
                activation_bits=self.activation_bits,
                ema=self.ema,
                axis_per_channel=self.axis_per_channel_activations,
            ),
            false_fn=lambda: self.input_scales,
        )
        return weight_scales_to_use, input_scales_to_use

    def get_output_scales(self, training: bool, output: tf.Variable) -> tf.Variable:
        return tf.cond(
            pred=training,
            true_fn=lambda: update_input_scales(
                I=output,
                per_channel=self.per_channel,
                input_scales=self.output_scales,
                exponent=self.power_exponent,
                activation_bits=self.activation_bits,
                ema=self.ema,
                axis_per_channel=self.axis_per_channel_activations,
            ),
            false_fn=lambda: self.output_scales,
        )

    def update_cyclical_parameters(self, alpha: Any = None, convex: Any = None) -> None:
        if alpha is not None:
            self.cyclical_activation.alpha = alpha
        if convex is not None:
            self.cyclical_activation.convex_combine = convex

    def call(
        self, input: tf.Variable, training=None, *args: Any, **kwargs: Any
    ) -> tf.Variable:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)

        weight_scales_to_use, input_scales_to_use = self.get_operation_scales(
            training=training, input=input
        )
        if self.is_depthwise:
            weight_scales_to_use = tf.expand_dims(weight_scales_to_use, axis=-1)

        input_quantized = QuantizeLinearTensor(
            x=input,
            scale=input_scales_to_use,
            q_min=-self.M_x,
            q_max=self.M_x,
            power=self.power_exponent,
            deterministic=self.deterministic,
        )
        if self.per_channel:
            kernel = self.kernel / tf.expand_dims(
                tf.expand_dims(tf.expand_dims(input_scales_to_use, axis=0), axis=0),
                axis=-1,
            )
        else:
            kernel = self.kernel
        weights_quantized = QuantizeLinearTensor(
            x=kernel,
            scale=weight_scales_to_use,
            q_min=-self.M_w,
            q_max=self.M_w,
            power=self.power_exponent,
            deterministic=self.deterministic,
        )
        if self.accumulator_bits == 32:
            input_quantized = DeQuantizeLinearTensor(
                x=input_quantized,
                scale=input_scales_to_use,
                power=self.power_exponent,
            )
            weights_quantized = DeQuantizeLinearTensor(
                x=weights_quantized,
                scale=weight_scales_to_use,
                power=self.power_exponent,
            )
        input_quantized = tf.stop_gradient(input_quantized - input) + input
        weights_quantized = (
            tf.stop_gradient(weights_quantized - self.kernel) + self.kernel
        )

        output = self.main_operation(
            input_quantized=input_quantized, weights_quantized=weights_quantized
        )

        if self.accumulator_bits != 32:
            output = self.cyclical_activation(output)

        if self.use_bias:
            output = tf.nn.bias_add(output, self.b)

        if self.batchnormalization:
            output = self.batchnormalization_layer(output)
        if self.pooling is not None:
            output = self.pooling(output)
        if self.activation != "linear":
            if isinstance(self.activation, str):
                output = tf.keras.layers.Activation(self.activation)(output)
            else:
                output = self.activation(output)
        return output


if __name__ == "__main__":
    import numpy as np

    inputs = tf.keras.layers.Input(shape=(10, 10, 3))
    outputs = QBlockConv2D(name="debug_conv", filters=5, kernel_size=(3, 3))(inputs)
    test_model = tf.keras.Model(inputs, outputs)

    x = np.random.normal(size=(1, 10, 10, 3))
    y = test_model(x, training=True)
    print(y)
