import math
import logging
from omegaconf import DictConfig, OmegaConf

from pado.core.base.optimizer import PadoOptimizer
from pado.core.base.lr_scheduler import PadoScheduler
from pado.optim.lr_scheduler import register_scheduler

__all__ = ["ReduceLROnPlateau"]

logger = logging.getLogger("pado")


@register_scheduler("ReduceLROnPlateau")
class ReduceLROnPlateau(PadoScheduler):

    def __init__(self,
                 optimizer: PadoOptimizer,
                 multiply_factor: float = 0.1,
                 patience: int = 1,
                 warmup_iters: int = 0,
                 keep_iters: int = 0,
                 min_lr: float = 1e-8,
                 mode: str = "min") -> None:
        super().__init__(optimizer, warmup_iters, keep_iters, min_lr, mode)
        self.multiply_factor = multiply_factor
        self.patience = max(patience, 1)
        self._drop_count = 0

    def state_dict(self) -> dict:
        d = super().state_dict()
        d["multiply_factor"] = self.multiply_factor
        d["patience"] = self.patience
        d["drop_count"] = self._drop_count
        return d

    def load_state_dict(self, state_dict: dict) -> None:
        super().load_state_dict(state_dict)
        self.multiply_factor = state_dict.get("multiply_factor", 0.1)
        self.patience = state_dict.get("patience", 1)
        self._drop_count = state_dict.get("drop_count", 0)

    def update_best(self, criterion_value) -> bool:
        """
        Update best and return whether the best is updated.
        """
        if self.best is None:
            self.best = criterion_value
            self._drop_count = 0
            self._patience_count = 0
            logger.info(f"... best set, {self.best:.6f}")
            return True

        prev_best = self.best
        if self.mode == "max":  # larger better
            self.best = max(self.best, criterion_value)
        else:  # smaller better
            self.best = min(self.best, criterion_value)
        is_updated = (self.best == criterion_value)
        if is_updated:
            self._patience_count = 0
            logger.info(f"... best updated, (old -> new): {prev_best:.6f} -> {self.best:.6f}")
        else:
            self._patience_count += 1
            s = f"... best NOT updated, (best / new): {prev_best:.6f} / {criterion_value:.6f}\n" \
                f"... best was before: {self.current_patience} checks.\n"
            if self._patience_count >= self.patience:
                self._patience_count = 0  # should we?
                self._drop_count += 1
                s += f"... drop count increased ({self._drop_count} times dropped)"
            else:
                s += f"... patience increased ({self._patience_count} / {self.patience})"
            logger.info(s)
        return is_updated

    def _get_lr(self, initial_lr: float, param_group_index=None, **kwargs) -> float:
        if self.num_iters < self.warmup_iters:
            lr = initial_lr * (self.num_iters + 1) / self.warmup_iters
        elif self.num_iters < self.warmup_iters + self.keep_iters:
            lr = initial_lr
        else:
            lr = initial_lr * math.pow(self.multiply_factor, self._drop_count)
            lr = max(lr, self.min_lr)
        return lr

    @classmethod
    def from_config(cls, cfg: DictConfig, optimizer: PadoOptimizer) -> "ReduceLROnPlateau":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(optimizer, **cfg)
