import tensorflow as tf

if __name__ == "__main__":
    from QOverflow import ModuLayer
else:
    from .QOverflow import ModuLayer
from typing import Any, Dict


@tf.keras.utils.register_keras_serializable()
class QAdd(tf.keras.layers.Add):
    def __init__(
        self,
        accumulator_bits: int = 32,
        ste_overflow: bool = False,
        cyclical_alpha: float = 1.0,
        *args,
        **kwargs
    ) -> None:
        super().__init__(*args, **kwargs)
        self.accumulator_bits = accumulator_bits
        self.ste_overflow = ste_overflow
        self.cyclical_alpha = cyclical_alpha

        if self.accumulator_bits != 32:
            self.cyclical_activation = ModuLayer(
                representation_bits=accumulator_bits,
                use_ste=ste_overflow,
                alpha=cyclical_alpha,
            )

    def update_cyclical_parameters(self, alpha: Any = None, convex: Any = None) -> None:
        if alpha is not None:
            self.cyclical_activation.alpha = alpha
        if convex is not None:
            self.cyclical_activation.convex_combine = convex

    def _merge_function(self, inputs):
        output = inputs[0]
        for i in range(1, len(inputs)):
            output += inputs[i]

        if self.accumulator_bits != 32:
            output = self.cyclical_activation(output)
        return output

    def get_config(
        self,
    ) -> Dict[str, Any]:
        config = super().get_config()
        config.update(
            {
                "accumulator_bits": self.accumulator_bits,
                "ste_overflow": self.ste_overflow,
                "cyclical_alpha": self.cyclical_alpha,
            }
        )
        return config


if __name__ == "__main__":
    import numpy as np

    X1 = -np.ones(shape=(3, 2, 1))
    X2 = np.ones(shape=(3, 2, 1))

    myaddlayer = QAdd()
    print(myaddlayer([X1, X2]))
    print(X1 + X2)

    X1 = -20 * np.ones(shape=(3, 2, 1))
    X2 = 200 * np.ones(shape=(3, 2, 1))

    myaddlayer = QAdd(accumulator_bits=4)
    print(myaddlayer([X1, X2]))
    print(X1 + X2)

    myaddlayer.update_cyclical_parameters(convex=0)
    print(myaddlayer([X1, X2]))
    print(X1 + X2)
