import torch


class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """Implements the learning rate schedule defined in the AlphaFold 2
    supplement. A linear warmup is followed by a plateau at the maximum
    learning rate and then exponential decay.

    Note that the initial learning rate of the optimizer in question is
    ignored; use this class' base_lr parameter to specify the starting
    point of the warmup.
    """

    def __init__(
        self,
        optimizer,
        last_epoch: int = -1,
        verbose: bool = False,
        base_lr: float = 0.0,
        max_lr: float = 0.001,
        warmup_no_steps: int = 1000,
        start_decay_after_n_steps: int = 50000,
        decay_every_n_steps: int = 50000,
        decay_factor: float = 0.95,
    ):
        step_counts = {
            "warmup_no_steps": warmup_no_steps,
            "start_decay_after_n_steps": start_decay_after_n_steps,
        }

        for k, v in step_counts.items():
            if v < 0:
                raise ValueError(f"{k} must be nonnegative")

        if warmup_no_steps > start_decay_after_n_steps:
            raise ValueError(
                "warmup_no_steps must not exceed start_decay_after_n_steps"
            )

        self.optimizer = optimizer
        self.last_epoch = last_epoch
        self.verbose = verbose
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.warmup_no_steps = warmup_no_steps
        self.start_decay_after_n_steps = start_decay_after_n_steps
        self.decay_every_n_steps = decay_every_n_steps
        self.decay_factor = decay_factor

        super(AlphaFoldLRScheduler, self).__init__(
            optimizer,
            last_epoch=last_epoch,
            verbose=verbose,
        )

    def state_dict(self):
        state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}

        return state_dict

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            raise RuntimeError(
                "To get the last learning rate computed by the scheduler, use "
                "get_last_lr()"
            )

        step_no = self.last_epoch

        if step_no <= self.warmup_no_steps:
            lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
        elif step_no > self.start_decay_after_n_steps:
            steps_since_decay = step_no - self.start_decay_after_n_steps
            exp = (steps_since_decay // self.decay_every_n_steps) + 1
            lr = self.max_lr * (self.decay_factor**exp)
        else:  # plateau
            lr = self.max_lr

        return [lr for group in self.optimizer.param_groups]


class TestAF2LRScheduler(AlphaFoldLRScheduler):
    def __init__(
        self,
        optimizer,
        last_epoch: int = -1,
        verbose: bool = False,
        base_lr: float = 0.0,
        max_lr: float = 0.0001,
        warmup_no_steps: int = 10,
        start_decay_after_n_steps: int = 100,
        decay_every_n_steps: int = 10,
        decay_factor: float = 0.95,
    ):
        super().__init__(
            optimizer,
            last_epoch,
            verbose,
            base_lr,
            max_lr,
            warmup_no_steps,
            start_decay_after_n_steps,
            decay_every_n_steps,
            decay_factor,
        )
