from spaghettini import quick_register
import matplotlib.pyplot as plt

import numpy as np


@quick_register
def saturated_linear(epoch, k1, k2, a1, a2):
    if epoch < k1:
        return a1
    elif k1 <= epoch <= k2:
        if k1 == k2:
            return a1
        return a1 + ((a2 - a1) / (k2 - k1))*(epoch - k1)
    else:
        return a2


@quick_register
def constant(epoch, a1):
    return a1


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.schedules.annealing_schedules
    """
    test_num = 0

    if test_num == 0:
        epochs = np.arange(0, 100)
        k1, k2 = 12, 12
        a1, a2 = 10, 0
        coeffs = list()
        for epoch in epochs:
            coeffs.append(saturated_linear(epoch=epoch, k1=k1, k2=k2, a1=a1, a2=a2))
        plt.plot(epochs, coeffs)
        plt.show()
