import math
try:
    from torch.optim.lr_scheduler import LRScheduler
except:
    from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.optim import Optimizer


class CosineScheduler(LRScheduler):
    r"""
    After a warmup period during which learning rate increases linearly between 0 and the start_lr,
    The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
    """ # noqa

    def __init__(
        self,
        optimizer: Optimizer,
        start_lr: float,
        n_warmup_steps: int,
        n_steps: int,
        cur_step: int = 0,
        lr_end_restart: int = 0,
        resume_no_optimze=0,
    ) -> None:
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = n_steps
        self.optimizer = optimizer
        self.cur_step = cur_step
        self._current_lr = None
        self._start_lr = start_lr
        self.start_lr: list[float] = []
        self.lr_end_restart = lr_end_restart
        self.resume_step = cur_step
        self.resume_no_optimze = resume_no_optimze
        for group in self.optimizer.param_groups:
            self.start_lr.append(group["lr"])

        self.step(self.cur_step)

    def get_lr_warmup(self, cur_step: int, base_lr: float) -> float:
        return base_lr * cur_step / self.n_warmup_steps

    def get_lr_decay(self, cur_step: int, base_lr: float) -> float:
        progress = (cur_step - self.n_warmup_steps) / max(1, (self.n_steps - self.n_warmup_steps))
        if progress > 1:
            if self.lr_end_restart == 0:
                progress = 1
            elif self.lr_end_restart == 1:
                progress = progress
            elif self.lr_end_restart == 2:
                progress = int(progress) * 2 + (progress - 1)

        return max(base_lr * 0.1, base_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))

    def get_lr(self, base_lr):
        assert self.cur_step >= 0
        if self.resume_step + self.resume_no_optimze > self.cur_step:
            print("resume no optimize")
            return 0

        if self.cur_step < self.n_warmup_steps:
            return self.get_lr_warmup(self.cur_step, base_lr)
        else:
            return self.get_lr_decay(self.cur_step, base_lr)

    @property
    def current_lr(self):
        return self._current_lr

    def step(self, num_iter=None) -> None:
        if num_iter is None:
            num_iter = self.cur_step + 1
        self.cur_step = num_iter

        self._current_lr = self.get_lr(self._start_lr)
        for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
            group["lr"] = self.get_lr(base_lr)

    def state_dict(self):
        return {
            "_start_lr": self._start_lr,
            "start_lr": self.start_lr,
            "warmup_iter": self.n_warmup_steps,
            "end_iter": self.n_steps,
            "num_iter": self.cur_step,
        }

    def load_state_dict(self, state_dict):
        self._start_lr = state_dict["_start_lr"]
        self.start_lr = state_dict["start_lr"]
        self.n_warmup_steps = state_dict["warmup_iter"]
        self.n_steps = state_dict["end_iter"]
        self.cur_step = state_dict["num_iter"]

        self.step(self.cur_step)

    def get_last_lr(self) -> float:
        assert self._current_lr is not None
        return self._current_lr


class StableDrop(LRScheduler):
    r"""
    After a warmup period during which learning rate increases linearly between 0 and the start_lr,
    The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
    """ # noqa

    def __init__(
        self,
        optimizer: Optimizer,
        max_lr: float,
        n_warmup_steps: int,
        n_steps: int,
        n_drop_steps: int = 0,
        cur_step: int = 0,
        resume_no_optimze: int = 0,
    ):
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = n_steps
        self.n_drop_steps = n_drop_steps
        self.optimizer = optimizer
        self.cur_step = cur_step
        self._cur_lr = None
        self.max_lr = max_lr
        self.start_lr: list[float] = []
        self.resume_step = cur_step
        self.resume_no_optimze = resume_no_optimze
        for group in self.optimizer.param_groups:
            self.start_lr.append(group["lr"])

        self.step(self.cur_step)

    def get_lr_warmup(self, cur_iter: int, base_lr: float, warmup_iter: int) -> float:
        return base_lr * cur_iter / warmup_iter

    def get_lr_stable(self, cur_iter: int, base_lr: float) -> float:
        return base_lr

    def get_lr_drop(self, cur_iter: int, base_lr: float) -> float:
        # progress = (self.end_iter - num_iter) / self.drop_iter
        return base_lr * (0.1 + max(0.9 * (self.n_steps - cur_iter) / self.n_drop_steps, 0))

    def get_lr(self, base_lr: float):
        assert self.cur_step >= 0
        if self.resume_step + self.resume_no_optimze > self.cur_step:
            return self.get_lr_warmup(self.cur_step - self.resume_step, base_lr, self.resume_no_optimze)

        if self.cur_step < self.n_warmup_steps:
            return self.get_lr_warmup(self.cur_step, base_lr, self.n_warmup_steps)

        if self.cur_step > self.n_steps - self.n_drop_steps:
            return self.get_lr_drop(self.cur_step, base_lr)

        return self.get_lr_stable(self.cur_step, base_lr)

    def get_last_lr(self) -> float:
        assert self._cur_lr is not None
        return self._cur_lr

    @property
    def current_lr(self):
        return self._cur_lr

    def step(self, cur_iter=None) -> None:
        if cur_iter is None:
            self.cur_step = self.cur_step + 1
        else:
            self.cur_step = cur_iter

        self._cur_lr = self.get_lr(self.max_lr)
        for group, base_lr in zip(self.optimizer.param_groups, self.start_lr):
            group["lr"] = self.get_lr(base_lr)

    def state_dict(self):
        return {
            "_start_lr": self.max_lr,
            "start_lr": self.start_lr,
            "warmup_iter": self.n_warmup_steps,
            "end_iter": self.n_steps,
            "num_iter": self.cur_step,
        }

    def load_state_dict(self, state_dict):
        self.max_lr: float = state_dict["_start_lr"]
        self.start_lr: list[float] = state_dict["start_lr"]
        self.n_warmup_steps: int = state_dict["warmup_iter"]
        self.n_steps: int = state_dict["end_iter"]
        self.cur_step: int = state_dict["num_iter"]

        self.step(self.cur_step)
