import warnings
from dataclasses import dataclass, field
from typing import Any, Optional
from verl.base_config import BaseConfig
__all__ = ["FSDPEngineConfig", "McoreEngineConfig"]
@dataclass
class McoreEngineConfig(BaseConfig):
    _mutable_fields = BaseConfig._mutable_fields | {"sequence_parallel"}
    param_offload: bool = False
    grad_offload: bool = False
    optimizer_offload: bool = False
    tensor_model_parallel_size: int = 1
    expert_model_parallel_size: int = 1
    expert_tensor_parallel_size: Optional[int] = None
    pipeline_model_parallel_size: int = 1
    virtual_pipeline_model_parallel_size: Optional[int] = None
    context_parallel_size: int = 1
    sequence_parallel: bool = True
    use_distributed_optimizer: bool = True
    use_dist_checkpointing: bool = False
    dist_checkpointing_path: Optional[str] = None
    seed: int = 42
    override_ddp_config: dict[str, Any] = field(default_factory=dict)
    override_transformer_config: dict[str, Any] = field(default_factory=dict)
    use_mbridge: bool = False
    def __post_init__(self) -> None:
        if self.tensor_model_parallel_size == 1:
            warnings.warn("set sequence parallel to false as TP size is 1", stacklevel=2)
            self.sequence_parallel = False
@dataclass
class FSDPEngineConfig(BaseConfig):
    wrap_policy: dict[str, Any] = field(default_factory=dict)
    param_offload: bool = False
    optimizer_offload: bool = False
    offload_policy: bool = False
    reshard_after_forward: bool = True
    fsdp_size: int = -1
    forward_prefetch: bool = False
    model_dtype: str = "fp32"
    use_orig_params: bool = False