from typing import List, Iterable, Union
from omegaconf import DictConfig
import logging
from pado.core.base.optimizer import PadoOptimizer

__all__ = ["PadoScheduler", "PadoSchedulerList"]

logger = logging.getLogger("pado")


class PadoScheduler(object):
    """
    Base class for LR schedulers in Pado framework.
    * for default, warm up and keep is supported.
    """

    def __init__(self,
                 optimizer: PadoOptimizer,
                 warmup_iters: int = 0,
                 keep_iters: int = 0,
                 min_lr: float = 1e-8,
                 mode: str = "min") -> None:
        self.optimizer = optimizer
        self.best = None

        mode = mode.lower()
        if mode not in ("min", "max"):
            raise ValueError(f"Scheduler mode should be either `min` or `max`, got {mode}.")
        self.mode = mode

        # set initial_lr to track
        for param_group in optimizer.param_groups:
            if "initial_lr" not in param_group:
                param_group.setdefault("initial_lr", param_group["lr"])

        self.initial_lrs = [g["initial_lr"] for g in optimizer.param_groups]

        self.warmup_iters = warmup_iters
        self.keep_iters = keep_iters
        self.min_lr = min_lr

        self._num_iters = -1  # should be set outside
        self._patience_count = 0
        self._step_called = False  # simple flag to check if step() is called at least once

    @property
    def current_patience(self) -> int:
        return self._patience_count

    @property
    def num_iters(self) -> int:
        return self._num_iters

    def set_num_iters(self, iters: int) -> None:
        self._num_iters = max(iters, -1)

    def state_dict(self) -> dict:
        return {
            "best": self.best,
            "warmup_iters": self.warmup_iters,
            "keep_iters": self.keep_iters,
            "min_lr": self.min_lr,
            "mode": self.mode,
            "num_iters": self._num_iters,  # will be override by outer Trainer.
            "patience_count": self._patience_count,
        }

    def load_state_dict(self, state_dict: dict) -> None:
        self.best = state_dict.get("best", None)
        self.warmup_iters = state_dict.get("warmup_iters", 0)
        self.keep_iters = state_dict.get("keep_iters", 0)
        self.min_lr = state_dict.get("min_lr", 1e-8)
        self.mode = state_dict.get("mode", "min")

        self._num_iters = state_dict.get("num_iters", -1)
        self._patience_count = state_dict.get("patience_count", 0)
        self._step_called = True

    def update_best(self, criterion_value) -> bool:
        """
        Update best and return whether the best is updated.
        """
        if self.best is None:
            self.best = criterion_value
            self._patience_count = 0
            logger.info(f"... best set, {self.best:.6f}")
            return True

        prev_best = self.best
        if self.mode == "max":  # larger better
            self.best = max(self.best, criterion_value)
        else:  # smaller better
            self.best = min(self.best, criterion_value)
        is_updated = (self.best == criterion_value)
        if is_updated:
            self._patience_count = 0
            logger.info(f"... best updated, (old -> new): {prev_best:.6f} -> {self.best:.6f}")
        else:
            self._patience_count += 1
            logger.info(f"... best NOT updated, (best / new): {prev_best:.6f} / {criterion_value:.6f}\n"
                        f"... best was before {self.current_patience} checks.")
        return is_updated

    def step(self, criterion=None) -> None:
        self._num_iters += 1

        if criterion is not None:
            _ = self.update_best(criterion)

        for i, param_group in enumerate(self.optimizer.param_groups):
            group_lr = self._get_lr(param_group["initial_lr"], group_index=i)
            param_group["lr"] = group_lr

        self._step_called = True

    def _get_lr(self, initial_lr: float, group_index: int, **kwargs) -> float:
        """
        Compute the current LR for each param group.
        """
        raise NotImplementedError

    @classmethod
    def from_config(cls, cfg: DictConfig, optimizer: PadoOptimizer):
        raise NotImplementedError


class PadoSchedulerList(object):
    """
    Container to hold multiple PadoSchedulers.
    """

    def __init__(self, schedulers: Union[PadoScheduler, Iterable[PadoScheduler]]):
        if isinstance(schedulers, PadoScheduler):
            schedulers = [schedulers]
        self.schedulers = list(schedulers)

    def __len__(self) -> int:
        return len(self.schedulers)

    def set_num_iters(self, iters: int) -> None:
        for sched in self.schedulers:
            sched.set_num_iters(iters)

    def update_best(self, criterion_value) -> List[bool]:
        is_updated = []
        for sched in self.schedulers:
            u = sched.update_best(criterion_value=criterion_value)
            is_updated.append(u)
        return is_updated

    def step(self, criterion=None) -> None:
        for sched in self.schedulers:
            sched.step(criterion=criterion)

    def state_dict(self) -> List:
        states = []
        for sched in self.schedulers:
            states.append(sched.state_dict())
        return states

    def load_state_dict(self, state_dict: List) -> None:
        # state_dict should be loaded as same order as saved.
        if len(state_dict) != len(self):
            raise ValueError(f"#schedulers in state dict mismatch, {len(self)} vs {len(state_dict)}.")
        for state, sched in zip(state_dict, self.schedulers):
            sched.load_state_dict(state)
