import torch
import matplotlib.pyplot as plt


class GSTempScheduler:
    def __init__(self, T_0, high, low, hold=0):
        self.T_0 = T_0
        # self.thres = thres
        self.hold = hold
        self.schedule = torch.full((int(T_0 + hold),), low)
        x = torch.linspace(0, 3.14159, T_0)
        self.schedule[:T_0] = 0.5 * (x.cos() + 1) * (high - low) + low
        if False:
            print("schedule ", self.schedule)
            plt.plot(self.schedule.detach().numpy())
            plt.show()

    def temp(self, it):
        return self.schedule[it % (self.T_0 + self.hold)]

    def store(self, it):
        return (it % (self.T_0 + self.hold)) >= self.T_0
        # return self.temp(it)<=self.thres
