from dataclasses import dataclass
from typing import Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig"]
@dataclass
class OptimizerConfig(BaseConfig):
    lr: float = MISSING
    lr_warmup_steps_ratio: float = 0.0
    total_training_steps: int = -1
    weight_decay: float = 0.01
    lr_warmup_steps: Optional[int] = -1
    def __post_init__(self):
        assert self.lr != MISSING
@dataclass
class FSDPOptimizerConfig(OptimizerConfig):
    min_lr_ratio: Optional[float] = None
    warmup_style: str = "constant"
    num_cycles: float = 0.5
    def __post_init__(self):
        assert self.warmup_style in ["constant", "cosine"]
        return super().__post_init__()
@dataclass
class McoreOptimizerConfig(OptimizerConfig):
    optimizer: str = "adam"
    clip_grad: float = 1.0
    lr_warmup_init: float = 0.0
    lr_decay_steps: Optional[int] = None
    lr_decay_style: str = "linear"
    min_lr: float = 0.0
    weight_decay_incr_style: str = "constant"
    lr_wsd_decay_style: str = "exponential"
    lr_wsd_decay_steps: Optional[int] = None
    use_checkpoint_opt_param_scheduler: bool = False