import torch
import pytorch_lightning as pl
import numpy as np
from torch.optim.lr_scheduler import _LRScheduler

class CosineScheduler(_LRScheduler):
    def __init__(self, optimizer, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
        """
        初始化Cosine调度器
        :param optimizer: PyTorch优化器
        :param base_value: 初始学习率
        :param final_value: 最终学习率
        :param total_iters: 总的训练迭代次数
        :param warmup_iters: warmup阶段的步数
        :param start_warmup_value: warmup阶段的初始学习率
        :param freeze_iters: 冻结阶段的步数
        """
        self.base_value = base_value      # 初始学习率
        self.final_value = final_value    # 最终学习率
        self.total_iters = total_iters    # 总的训练步数
        self.warmup_iters = warmup_iters  # warmup 阶段步数
        self.start_warmup_value = start_warmup_value  # warmup 阶段初始学习率
        self.freeze_iters = freeze_iters  # 冻结阶段步数

        # 冻结阶段学习率保持不变
        freeze_schedule = np.zeros((freeze_iters))

        # warmup 阶段，学习率从 start_warmup_value 到 base_value 线性增加
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

        # 余弦衰减阶段，学习率从 base_value 衰减到 final_value
        iters = np.arange(total_iters - warmup_iters - freeze_iters)
        decay_schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

        # 合并各个阶段的学习率
        self.schedule = np.concatenate((freeze_schedule, warmup_schedule, decay_schedule))

        # 确保总的步数符合预期
        assert len(self.schedule) == self.total_iters

        super().__init__(optimizer)  # 初始化父类

    def get_lr(self):
        """
        返回当前学习率
        """
        step = self.last_epoch  # 获取当前的epoch
        if step >= self.total_iters:
            return [self.final_value for _ in self.optimizer.param_groups]
        else:
            return [self.schedule[step] for _ in self.optimizer.param_groups]