from omegaconf import DictConfig, OmegaConf

from pado.core.base.optimizer import PadoOptimizer
from pado.core.base.lr_scheduler import PadoScheduler
from pado.optim.lr_scheduler import register_scheduler

__all__ = ["ConstantLR"]


@register_scheduler("ConstantLR")
class ConstantLR(PadoScheduler):

    def __init__(self,
                 optimizer: PadoOptimizer,
                 warmup_iters: int = 0,
                 mode: str = "min") -> None:
        super().__init__(optimizer, warmup_iters, 0, min_lr=0, mode=mode)

    def _get_lr(self, initial_lr: float, param_group_index=None, **kwargs) -> float:
        if initial_lr <= self.min_lr:
            return initial_lr

        if self.num_iters < self.warmup_iters:
            lr = initial_lr * (self.num_iters + 1) / self.warmup_iters
        else:
            lr = initial_lr
        return lr

    @classmethod
    def from_config(cls, cfg: DictConfig, optimizer: PadoOptimizer) -> "ConstantLR":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        if "keep_iters" in cfg:
            del cfg["keep_iters"]
        return cls(optimizer, **cfg)
