import math
from model.mydataclass import TrainingParams
from tensorboardX import SummaryWriter


class scheduler(object):
    def __init__(self, params: TrainingParams, init_beta: int, init_step: int) -> None:
        
        self.beta = init_beta
        self.t = init_step
        self.warmup = params.beta_warmup
        self.beta_min = params.beta_min
        self.beta_max = params.beta_max
        self.beta_anneal_period = params.beta_anneal_period
    
    def step(self):
        pass

class linear_schedule(scheduler):
    def __init__(self, params: TrainingParams, init_beta: int, init_step: int) -> None:
        super(linear_schedule, self).__init__(params, init_beta, init_step)
        self.beta_step = (self.beta_max - self.beta_min) / self.beta_anneal_period
    
    def step(self):
        if self.t < self.warmup:
            self.beta = self.beta_min
        elif self.t <= self.warmup + self.beta_anneal_period:
            self.beta += self.beta_step
        self.t += 1
        return self.beta


class cyclical_schedule(scheduler):
    def __init__(self, params: TrainingParams, init_beta: int, init_step: int) -> None:
        super(cyclical_schedule, self).__init__(params, init_beta, init_step)
        self.beta_num_cycles = params.beta_num_cycles

        self.cycle_period = self.beta_anneal_period // self.beta_num_cycles
        self.linear_period = int(self.cycle_period * 0.5)
        self.beta_step = (self.beta_max - self.beta_min) / self.linear_period

        self.T = max((self.t - self.warmup) // self.cycle_period + 1, 1)
        self.tau = self.t - self.warmup - self.T * self.cycle_period
    
    def step(self):
        if self.t < self.warmup:
            self.beta = self.beta_min
            self.t += 1
        else:
            if self.tau == 0 and self.T < self.beta_num_cycles:
                self.beta = self.beta_min
                self.T += 1
            elif self.tau <= self.linear_period:
                self.beta = min(self.beta + self.beta_step, self.beta_max)
            self.tau = (self.tau + 1) % self.cycle_period
            self.t += 1

        return self.beta

class sigmoid_schedule(scheduler):
    def __init__(self, params: TrainingParams, init_beta: int, init_step: int) -> None:
        super(sigmoid_schedule, self).__init__(params, init_beta, init_step)
        self.diff = self.beta_max - self.beta_min
        self.anneal_rate = math.pow(0.01, 1 / self.beta_anneal_period)
        self.weight = 1
        
    def step(self):
        if self.t < self.warmup:
            self.beta = self.beta_min
        else:
            self.weight = math.pow(self.anneal_rate, self.t - self.warmup)
            self.beta = self.beta_min + self.diff * (1 - self.weight)
        self.t += 1
        return self.beta

class beta_annealing_schedule(object):
    def __init__(self, params: TrainingParams, init_beta: int=0, init_step: int=0) -> None:
        self.mode = params.beta_schedule_mode
        if self.mode not in ["linear", "sigmoid", "cyclical"]:
            self.mode = "sigmoid"
        if self.mode == "linear":
            self.schedule = linear_schedule(params, init_beta, init_step)
        elif self.mode == "cyclical":
            self.schedule = cyclical_schedule(params, init_beta, init_step)
        else:
            self.schedule = sigmoid_schedule(params, init_beta, init_step)
    
    def step(self):
        return self.schedule.step()