# Copyright (c) OpenMMLab. All rights reserved.
import math

from mmengine.optim.scheduler import CosineAnnealingParamScheduler

from mmpretrain.registry import PARAM_SCHEDULERS


class WeightDecaySchedulerMixin:
    """A mixin class for learning rate schedulers."""

    def __init__(self, optimizer, *args, **kwargs):
        super().__init__(optimizer, 'weight_decay', *args, **kwargs)


@PARAM_SCHEDULERS.register_module()
class CosineAnnealingWeightDecay(WeightDecaySchedulerMixin,
                                 CosineAnnealingParamScheduler):
    """Set the weight decay value of each parameter group using a cosine
    annealing schedule.

    If the weight decay was set to be 0 initially, the weight decay value will
    be 0 constantly during the training.
    """

    def _get_value(self) -> list:
        """Compute value using chainable form of the scheduler."""

        def _get_eta_min(base_value):
            if self.eta_min_ratio is None:
                return self.eta_min
            return base_value * self.eta_min_ratio

        if self.last_step == 0:
            return [
                group[self.param_name] for group in self.optimizer.param_groups
            ]
        elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0:
            weight_decay_value_list = []
            for base_value, group in zip(self.base_values,
                                         self.optimizer.param_groups):
                if base_value == 0:
                    group_value = 0
                else:
                    group_value = group[self.param_name] + (
                        base_value - _get_eta_min(base_value)) * (
                            1 - math.cos(math.pi / self.T_max)) / 2
                weight_decay_value_list.append(group_value)
            return weight_decay_value_list

        weight_decay_value_list = []
        for base_value, group in zip(self.base_values,
                                     self.optimizer.param_groups):
            if base_value == 0:
                group_value = 0
            else:
                group_value = (
                    1 + math.cos(math.pi * self.last_step / self.T_max)) / (
                        1 + math.cos(math.pi *
                                     (self.last_step - 1) / self.T_max)
                    ) * (group[self.param_name] -
                         _get_eta_min(base_value)) + _get_eta_min(base_value)
            weight_decay_value_list.append(group_value)
        return weight_decay_value_list
