from __future__ import annotations

import os
from dataclasses import dataclass


def _env_str(name: str, default: str) -> str:
    v = os.environ.get(name)
    return default if v is None else str(v)


def _env_int(name: str, default: int) -> int:
    v = os.environ.get(name)
    return default if v is None else int(v)


def _env_float(name: str, default: float) -> float:
    v = os.environ.get(name)
    return default if v is None else float(v)


def _env_is_one(name: str, default: str) -> bool:
    return _env_str(name, default).strip() == "1"


def _env_not_zero(name: str, default: str) -> bool:
    return _env_str(name, default).strip() != "0"


@dataclass(frozen=True)
class WormTrainConfig:
    """
    训练过程的“可调参数集合”（主要来自 env）。

    设计目标：
    - 让训练核心逻辑更像“库代码”，而不是散落着大量 `os.environ.get(...)`；
    - 不改变既有行为（默认值、布尔解析方式尽量保持历史一致）。

    说明：
    - 这个 demo 目前仍通过 env 控制大量实验参数（便于脚本/续跑/批量扫描）。
    - 如需更强可控性，可把这些字段逐步改为显式 Python 参数或配置文件。
    """

    # Misc
    profile: bool

    # Replay runtime
    replay: bool
    replay_use_vecplay: bool
    replay_streaming: bool
    replay_cache_signals: bool

    # Optimizer routing
    opt_w_backend: bool

    # Objective
    corr_use_torch: bool
    corr_torch_device: str

    # LR schedule scaling
    alpha_w0: float
    alpha_x0: float
    alpha_w_scale: float
    alpha_x_scale: float

    # Plateau / retreat
    plateau_use_vmin: bool
    plateau_vmin_err_tol: float
    plateau_vmin_eps: float
    plateau_vmin_save_snapshot: bool
    plateau_emergency_err_add: float
    plateau_lr_multiplier: float
    plateau_patience_epochs: int
    plateau_reset_adam: str

    # Trainable freeze/schedule
    freeze_w: bool
    freeze_x: bool
    x_l2_coef: float
    x_update_every: int
    x_update_burst: int
    x_update_offset: int

    # Logging / printing
    save_run_best: bool
    print_timestep: bool
    print_epoch_time: bool
    print_interval_ms: float
    debug_backend_adam_state: bool

    @classmethod
    def from_env(cls) -> "WormTrainConfig":
        return cls(
            profile=_env_is_one("EWORM_PROFILE", "0"),
            replay=_env_is_one("EWORM_REPLAY", "0"),
            replay_use_vecplay=True,
            replay_streaming=_env_is_one("EWORM_REPLAY_STREAMING", "0"),
            replay_cache_signals=_env_is_one("EWORM_REPLAY_CACHE_SIGNALS", "0"),
            opt_w_backend=_env_not_zero("EWORM_OPT_W_BACKEND", "0"),
            corr_use_torch=_env_not_zero("EWORM_CORR_USE_TORCH", "1"),
            corr_torch_device=_env_str("EWORM_CORR_TORCH_DEVICE", "cuda:0").strip(),
            alpha_w0=_env_float("EWORM_ALPHA_W0", 1e-5),
            alpha_x0=_env_float("EWORM_ALPHA_X0", 3e-2),
            alpha_w_scale=_env_float("EWORM_ALPHA_W_SCALE", 1.0),
            alpha_x_scale=_env_float("EWORM_ALPHA_X_SCALE", 1.0),
            plateau_use_vmin=_env_not_zero("EWORM_PLATEAU_USE_VMIN", "0"),
            plateau_vmin_err_tol=_env_float("EWORM_PLATEAU_VMIN_ERR_TOL", 0.002),
            plateau_vmin_eps=_env_float("EWORM_PLATEAU_VMIN_EPS", 5.0),
            plateau_vmin_save_snapshot=_env_not_zero("EWORM_PLATEAU_VMIN_SAVE_SNAPSHOT", "1"),
            plateau_emergency_err_add=_env_float("EWORM_PLATEAU_EMERGENCY_ERR_ADD", 0.0),
            plateau_lr_multiplier=_env_float("EWORM_PLATEAU_LR_MULTIPLIER", 0.3),
            plateau_patience_epochs=_env_int("EWORM_PLATEAU_PATIENCE_EPOCHS", 5),
            plateau_reset_adam=_env_str("EWORM_PLATEAU_RESET_ADAM", "none").strip().lower(),
            freeze_w=_env_not_zero("EWORM_FREEZE_W", "0"),
            freeze_x=_env_not_zero("EWORM_FREEZE_X", "0"),
            x_l2_coef=_env_float("EWORM_X_L2_COEF", 1e-1),
            x_update_every=_env_int("EWORM_X_UPDATE_EVERY", 1),
            x_update_burst=_env_int("EWORM_X_UPDATE_BURST", 1),
            x_update_offset=_env_int("EWORM_X_UPDATE_OFFSET", 0),
            save_run_best=_env_not_zero("EWORM_SAVE_RUN_BEST", "1"),
            print_timestep=_env_is_one("EWORM_PRINT_TIMESTEP", "0"),
            print_epoch_time=_env_is_one("EWORM_PRINT_EPOCH_TIME", "0"),
            print_interval_ms=_env_float("EWORM_PRINT_INTERVAL_MS", 100.0),
            debug_backend_adam_state=_env_not_zero("EWORM_DEBUG_BACKEND_ADAM_STATE", "0"),
        )
