import math
from omegaconf import DictConfig, OmegaConf

from pado.core.base.optimizer import PadoOptimizer
from pado.core.base.lr_scheduler import PadoScheduler
from pado.optim.lr_scheduler import register_scheduler

__all__ = ["CosineLR"]


@register_scheduler("CosineLR")
class CosineLR(PadoScheduler):

    def __init__(self,
                 optimizer: PadoOptimizer,
                 max_iters: int,
                 warmup_iters: int = 0,
                 keep_iters: int = 0,
                 min_lr: float = 1e-8,
                 mode: str = "min") -> None:
        super().__init__(optimizer, warmup_iters, keep_iters, min_lr, mode)

        if warmup_iters + keep_iters > max_iters:
            raise ValueError(f"CosineLR scheduler warmup + keep > max iterations.")
        self.max_iters = max_iters

    def state_dict(self) -> dict:
        d = super().state_dict()
        d["max_iters"] = self.max_iters
        return d

    def load_state_dict(self, state_dict: dict) -> None:
        super().load_state_dict(state_dict)
        self.max_iters = state_dict.get("max_iters")

    def _get_lr(self, initial_lr: float, param_group_index=None, **kwargs) -> float:
        if initial_lr <= self.min_lr:
            return initial_lr

        if self.num_iters < self.warmup_iters:
            lr = initial_lr * (self.num_iters + 1) / self.warmup_iters
        elif self.num_iters < self.warmup_iters + self.keep_iters:
            lr = initial_lr
        elif self.num_iters >= self.max_iters:
            lr = self.min_lr
        else:
            curr_iters = self.num_iters - self.warmup_iters - self.keep_iters
            end_iters = self.max_iters - self.warmup_iters - self.keep_iters

            lr = self.min_lr + 0.5 * (initial_lr - self.min_lr) * (
                    1 + math.cos(math.pi * curr_iters / end_iters))
        return lr

    @classmethod
    def from_config(cls, cfg: DictConfig, optimizer: PadoOptimizer) -> "CosineLR":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(optimizer, **cfg)
