from typing import Tuple
import tensorflow as tf


def compute_i_poly(
    q: tf.Variable, S: tf.Variable, a: float, b: float, c: float
) -> Tuple[tf.Variable, tf.Variable]:
    """
    computes the integer-only second order polynomial a(x+b)^2 +c.
    This comes from the Alg1 of I-Bert [1]. q and S are the input quantized
    and scalign factor.

    q_out S_out ≈ a(qs+b)^2 + c
    Note: je soupsonne qu'il y ait une erreur dans l'algo 1 du papier [1]
    Ils arrondissent le scale... ce qui n'a pas d'intéret et introduit une
    perte d'information supplémentaire
    """
    qb = tf.math.round(b / S)
    qc = tf.math.round(c / (a * (S**2)))
    S_out = a * (S**2)
    q_out = (q + qb) ** 2 + qc
    return (q_out, S_out)


if __name__ == "__main__":
    import numpy as np

    error = []
    for _ in range(100):
        q = np.random.randint(low=-7, high=8, size=(8, 8, 3))
        S = np.abs(np.random.normal()) + 0.5

        a = 2
        b = 1
        c = -1
        poly = a * ((q * S + b) ** 2) + c
        q_out, S_out = compute_i_poly(q=q, S=S, a=2, b=1, c=-3)
        i_poly = q_out * S_out
        error.append(np.mean(np.abs(poly - i_poly)))
    print(f"average difference = {np.mean(error)}")
