import numpy as np

class OneCycle():
    def __init__(self, optimizer, max_lr, n_epochs,
                 anneal=0.1, min_mom=0.85, max_mom=0.95,
                 use_beta=False, min_lr_divider=4):

        min_lr = max_lr/min_lr_divider

        self.optimizer = optimizer

        # pre-compute incremental values
        self.epoch = 0
        self.end_epoch = round(n_epochs*0.9)
        self.mid_epoch = self.end_epoch//2

        self.lr = min_lr
        self.lr_inc = (max_lr-min_lr)/self.mid_epoch
        self.anneal = anneal

        self.mom = max_mom
        self.mom_inc = (max_mom-min_mom)/self.mid_epoch
        self.use_beta = use_beta  # Account for Adam(W)

    def _update(self):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
            if not self.use_beta:  # Update momentum for SGD
                param_group["momentum"] = self.mom
            else:
                (_, beta2) = param_group["betas"]
                param_group["betas"] = (self.mom, beta2)

    def _compute(self):
        if self.epoch < self.mid_epoch:
            lr_inc, mom_inc = self.lr_inc, -self.mom_inc
        elif self.epoch < self.end_epoch:
            lr_inc, mom_inc = -self.lr_inc, self.mom_inc
        else:
            lr_inc, mom_inc = -self.lr_inc*self.anneal, 0
        self.lr += lr_inc
        self.mom += mom_inc

    def step(self):
        self.epoch += 1
        self._compute()
        self._update()


class LrFinder():
    def __init__(self, optimizer, max_lr, n_epochs):
        self.optimizer = optimizer
        for param_group in self.optimizer.param_groups:
            min_lr = param_group['lr']
        self.lr_inc = (max_lr-min_lr)/n_epochs

    def step(self):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] += self.lr_inc


class CosineLR():
    def __init__(self, optimizer, warmup_length, n_epochs):
        self.optimizer = optimizer
        for param_group in self.optimizer.param_groups:
            self.min_lr = param_group['lr']
        self.epoch = 0
        self.warmup_length = warmup_length
        self.n_epochs = n_epochs
        self.step()

    def step(self):
        if self.epoch < self.warmup_length:
            lr = self.min_lr * (self.epoch + 1) / self.warmup_length
        else:
            e = self.epoch - self.warmup_length
            es = self.n_epochs - self.warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * self.min_lr
        self.epoch += 1
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr