from __future__ import annotations

import argparse
import os
from dataclasses import dataclass


def _env_bool(name: str, default: bool) -> bool:
    v = os.environ.get(name)
    if v is None:
        return bool(default)
    return v.strip() not in ("", "0", "false", "False", "FALSE")


def _env_str(name: str, default: str) -> str:
    v = os.environ.get(name)
    return default if v is None else str(v)


def _env_int(name: str, default: int) -> int:
    v = os.environ.get(name)
    return default if v is None else int(v)


def _env_float(name: str, default: float) -> float:
    v = os.environ.get(name)
    return default if v is None else float(v)


@dataclass(frozen=True)
class WormDemoPublicConfig:

    base_trial: str
    output_path: str
    prefix: str
    suffix: str
    epochs_total: int
    resume: bool
    resume_start_epoch: int | None

    # Training-time LR downsample factor: LR tick every K_mul steps (dt-grid).
    # 5 is the historical default; 1 means "no downsample".
    k_mul: int

    # K history length controls (transfer impedance temporal horizon).
    #
    # Relationship (dt-grid):
    #   K_len_lr = floor(K_max_t_ms / (dt_ms * K_mul))
    #
    # You can override either:
    # - `k_len` (preferred when you want exact window length in LR ticks), or
    # - `k_max_t_ms` (preferred when you want to specify a physical horizon in ms).
    #
    # If both are None, a default is used.
    k_len: int | None
    k_max_t_ms: float | None

    # runtime knobs (still public because users often care)
    tstop_override_ms: float | None
    export_path: str | None


@dataclass(frozen=True)
class WormDemoInternalConfig:

    # These envs are used in multiple places (shell/demo/train). We centralize parsing here.
    # NOTE: This does not change training behavior by itself; it only provides a single source of truth.
    replay: bool
    replay_use_vecplay: bool
    replay_streaming: bool
    replay_cache_signals: bool

    # Loss / misc
    corr_use_torch: bool
    corr_torch_device: str

    # Misc verbosity/profiling
    profile: bool


@dataclass(frozen=True)
class WormDemoConfig:
    public: WormDemoPublicConfig
    internal: WormDemoInternalConfig

    @staticmethod
    def add_cli_args(parser: argparse.ArgumentParser) -> None:
        parser.add_argument("--base-trial", default="./trial10", help="包含 config/pkl 的基础 trial 目录")
        parser.add_argument("--output-path", default="./runs/trial/demo_run", help="输出目录（会写 ckpt/log）")
        parser.add_argument("--prefix", default="eworm", help="输出文件前缀")
        parser.add_argument("--suffix", default="demo", help="输出文件后缀")
        parser.add_argument("--epochs", type=int, default=50, help="训练总 epoch 数（absolute）")
        parser.add_argument(
            "--k-mul",
            type=int,
            default=None,
            help="LR 降采样倍率（K_mul）：每 K_mul 个 dt step 采样/更新一次；1 表示不降采样（默认 5）",
        )
        parser.add_argument(
            "--k-len",
            type=int,
            default=None,
            help="K 的历史长度（LR tick 数）。若设置，将优先于 --k-max-t-ms。",
        )
        parser.add_argument(
            "--k-max-t-ms",
            type=float,
            default=None,
            help="K 的历史最大时间长度（ms）。默认使用内置值（与历史最佳设置一致）。",
        )
        parser.add_argument("--resume", action="store_true", help="从 output-path 的 ckpt/last-train 恢复")
        parser.add_argument("--resume-start-epoch", type=int, default=None, help="恢复时强制设置 start_epoch")

    @classmethod
    def from_args_and_env(cls, args: argparse.Namespace) -> "WormDemoConfig":
        k_mul = args.k_mul
        if k_mul is None:
            # Keep an env override for quick experiments from shell scripts.
            k_mul = _env_int("EWORM_K_MUL", 5)
        k_mul = int(k_mul)
        if k_mul <= 0:
            raise ValueError(f"--k-mul must be a positive int (got {k_mul})")

        k_len = args.k_len
        if k_len is None and os.environ.get("EWORM_K_LEN"):
            k_len = _env_int("EWORM_K_LEN", 0)
        if k_len is not None:
            k_len = int(k_len)
            if k_len <= 0:
                raise ValueError(f"--k-len must be a positive int (got {k_len})")

        k_max_t_ms = args.k_max_t_ms
        if k_max_t_ms is None and os.environ.get("EWORM_K_MAX_T_MS"):
            k_max_t_ms = _env_float("EWORM_K_MAX_T_MS", 0.0)
        if k_max_t_ms is not None:
            k_max_t_ms = float(k_max_t_ms)
            if k_max_t_ms <= 0:
                raise ValueError(f"--k-max-t-ms must be positive (got {k_max_t_ms})")

        pub = WormDemoPublicConfig(
            base_trial=os.path.abspath(os.path.expanduser(str(args.base_trial))),
            output_path=os.path.abspath(os.path.expanduser(str(args.output_path))),
            prefix=str(args.prefix),
            suffix=str(args.suffix),
            epochs_total=int(args.epochs),
            resume=bool(args.resume),
            resume_start_epoch=None if args.resume_start_epoch is None else int(args.resume_start_epoch),
            k_mul=int(k_mul),
            k_len=None if k_len is None else int(k_len),
            k_max_t_ms=None if k_max_t_ms is None else float(k_max_t_ms),
            tstop_override_ms=(
                float(os.environ.get("EWORM_TSTOP_MS")) if os.environ.get("EWORM_TSTOP_MS") else None
            ),
            export_path=(
                os.environ.get("EWORM_HELIOX_EXPORT_PATH")
                if os.environ.get("EWORM_HELIOX_EXPORT_PATH")
                else None
            ),
        )
        internal = WormDemoInternalConfig(
            replay=_env_bool("EWORM_REPLAY", False),
            replay_use_vecplay=True,
            replay_streaming=_env_bool("EWORM_REPLAY_STREAMING", False),
            replay_cache_signals=_env_bool("EWORM_REPLAY_CACHE_SIGNALS", False),
            corr_use_torch=_env_bool("EWORM_CORR_USE_TORCH", True),
            corr_torch_device=_env_str("EWORM_CORR_TORCH_DEVICE", "cuda:0").strip(),
            profile=_env_bool("EWORM_PROFILE", False),
        )
        return cls(public=pub, internal=internal)

    def describe(self) -> dict:
        # Keep it short: this is a demo.
        return {
            "base_trial": self.public.base_trial,
            "output_path": self.public.output_path,
            "prefix": self.public.prefix,
            "suffix": self.public.suffix,
            "epochs_total": self.public.epochs_total,
            "resume": self.public.resume,
            "k_mul": self.public.k_mul,
            "k_len": self.public.k_len,
            "k_max_t_ms": self.public.k_max_t_ms,
            "tstop_override_ms": self.public.tstop_override_ms,
            "replay": self.internal.replay,
            "replay_use_vecplay": self.internal.replay_use_vecplay,
            "replay_streaming": self.internal.replay_streaming,
            "replay_cache_signals": self.internal.replay_cache_signals,
        }
