from typing import Dict, Any, Callable, Tuple
import tensorflow as tf


@tf.keras.utils.register_keras_serializable()
class ModuLayer(tf.keras.layers.Layer):
    def __init__(
        self, representation_bits: int, use_ste: bool, alpha: float, *args, **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.representation_bits = representation_bits
        self.use_ste = use_ste
        self.alpha = alpha
        self.convex_combine = 1.0

    def get_config(self) -> Dict[str, Any]:
        config = super.get_config()
        config["representation_bits"] = self.representation_bits
        config["use_ste"] = self.use_ste
        config["alpha"] = self.alpha
        return config

    def build(self, input_shape: tuple):
        self.zero_point = self.add_weight(
            name=self.name + "/zero_point",
            shape=(1,),
            initializer=tf.constant_initializer(
                value=2 ** (self.representation_bits - 1)
            ),
            trainable=False,
        )
        self.module = self.add_weight(
            name=self.name + "/module",
            shape=(1,),
            initializer=tf.constant_initializer(value=2**self.representation_bits),
            trainable=False,
        )

    @tf.custom_gradient
    def call(self, input: tf.Variable) -> Tuple[tf.Variable, Callable]:
        x = input
        mod_x = (
            tf.math.floormod(x + self.zero_point, self.module, name=None)
            - self.zero_point
        )

        def custom_grad(dX: tf.Variable) -> tf.Variable:
            """
            we want to avoid overflow
            """
            if self.use_ste:
                return dX
            upstream = dX
            upstream = tf.Variable(
                tf.where(
                    input > self.zero_point, -(self.alpha * tf.abs(upstream)), upstream
                )
            )
            upstream = tf.Variable(
                tf.where(
                    input < -self.zero_point, (self.alpha * tf.abs(upstream)), upstream
                )
            )

            return dX * self.convex_combine + upstream * (1 - self.convex_combine)

        output = input * self.convex_combine + mod_x * (1 - self.convex_combine)
        return output, custom_grad


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import numpy as np

    x = tf.keras.layers.Input(shape=(1000,))
    y = ModuLayer(representation_bits=8)(x)
    my_model = tf.keras.Model(x, y)

    x = np.linspace(-500, 500, 1000)
    module = 256
    zero_point = 128
    mod_x = np.mod(x + zero_point, module) - zero_point
    mod_x_tf = my_model(x).numpy()
    print(np.sum(np.abs(mod_x - mod_x_tf)))
    plt.plot(x, np.zeros(shape=x.shape))
    for t in np.linspace(0, 1, 11):
        plt.plot(x, x * t + mod_x_tf * (1 - t), label=f"t={t:2f}")
    plt.legend()
    plt.show()
    plt.close()
