import tensorflow as tf


def stochastic_round(x: tf.Variable, precision: float = 1.0) -> tf.Variable:
    """Performs stochastic rounding to the first decimal point."""
    scale = 1.0 / precision
    scale_x = x * scale
    fraction = scale_x - tf.floor(scale_x)

    result = tf.where(
        fraction < tf.random.uniform(tf.shape(x)),
        tf.math.floor(scale_x),
        tf.math.ceil(scale_x),
    )
    return result / scale


def deterministic_rounding(x: tf.Variable) -> tf.Variable:
    """Performs deterministic rounding to the closest integer"""
    return tf.math.round(x)


def QuantizeLinearTensor(
    x: tf.Variable,
    scale: tf.Variable,
    q_min: int,
    q_max: int,
    power: int = 1,
    deterministic: bool = True,
) -> tf.Variable:
    """
    This function quantizes a tensor x using a scale within the range [qmin; qmax].
    Eventually using power and or stochastic quantization. We use the following formulae
        x -> clip(round(x * scale))
        dx -> d(x * scale)
    If we use power quantization (power != 1)
        x -> clip(round(x^a * scale))
        dx -> d(x^a * scale)
    """
    if power != 1:
        x_sign = tf.math.sign(x)
        x_q = tf.math.pow(tf.math.abs(x), power)
    else:
        x_sign = 1.0
        x_q = x
    x_q = x_q * scale

    if deterministic:
        x_round = deterministic_rounding(x_q)
    else:
        x_round = stochastic_round(x_q)
    x_q = tf.clip_by_value(t=x_round, clip_value_min=q_min, clip_value_max=q_max)
    # x_q = x_q + tf.stop_gradient(x_round - x_q)
    x_q = x_q * x_sign
    return x_q


def DeQuantizeLinearTensor(
    x: tf.Variable,
    scale: tf.Variable,
    power: int = 1,
) -> tf.Variable:
    """
    This function de-quantizes a tensor x using a scale within the range [qmin; qmax].
    Eventually using power and or stochastic quantization. We use the following formulae
        x -> x / scale
    If we use power quantization (power != 1)
        x -> (x / scale)^(1/a)
    This function associates standard gradients (the straight through operation happens
    in the QuantizeLinearTensor function)
    """
    x_q = x / tf.maximum(scale, 1.0e-4)
    if power != 1:
        x_sign = tf.math.sign(x)
        x_q = x_sign * tf.math.pow(tf.math.abs(x_q), 1.0 / power)
    return x_q


if __name__ == "__main__":
    import numpy as np

    X = np.random.normal(size=(2, 2, 1))
    S = 8 / np.max(np.abs(X))
    print(X)
    print(QuantizeLinearTensor(X, S, -8, 8))
    print(DeQuantizeLinearTensor(QuantizeLinearTensor(X, S, -8, 8), S))
