if __name__ == "__main__":
    from I_Bert.I_Erf import compute_i_erf
    from QDQ import quantize_tensor, dequantize_tensor
    from QGradients import apply_gradients_gelu
else:
    from .I_Bert.I_Erf import compute_i_erf
    from .QDQ import quantize_tensor, dequantize_tensor
    from .QGradients import apply_gradients_gelu
import tensorflow as tf
from typing import Any, Dict
import numpy as np


@tf.keras.utils.register_keras_serializable()
class QGeLU(tf.keras.layers.Layer):
    def __init__(self, num_bits: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.num_bits = num_bits
        self.ema = 0.99

    def get_config(self) -> Dict[str, Any]:
        config = super().get_config()
        config["num_bits"] = self.num_bits
        return config

    def build(self, input_shape: tuple) -> None:
        self.input_scale = self.add_weight(
            name=self.name + "/input_scale",
            shape=(1,),
            initializer="ones",
            trainable=False,
        )

    def update_input_scales(self, input: tf.Variable) -> tf.Variable:
        def compute_update(I):
            max_range = tf.expand_dims(
                tf.math.reduce_max(input_tensor=tf.math.abs(I)), axis=0
            )
            max_range = tf.maximum(max_range, 1e-4)
            empirical_scale = (2 ** (self.num_bits - 1) - 1) / max_range
            new_scale = (1 - self.ema) * empirical_scale + self.ema * self.input_scale
            return new_scale, empirical_scale

        new_scale, empirical_scale = compute_update(input)
        tf.keras.backend.update(self.input_scale, new_scale)
        return empirical_scale

    def get_scales(self, input: tf.Variable, training: Any) -> tf.Variable:
        input_scales_to_use = tf.cond(
            pred=training,
            true_fn=lambda: self.update_input_scales(input=input),
            false_fn=lambda: self.input_scale,
        )
        return input_scales_to_use

    def call(
        self, input: tf.Variable, training: Any = None, *args: Any, **kwargs: Any
    ) -> Any:
        if training is not None:
            training = tf.cast(x=training, dtype=tf.bool)
        else:
            training = tf.cast(x=tf.keras.backend.learning_phase(), dtype=tf.bool)

        input_quantization_scale = self.get_scales(input=input, training=training)
        input_quantized = quantize_tensor(
            x=input, s=input_quantization_scale, bits=self.num_bits
        )
        s = 1 / input_quantization_scale
        q_erf, s_erf = compute_i_erf(q=input_quantized, S=s / np.sqrt(2))
        q1 = tf.math.round(1 / s_erf)
        q_out = input_quantized * (q_erf + q1)
        s_out = s * s_erf / 2
        output_dequantized = q_out * s_out
        output = tf.stop_gradient(output_dequantized - input) + input
        return apply_gradients_gelu(output)


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    inputs = tf.keras.layers.Input(shape=(1,))
    outputs = QGeLU(num_bits=8)(inputs)
    test_model = tf.keras.Model(inputs, outputs)

    test_model.layers[-1].set_weights([np.array([127 / 4])])

    x = np.expand_dims(np.linspace(start=-4, stop=4, num=1000), axis=-1)
    y = test_model(x, training=False)
    y_ = tf.nn.gelu(x)
    y__ = tf.nn.relu(x)
    print(np.min(x), np.max(x))
    print(np.min(y), np.max(y))
    print(np.min(y_), np.max(y_))
    plt.plot(x, y, label="i-gelu")
    plt.plot(x, y_, label="gelu")
    plt.plot(x, y__, label="relu")
    plt.legend()
    plt.show()
    plt.close()
