from typing import List

import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, MultiStepLR
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer

from options import OptimizationConfig


class GenericScheduler:
    def __init__(self, scheduler: _LRScheduler, min_lr: float, step_interval: int):
        self.scheduler = scheduler
        self.min_lr = min_lr
        self.step_interval = step_interval

    def check_and_step(self, step: int):
        if step % self.step_interval == 0 and self.min_lr < self.get_last_lr():
            self.scheduler.step()

    def get_last_lr(self) -> float:
        return self.scheduler.get_last_lr()[-1]


def load_lr_scheduler(optimizer: Optimizer, cfg: OptimizationConfig) -> GenericScheduler:
    
    print('-' * 50)
    print(f"Using LR Scheduler: {cfg.lr_scheduler_type}")
    print(f'Warmup Epochs: {cfg.warmup_epochs}')
    print(f'Step Interval: {cfg.step_interval}')
    print(f'Min LR: {cfg.min_lr}')
    if cfg.lr_scheduler_type == 'multistep':
        print(f'Milestones: {cfg.milestones}')
        print(f'Gamma: {cfg.gamma}')
    print('-' * 50)
    
    def init_cosine():
        return GenericScheduler(
            CosineAnnealingLR(optimizer, T_max=cfg.epochs, eta_min=cfg.min_lr),
            min_lr=cfg.min_lr,
            step_interval=cfg.step_interval
        )

    def init_exponential():
        # If not explicitly set, compute exponential decay factor to achieve min_lr at final iteration
        step_interval = cfg.step_interval if cfg.step_interval > 0 else cfg.epochs // 100
        approx_num_steps = cfg.epochs / step_interval
        gamma = np.power(cfg.min_lr / cfg.lr, 1 / approx_num_steps) if cfg.gamma == 0 else cfg.gamma
        return GenericScheduler(
            ExponentialLR(optimizer, gamma),
            min_lr=cfg.min_lr,
            step_interval=step_interval
        )
    
    def init_multistep():
        return GenericScheduler(
            MultiStepLR(optimizer, milestones=cfg.milestones, gamma=cfg.gamma),
            min_lr=cfg.min_lr,
            step_interval=cfg.step_interval
        )

    # 学习率调度器字典
    schedulers = {
        "cosine": init_cosine,
        "exponential": init_exponential,
        "multistep": init_multistep
    }
    
    # 动态选择学习率调度器
    try:
        return schedulers[cfg.lr_scheduler_type]()
    except KeyError:
        raise ValueError(f"Unknown LR Scheduler Type: {cfg.lr_scheduler_type}")
