import math


class Sched:
    """``0 <= current_step <= total_step``"""

    @staticmethod
    def linear(base_value, current_step, total_step, start_factor, end_factor):
        return base_value * (
            start_factor + (end_factor - start_factor) / total_step * current_step
        )

    @staticmethod
    def cosine(base_value, current_step, total_step, min_value):
        return (
            min_value
            + (base_value - min_value)
            * (1 + math.cos(math.pi * current_step / total_step))
            / 2
        )

    @staticmethod
    def linear_cosine(
        base_value,
        current_step,
        warmup_step,
        total_step,
        start_factor,
        mid_factor,
        min_value,
    ):
        if current_step < warmup_step:
            return __class__.linear(
                base_value, current_step, warmup_step, start_factor, mid_factor
            )
        elif warmup_step <= current_step <= total_step:
            return __class__.cosine(
                base_value,
                current_step - warmup_step,
                total_step - warmup_step,
                min_value,
            )
        else:
            raise "ValueError"

    @staticmethod
    def exponent(base_value, current_step, total_step, scale=5, factor=0.5):
        return base_value * factor ** (current_step / total_step * scale)

    @staticmethod
    def linear_exponent(
        base_value, current_step, warmup_step, total_step, scale=5, factor=0.5
    ):
        """note: this is multiply while linear-cosine is concat"""
        if current_step < warmup_step:
            value1 = __class__.linear(1, current_step, warmup_step, 0, 1)
        else:
            value1 = 1
        value2 = __class__.exponent(1, current_step, total_step, scale, factor)
        return base_value * value1 * value2
