import math
from torch.optim.lr_scheduler import _LRScheduler

class LinearWarmupCosineDecayScheduler(_LRScheduler):
    def __init__(self, optimizer, init_value: float, peak_value: float, end_value: float,
                 warmup_steps: int, decay_steps: int, exponent: float = 1.0, last_epoch: int = -1):
        """
        自定义学习率调度器：先线性 warmup 到 peak_value，再采用余弦衰减到 end_value。

        参数：
            optimizer: 优化器对象。
            init_value: warmup 初始学习率（对应 optax 中的 init_value）。
            peak_value: warmup 后的最高学习率。
            end_value: 衰减结束时的学习率。
            warmup_steps: warmup 阶段的步数。
            decay_steps: 总衰减步数（包含 warmup 步数，故衰减阶段步数 = decay_steps - warmup_steps）。
            exponent: 衰减指数，用于调整余弦衰减曲线的形状（默认 1.0）。
            last_epoch: 上一次的 epoch 数（通常保持默认 -1）。
        """
        self.init_value = init_value
        self.peak_value = peak_value
        self.end_value = end_value
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps
        self.exponent = exponent

        # 计算余弦衰减时的 alpha 参数，与 optax 中保持一致：
        # 若 peak_value 为 0，则 alpha 直接设置为 0.0；否则 alpha = end_value / peak_value
        self.alpha = 0.0 if peak_value == 0.0 else end_value / peak_value

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        # 当前步数（实际上 last_epoch 表示上一次调用 step 后的计数）
        current_step = self.last_epoch  # _LRScheduler 中 last_epoch 初始化为 -1，所以第 0 步为 last_epoch + 1 = 0

        if current_step <= self.warmup_steps:
            # 线性 warmup：从 init_value 上升到 peak_value
            lr = self.init_value + (self.peak_value - self.init_value) * (current_step / self.warmup_steps)
        elif current_step <= self.decay_steps:
            # 余弦衰减阶段：衰减步数为 decay_steps - warmup_steps
            decay_phase_steps = self.decay_steps - self.warmup_steps
            progress = (current_step - self.warmup_steps) / decay_phase_steps
            # 计算余弦值，其区间从 cos(0)=1 到 cos(pi)= -1，转换到 [0,1] 区间
            cosine_decay = (math.cos(math.pi * progress) + 1) / 2
            # 若使用 exponent 调整形状，则对余弦衰减值进行幂运算
            cosine_decay = cosine_decay ** self.exponent
            # 衰减后的学习率计算：
            # 当 progress=0 时，lr = peak_value；progress=1 时，lr = peak_value*(alpha + (1-alpha)*0**exponent)= peak_value*alpha = end_value
            lr = self.peak_value * (self.alpha + (1 - self.alpha) * cosine_decay)
        else:
            # 超过总步数后，保持 end_value
            lr = self.end_value

        # 为每个参数组返回相同的 lr
        return [lr for _ in self.optimizer.param_groups]