import tensorflow as tf


def update_weight_scales(
    weights: tf.Variable,
    weight_scales: tf.Variable,
    axis_per_channel:tuple,
    exponent: float = 1.0,
    weight_bits: int = 8,
) -> tf.Variable:
    """
    Updates the weight scales based on the current floating weight values
    """

    def compute_update():
        max_range = tf.math.pow(
            tf.math.reduce_max(input_tensor=tf.math.abs(weights), axis=axis_per_channel),
            exponent,
        )
        max_range = tf.maximum(max_range, 1e-4)
        new_scale = (2 ** (weight_bits - 1) - 1) / max_range
        return new_scale

    tf.keras.backend.update(weight_scales, compute_update())
    return weight_scales


def update_input_scales(
    I: tf.Variable,
    input_scales: tf.Variable,
    axis_per_channel: tuple,
    per_channel: bool = False,
    exponent: float = 1.0,
    activation_bits: int = 8,
    ema: float = 0.9,
) -> tf.Variable:
    """
    Updates the activation scales based on the current input/activation values.
    We also use this function for output scales
    """

    def compute_update(I):
        if per_channel:
            max_range = tf.math.pow(
                tf.math.reduce_max(input_tensor=tf.math.abs(I), axis=axis_per_channel),
                exponent,
            )
        else:
            max_range = tf.expand_dims(
                tf.math.pow(tf.math.reduce_max(input_tensor=tf.math.abs(I)), exponent),
                axis=0,
            )
        max_range = tf.maximum(max_range, 1e-4)
        empirical_scale = (2 ** (activation_bits - 1) - 1) / max_range
        new_scale = (1 - ema) * empirical_scale + ema * input_scales
        return new_scale, empirical_scale

    new_scale, empirical_scale = compute_update(I)
    tf.keras.backend.update(input_scales, new_scale)
    return empirical_scale


if __name__ == "__main__":
    pass
