import tensorflow as tf
from typing import Callable, Tuple


def compute_newton_sqrt(x: tf.Variable) -> tf.Variable:
    """
    the algorithm assumes that x is positive non-zero.
    """
    max_iter = 5
    x_0 = tf.math.pow(2.0, tf.math.round((tf.math.log(x) / tf.math.log(2.0)) / 2.0))
    for cpt in range(max_iter):
        x_next = tf.math.round((x_0 + tf.math.round(x / x_0)) / 2.0)
    return x_next


@tf.custom_gradient
def compute_i_sqrt(x: tf.Variable) -> Tuple[tf.Variable, Callable]:
    """
    computes the integer only square root.
    This is based on algorithm 4 of [1]
    """

    def grad(dX):
        dSqrt = tf.where(
            x > 0, 1 / (2 * tf.math.sqrt(tf.maximum(x, 1))), tf.zeros_like(x)
        )
        return dX * dSqrt

    x = tf.cast(x=x, dtype=tf.float32)
    x = tf.nn.relu(x)
    return (
        tf.where(x > 0, compute_newton_sqrt(tf.maximum(x, 1)), tf.zeros_like(x)),
        grad,
    )


if __name__ == "__main__":
    print(f"sqrt(16) = {compute_i_sqrt([16])}")
    print(f"sqrt(26) = {compute_i_sqrt([26])}")
    print(f"sqrt(25) = {compute_i_sqrt([25])}")
    print(f"sqrt(-6) = {compute_i_sqrt([-6])}")
    print(f"sqrt(0) = {compute_i_sqrt([0])}")
    print(f"sqrt(1) = {compute_i_sqrt([1])}")
