import math

class CustomCosineAnnealingLR:
    def __init__(self, optimizer, T_max, lr_name='lr_eta', eta_min=0, last_epoch=-1):
        self.optimizer = optimizer
        self.T_max = T_max
        self.lr_name = lr_name
        self.eta_min = eta_min
        self.last_epoch = last_epoch
        self.base_lrs = list(map(lambda group: group[lr_name], optimizer.param_groups))
        self.step()

    def get_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                for base_lr in self.base_lrs]

    def step(self):
        self.last_epoch += 1
        lrs = self.get_lr()
        for param_group, lr in zip(self.optimizer.param_groups, lrs):
            param_group[self.lr_name] = lr
            # print(lr)