import torch.optim.lr_scheduler as lr_scheduler
from .registry import Registry

SCHEDULERS = Registry("schedulers")


@SCHEDULERS.register_module()
class MultiStepLR(lr_scheduler.MultiStepLR):
    def __init__(self, optimizer, milestones, total_steps, gamma=0.1, last_epoch=-1, verbose=False):
        super().__init__(optimizer=optimizer,
                         milestones=[rate * total_steps for rate in milestones],
                         gamma=gamma,
                         last_epoch=last_epoch,
                         verbose=verbose)


@SCHEDULERS.register_module()
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
    def __init__(self, optimizer, milestones, total_steps, gamma=0.1,
                 warmup_rate=0.05, warmup_scale=1e-6, last_epoch=-1, verbose=False):
        milestones = [rate * total_steps for rate in milestones]

        def multi_step_with_warmup(s):
            factor = 1.0
            for i in range(len(milestones)):
                if s < milestones[i]:
                    break
                factor *= gamma

            if s <= warmup_rate * total_steps:
                warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (1 - warmup_scale)
            else:
                warmup_coefficient = 1.0
            return warmup_coefficient * factor
        super().__init__(optimizer=optimizer,
                         lr_lambda=multi_step_with_warmup,
                         last_epoch=last_epoch,
                         verbose=verbose)


@SCHEDULERS.register_module()
class PolyLR(lr_scheduler.LambdaLR):
    def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
        super().__init__(optimizer=optimizer,
                         lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
                         last_epoch=last_epoch,
                         verbose=verbose)


@SCHEDULERS.register_module()
class ExpLR(lr_scheduler.LambdaLR):
    def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
        super().__init__(optimizer=optimizer,
                         lr_lambda=lambda s: gamma**(s / total_steps),
                         last_epoch=last_epoch,
                         verbose=verbose)


@SCHEDULERS.register_module()
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
    def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
        super().__init__(optimizer=optimizer,
                         T_max=total_steps,
                         eta_min=eta_min,
                         last_epoch=last_epoch,
                         verbose=verbose)


@SCHEDULERS.register_module()
class OneCycleLR(lr_scheduler.OneCycleLR):
    r"""
    torch.optim.lr_scheduler.OneCycleLR, Block total_steps
    """
    def __init__(self,
                 optimizer,
                 max_lr,
                 total_steps=None,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 cycle_momentum=True,
                 base_momentum=0.85,
                 max_momentum=0.95,
                 div_factor=25.,
                 final_div_factor=1e4,
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):
        super().__init__(optimizer=optimizer,
                         max_lr=max_lr,
                         total_steps=total_steps,
                         pct_start=pct_start,
                         anneal_strategy=anneal_strategy,
                         cycle_momentum=cycle_momentum,
                         base_momentum=base_momentum,
                         max_momentum=max_momentum,
                         div_factor=div_factor,
                         final_div_factor=final_div_factor,
                         three_phase=three_phase,
                         last_epoch=last_epoch,
                         verbose=verbose)


def build_scheduler(cfg, optimizer):
    cfg.optimizer = optimizer
    return SCHEDULERS.build(cfg=cfg)
