import tensorflow as tf


class ConstantSchedule:
    def __init__(self, value=1.0):
        self.value = value

    def __call__(self, step):
        return self.value


class ZeroOneSchedule:
    def __init__(self, threshold_step, init_step=1):
        self.threshold_step = threshold_step
        self.init_step = init_step

    def __call__(self, step):
        if tf.greater(tf.cast(step, tf.float32), self.threshold_step):
            return 1.0
        else:
            return 0.0


class LinearSchedule:
    def __init__(self, max_steps=256 * 100.0, init_step=256 * 10.0):
        self.init_step = tf.cast(init_step, tf.float32)
        self.max_steps = tf.cast(max_steps, tf.float32)

    def __call__(self, step):
        if tf.less(tf.cast(step, tf.float32), self.init_step):
            return 0.0
        else:
            return tf.cast(step, tf.float32) - self.init_step / (
                self.max_steps - self.init_step
            )
