from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from typing import Union

class TransformerLR(_LRScheduler):

    def __init__(self, optimizer, warmup_epochs=1000, last_epoch=-1, verbose=False):
        self.warmup_epochs = warmup_epochs
        self.normalize = self.warmup_epochs**0.5
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        step = self.last_epoch + 1
        scale = self.normalize * min(step**-0.5, step * self.warmup_epochs**-1.5)
        return [base_lr * scale for base_lr in self.base_lrs]

class ZeroToLinearLR(_LRScheduler):
    def __init__(self, optimizer: Optimizer, num_training_steps, zero_steps: Union[int, list]=1000):
        self.num_training_steps = num_training_steps
        if isinstance(zero_steps, int):
            zero_steps = [zero_steps] * len(optimizer.param_groups)
        self.zero_steps = zero_steps
        super().__init__(optimizer)

    def _get_group_lr(self, step, base_lr, zero_step):
        if step < zero_step:
            return 0.
        else:
            return base_lr * (self.num_training_steps - step) / (self.num_training_steps - zero_step) 
    
    def get_lr(self):
        step = self.last_epoch + 1
        return [self._get_group_lr(step, base_lr, zero_step) for base_lr, zero_step in zip(self.base_lrs, self.zero_steps)]