# ===========================================================================
# Project:      Sparse Model Soups
# File:         lr_schedulers.py
# Description:  All kinds of learning rate schedulers
# ===========================================================================

import math
import warnings
from bisect import bisect_right

import torch


class FixedLR(torch.optim.lr_scheduler._LRScheduler):
    """
    Just uses the learning rate given by a list
    """

    def __init__(self, optimizer, lrList, last_epoch=-1):
        self.lrList = lrList

        super(FixedLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        return [self.lrList[self.last_epoch] for _ in self.optimizer.param_groups]


class SequentialSchedulers(torch.optim.lr_scheduler.SequentialLR):
    """
    Repairs SequentialLR to properly use the last learning rate of the previous scheduler when reaching milestones
    """

    def __init__(self, **kwargs):
        self.optimizer = kwargs['schedulers'][0].optimizer
        super(SequentialSchedulers, self).__init__(**kwargs)

    def step(self):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        self._schedulers[idx].step()


class ChainedSchedulers(torch.optim.lr_scheduler.ChainedScheduler):
    """
    Repairs ChainedScheduler to avoid a known bug that makes it into the pytorch release soon
    """

    def __init__(self, **kwargs):
        self.optimizer = kwargs['schedulers'][0].optimizer
        super(ChainedSchedulers, self).__init__(**kwargs)


class CyclicLRAdaptiveBase(torch.optim.lr_scheduler.CyclicLR):

    def __init__(self, base_lr_scale_fn=None, **kwargs):
        self.base_lr_scale_fn = base_lr_scale_fn
        super(CyclicLRAdaptiveBase, self).__init__(**kwargs)

    def get_lr(self):
        """Calculates the learning rate at batch index. This function treats
        `self.last_epoch` as the last batch index.

        If `self.cycle_momentum` is ``True``, this function has a side effect of
        updating the optimizer's momentum.
        """

        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        cycle = math.floor(1 + self.last_epoch / self.total_size)
        x = 1. + self.last_epoch / self.total_size - cycle
        if x <= self.step_ratio:
            scale_factor = x / self.step_ratio
        else:
            scale_factor = (x - 1) / (self.step_ratio - 1)

        # Adjust the base lrs
        if self.base_lr_scale_fn:
            for entry_idx in range(len(self.base_lrs)):
                self.base_lrs[entry_idx] = self.max_lrs[entry_idx] * self.base_lr_scale_fn(cycle)

        lrs = []
        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
            base_height = (max_lr - base_lr) * scale_factor
            if self.scale_mode == 'cycle':
                lr = base_lr + base_height * self.scale_fn(cycle)
            else:
                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
            lrs.append(lr)

        if self.cycle_momentum:
            momentums = []
            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
                base_height = (max_momentum - base_momentum) * scale_factor
                if self.scale_mode == 'cycle':
                    momentum = max_momentum - base_height * self.scale_fn(cycle)
                else:
                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
                momentums.append(momentum)
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
                param_group['momentum'] = momentum

        return lrs
