if __name__ == "__main__":
    from I_Bert.I_Exp import compute_i_exp
    from QDQ import quantize_tensor, dequantize_tensor
    from QGradients import apply_gradients_softmax
else:
    from .I_Bert.I_Exp import compute_i_exp
    from .QDQ import quantize_tensor, dequantize_tensor
    from .QGradients import apply_gradients_softmax
import tensorflow as tf
from typing import Any, Dict
import numpy as np

@tf.keras.utils.register_keras_serializable()
class QSoftmax(tf.keras.layers.Layer):
    def __init__(self, num_bits:int=8, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.num_bits = num_bits
        self.ema = 0.99

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["num_bits"] = self.num_bits
        return config

    def build(self, input_shape: tuple) -> None:
        self.axis = tuple(np.arange(len(input_shape))[1:])
        self.input_scale = self.add_weight(
            name=self.name + "/input_scale",
            shape=(1,),
            initializer="ones",
            trainable=False,
        )

    def update_input_scales(self, input: tf.Variable) -> tf.Variable:
        def compute_update(I):
            max_range = tf.expand_dims(
                tf.math.reduce_max(input_tensor=tf.math.abs(I)), axis=0
            )
            max_range = tf.maximum(max_range, 1e-4)
            empirical_scale = (2 ** (self.num_bits - 1) - 1) / max_range
            new_scale = (1 - self.ema) * empirical_scale + self.ema * self.input_scale
            return new_scale, empirical_scale

        new_scale, empirical_scale = compute_update(input)
        tf.keras.backend.update(self.input_scale, new_scale)
        return empirical_scale

    def get_scales(self, input: tf.Variable, training: Any) -> tf.Variable:
        input_scales_to_use = tf.cond(
            pred=training,
            true_fn=lambda: self.update_input_scales(input=input),
            false_fn=lambda: self.input_scale,
        )
        return input_scales_to_use
    
    def reduce_sum_softmax(self, x:tf.Variable):
        output = tf.math.reduce_sum(input_tensor=x, axis=self.axis)
        for _ in range(len(self.axis)):
            output = tf.expand_dims(output, axis = -1)
        return output

    def call(
        self, input: tf.Variable, training: Any = None, *args: Any, **kwargs: Any
    ) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)
        input_quantization_scale = self.get_scales(input=input, training=training)
        input_quantized = quantize_tensor(
            x=input, s=input_quantization_scale, bits=self.num_bits
        )
        s = 1 / (input_quantization_scale)
        q_out, s_out = compute_i_exp(
            q=input_quantized
            - ((2**(self.num_bits - 1))-1),
            S=s,
        )
        q_out = q_out * s_out
        q_out = q_out / self.reduce_sum_softmax(q_out)
        output = tf.stop_gradient(q_out - input) + input
        return apply_gradients_softmax(output)


if __name__ == "__main__":
    x = np.random.normal(size=(32,10))
    
    inputs = tf.keras.layers.Input(shape=(10,))
    outputs = QSoftmax()(inputs)
    test_model = tf.keras.Model(inputs, outputs)
    test_model.layers[-1].set_weights([np.array([127 / np.max(np.abs(x))])])
    
    y = test_model(x, training=False).numpy()
    y_ = tf.nn.softmax(x).numpy()
    
    print(np.min(x), np.mean(x), np.max(x))
    print(np.min(y), np.mean(y), np.max(y))
    print(np.min(y_), np.mean(y_), np.max(y_))
    print(f"integer inference error: {np.mean(np.abs(y - y_)):.4f}")
    