from torch.optim import lr_scheduler
from timm.scheduler import CosineLRScheduler
from .Build_Scheduler import schedulers


@schedulers.register_module()
class MultiStepLR(lr_scheduler.MultiStepLR):
    def __init__(
            self,
            optimizer,
            milestones,
            total_steps,
            gamma=0.1,
            last_epoch=-1,
            verbose=False,
    ):
        super().__init__(
            optimizer=optimizer,
            milestones=[rate * total_steps for rate in milestones],
            gamma=gamma,
            last_epoch=last_epoch,
            verbose=verbose,
        )


@schedulers.register_module()
class MultiStepWithWarmupLR(lr_scheduler.LambdaLR):
    def __init__(
            self,
            optimizer,
            milestones,
            total_steps,
            gamma=0.1,
            warmup_rate=0.05,
            warmup_scale=1e-6,
            last_epoch=-1,
            verbose=False,
    ):
        milestones = [rate * total_steps for rate in milestones]

        def multi_step_with_warmup(s):
            factor = 1.0
            for i in range(len(milestones)):
                if s < milestones[i]:
                    break
                factor *= gamma

            if s <= warmup_rate * total_steps:
                warmup_coefficient = 1 - (1 - s / warmup_rate / total_steps) * (
                        1 - warmup_scale
                )
            else:
                warmup_coefficient = 1.0
            return warmup_coefficient * factor

        super().__init__(
            optimizer=optimizer,
            lr_lambda=multi_step_with_warmup,
            last_epoch=last_epoch,
            verbose=verbose,
        )


@schedulers.register_module()
class PolyLR(lr_scheduler.LambdaLR):
    def __init__(self, optimizer, total_steps, power=0.9, last_epoch=-1, verbose=False):
        super().__init__(
            optimizer=optimizer,
            lr_lambda=lambda s: (1 - s / (total_steps + 1)) ** power,
            last_epoch=last_epoch,
            verbose=verbose,
        )


@schedulers.register_module()
class ExpLR(lr_scheduler.LambdaLR):
    def __init__(self, optimizer, total_steps, gamma=0.9, last_epoch=-1, verbose=False):
        super().__init__(
            optimizer=optimizer,
            lr_lambda=lambda s: gamma ** (s / total_steps),
            last_epoch=last_epoch,
            verbose=verbose,
        )


@schedulers.register_module()
class CosineAnnealingLR(lr_scheduler.CosineAnnealingLR):
    def __init__(self, optimizer, total_steps, eta_min=0, last_epoch=-1, verbose=False):
        super().__init__(
            optimizer=optimizer,
            T_max=total_steps,
            eta_min=eta_min,
            last_epoch=last_epoch,
            verbose=verbose,
        )


@schedulers.register_module()
class CosineLRScheduler(CosineLRScheduler):
    def __init__(self, optimizer, t_initial, lr_min, warmup_lr_init, warmup_t, cycle_limit=1, t_in_epochs=True):
        super(CosineLRScheduler, self).__init__(
            optimizer,
            t_initial=t_initial,
            lr_min=lr_min,
            warmup_lr_init=warmup_lr_init,
            warmup_t=warmup_t,
            cycle_limit=cycle_limit,
            t_in_epochs=t_in_epochs
        )


@schedulers.register_module()
class OneCycleLR(lr_scheduler.OneCycleLR):
    r"""
    torch.optim.lr_scheduler.OneCycleLR, Block total_steps
    """

    def __init__(
            self,
            optimizer,
            max_lr,
            total_steps=None,
            pct_start=0.3,
            anneal_strategy="cos",
            cycle_momentum=True,
            base_momentum=0.85,
            max_momentum=0.95,
            div_factor=25.0,
            final_div_factor=1e4,
            three_phase=False,
            last_epoch=-1,
            verbose=False,
    ):
        super().__init__(
            optimizer=optimizer,
            max_lr=max_lr,
            total_steps=total_steps,
            pct_start=pct_start,
            anneal_strategy=anneal_strategy,
            cycle_momentum=cycle_momentum,
            base_momentum=base_momentum,
            max_momentum=max_momentum,
            div_factor=div_factor,
            final_div_factor=final_div_factor,
            three_phase=three_phase,
            last_epoch=last_epoch,
            verbose=verbose,
        )
