if __name__ == "__main__":
    from I_Poly import compute_i_poly
else:
    from .I_Poly import compute_i_poly
from numpy import float32
import tensorflow as tf
from typing import Tuple


def compute_i_erf(q: tf.Variable, S: tf.Variable) -> Tuple[tf.Variable, tf.Variable]:
    """
    computes the integer-only approximation of erf.
    This comes from the Alg2 of I-Bert [1]. q and S are the input quantized
    and scalign factor.

    note: I-Bert inverts the way of applying the scale, they do x / s instead of x * s
    note: I-Bert seems to have a mistake in algorithm 2, we should use b=0 in I_poly
    """
    a = float(-0.2888)
    b = float(-1.769)
    c = float(1)
    q_sign = tf.math.sign(q)
    q = tf.minimum(x=tf.math.abs(q), y=-b / S)
    qL, sL = compute_i_poly(q=q, S=S, a=a, b=b, c=c)
    q_out = qL * q_sign
    S_out = sL
    return (q_out, S_out)


if __name__ == "__main__":
    import numpy as np
    from scipy.special import erf

    error = []
    for _ in range(100):
        q = np.random.randint(low=-127, high=127, size=(8, 8, 3))
        q = q.astype(float32)
        S = float(6 / 127)

        exact_erf = erf(q * S)
        q_out, S_out = compute_i_erf(q=q, S=S)
        i_erf = q_out * S_out
        error.append(np.mean(np.abs(exact_erf - i_erf)))
    print(f"average difference = {np.mean(error)}")
