import torch


class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """ Implements the learning rate schedule defined in the AlphaFold 2
        supplement. A linear warmup is followed by a plateau at the maximum
        learning rate and then exponential decay.
         
        Note that the initial learning rate of the optimizer in question is 
        ignored; use this class' base_lr parameter to specify the starting 
        point of the warmup.
    """
    def __init__(self, 
        optimizer, 
        last_epoch: int = -1, 
        verbose: bool = False,
        base_lr: float = 0.,
        max_lr: float = 0.001,
        warmup_no_steps: int = 1000,  # 1000
        start_decay_after_n_steps: int = 50000,
        decay_every_n_steps: int = 50000,
        decay_factor: float = 0.95,
    ):
        step_counts = {
            "warmup_no_steps": warmup_no_steps,
            "start_decay_after_n_steps": start_decay_after_n_steps,
        }

        for k,v in step_counts.items():
            if(v < 0):
                raise ValueError(f"{k} must be nonnegative")

        if(warmup_no_steps > start_decay_after_n_steps):
            raise ValueError(
                "warmup_no_steps must not exceed start_decay_after_n_steps"
            )

        self.optimizer = optimizer
        self.last_epoch = last_epoch
        self.verbose = verbose
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.warmup_no_steps = warmup_no_steps
        self.start_decay_after_n_steps = start_decay_after_n_steps
        self.decay_every_n_steps = decay_every_n_steps
        self.decay_factor = decay_factor

        super(AlphaFoldLRScheduler, self).__init__(
            optimizer,
            last_epoch=last_epoch, 
            verbose=verbose,
        )

    def state_dict(self):
        state_dict = {
            k:v for k,v in self.__dict__.items() if k not in ["optimizer"]
        }

        return state_dict

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)

    def get_lr(self):
        if(not self._get_lr_called_within_step):
            raise RuntimeError(
                "To get the last learning rate computed by the scheduler, use "
                "get_last_lr()"
            )

        step_no = self.last_epoch

        if(step_no <= self.warmup_no_steps):
            lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
        elif(step_no > self.start_decay_after_n_steps):
            steps_since_decay = step_no - self.start_decay_after_n_steps
            exp = (steps_since_decay // self.decay_every_n_steps) + 1
            lr = self.max_lr * (self.decay_factor ** exp)
        else: # plateau
            lr = self.max_lr

        return [lr for group in self.optimizer.param_groups]
