from typing import Tuple
from omegaconf import DictConfig, OmegaConf

from pado.core.base.optimizer import PadoOptimizer
from pado.core.base.lr_scheduler import PadoScheduler
from pado.optim.lr_scheduler import register_scheduler

__all__ = ["StepLR"]


@register_scheduler("StepLR")
class StepLR(PadoScheduler):

    def __init__(self,
                 optimizer: PadoOptimizer,
                 steps: Tuple[int],
                 multiply_factor: float = 0.1,
                 warmup_iters: int = 0,
                 min_lr: float = 1e-8,
                 mode: str = "min") -> None:
        super().__init__(optimizer, warmup_iters, 0, min_lr, mode)

        if steps[0] == 0:
            steps = steps[1:]
        if steps != sorted(steps):
            steps = sorted(steps)  # sort ascending
        if warmup_iters >= steps[0]:
            raise ValueError("Warmup steps should be smaller than first step[0].")
        self.steps = steps
        self.multiply_factor = multiply_factor

    def state_dict(self) -> dict:
        d = super().state_dict()
        d["steps"] = self.steps
        d["multiply_factor"] = self.multiply_factor
        return d

    def load_state_dict(self, state_dict: dict) -> None:
        super().load_state_dict(state_dict)
        self.steps = state_dict.get("steps")
        self.multiply_factor = state_dict.get("multiply_factor", 0.1)

    def _get_lr(self, initial_lr: float, param_group_index=None, **kwargs) -> float:
        if self.num_iters < self.warmup_iters:
            lr = initial_lr * (self.num_iters + 1) / self.warmup_iters
        else:
            curr_iters = self.num_iters - self.warmup_iters
            lr = initial_lr
            for s in self.steps:
                if curr_iters >= s:
                    lr *= self.multiply_factor
            lr = max(lr, self.min_lr)
        return lr

    @classmethod
    def from_config(cls, cfg: DictConfig, optimizer: PadoOptimizer) -> "StepLR":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(optimizer, **cfg)
