if __name__ == "__main__":
    from I_Poly import compute_i_poly
else:
    from .I_Poly import compute_i_poly
from typing import Tuple
import tensorflow as tf


def compute_i_exp(q: tf.Variable, S: tf.Variable) -> Tuple[tf.Variable, tf.Variable]:
    """
    computes I_exp based on algorithm 3 in [1]
    """
    a = float(0.3585)
    b = float(1.353)
    c = float(0.344)
    q_ln2 = tf.math.round(tf.math.log(2.0) / S)
    z = tf.math.round(-q / q_ln2)
    qp = q + z * q_ln2
    qL, SL = compute_i_poly(q=qp, S=S, a=a, b=b, c=c)
    q_out = qL * (2 ** (-z))
    S_out = SL
    return q_out, S_out


if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import numpy as np

    x = np.expand_dims(np.linspace(start=-4, stop=0, num=500), axis=-1)
    s = float(4 / 127)
    q_out, S_out = compute_i_exp(q=np.round(x / s), S=s)
    y = q_out * S_out
    y_ = np.exp(x)
    print(np.min(x), np.max(x))
    print(np.min(y), np.max(y))
    print(np.min(y_), np.max(y_))
    plt.plot(x, y, label="i-exp")
    plt.plot(x, y_, label="exp")
    plt.legend()
    plt.show()
    plt.close()
