from dataclasses import dataclass, field
from typing import Any, Optional
from verl.base_config import BaseConfig
__all__ = ["AlgoConfig", "FilterGroupsConfig", "KLControlConfig"]
@dataclass
class KLControlConfig(BaseConfig):
    type: str = "fixed"
    kl_coef: float = 0.001
    horizon: int = 10000
    target_kl: float = 0.1
@dataclass
class FilterGroupsConfig(BaseConfig):
    enable: bool = False
    metric: Optional[str] = None
    max_num_gen_batches: int = 0
@dataclass
class AlgoConfig(BaseConfig):
    gamma: float = 1.0
    lam: float = 1.0
    adv_estimator: str = "gae"
    norm_adv_by_std_in_grpo: bool = True
    use_kl_in_reward: bool = False
    kl_penalty: str = "kl"
    kl_ctrl: KLControlConfig = field(default_factory=KLControlConfig)
    use_pf_ppo: bool = False
    pf_ppo: dict[str, Any] = field(default_factory=dict)
    filter_groups: Optional[FilterGroupsConfig] = None