import math
from torch.optim.lr_scheduler import _LRScheduler

class CosineAnnealingWithFlatTail(_LRScheduler):
    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, flat_after=None):
        """
        Args:
            optimizer (Optimizer): wrapped optimizer
            T_max (int): number of epochs for cosine annealing
            eta_min (float): minimum learning rate
            flat_after (int): epoch after which learning rate stays at eta_min (optional, defaults to T_max)
        """
        self.T_max = T_max
        self.eta_min = eta_min
        self.flat_after = flat_after if flat_after is not None else T_max
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch >= self.flat_after:
            return [self.eta_min for _ in self.base_lrs]
        else:
            return [
                self.eta_min + (base_lr - self.eta_min) * 
                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                for base_lr in self.base_lrs
            ]