#!/usr/bin/env python3
"""
WORM 暂态训练 Demo：主入口

设计目标
--------
- 入口代码尽量短小：只做参数解析 + 日志 + 组装组件
- 训练循环走 `heliox_learn.Trainer.fit(task)`（通过 `worm_framework_task.py` 适配）
- 训练步进逻辑在 `worm_training_step.py` / `worm_training_impl.py`

注意
----
日常跑训练更推荐用 `run.sh`（更方便配置 env/恢复/改参数）。
本文件主要用于：
- 作为 Python 入口直接运行
- 阅读整体流程
"""

from __future__ import annotations

import argparse
import logging
import os
import shutil
import sys

import numpy as np

from worm_checkpoint import _error_path, _load_resume_state
from worm_demo_config import WormDemoConfig
from worm_defaults import (
    K_MAX_T,
    K_NBLOCK,
    NGPU,
    RANDOM_SEED,
    W_GAP_MAX,
    W_GAP_MIN,
    W_SYN_MAX,
    W_SYN_MIN,
)
from worm_framework_task import WormFrameworkConfig, run_framework_train
from worm_model_builder import build_worm_model
from worm_trainables import extract_trainables


def _setup_logger(*, output_path: str, prefix: str, suffix: str) -> logging.Logger:
    os.makedirs(output_path, exist_ok=True)
    logger = logging.getLogger("worm_demo")
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    fmt = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")

    sh = logging.StreamHandler(sys.stdout)
    sh.setLevel(logging.INFO)
    sh.setFormatter(fmt)
    logger.addHandler(sh)

    fh = logging.FileHandler(os.path.join(output_path, f"log_{prefix}_{suffix}.txt"))
    fh.setLevel(logging.INFO)
    fh.setFormatter(fmt)
    logger.addHandler(fh)

    logger.info("\n###################################################################################################\n")
    return logger


def _ensure_trial_inputs(*, base_trial: str, out_dir: str, logger: logging.Logger) -> None:
    """
    确保 out_dir 下存在建模所需的两个输入文件：
    - 000_circuit_search_config.json
    - sample_#0_circuit_old.pkl

    如果 out_dir 不存在这些文件，会从 base_trial 复制。
    """
    os.makedirs(out_dir, exist_ok=True)
    cfg_src = os.path.join(base_trial, "000_circuit_search_config.json")
    conn_src = os.path.join(base_trial, "sample_#0_circuit_old.pkl")
    cfg_dst = os.path.join(out_dir, "000_circuit_search_config.json")
    conn_dst = os.path.join(out_dir, "sample_#0_circuit_old.pkl")

    if not os.path.exists(cfg_dst):
        shutil.copy2(cfg_src, cfg_dst)
        logger.info(f"复制 config: {cfg_src} -> {cfg_dst}")
    if not os.path.exists(conn_dst):
        shutil.copy2(conn_src, conn_dst)
        logger.info(f"复制 connection: {conn_src} -> {conn_dst}")

    # 可选：复用 base_trial 里预计算好的 K（首次跑可节约时间）
    for name in os.listdir(base_trial):
        if name.startswith("K_eworm_v4_x") and name.endswith(".npz"):
            src = os.path.join(base_trial, name)
            dst = os.path.join(out_dir, name)
            if os.path.exists(src) and (not os.path.exists(dst)):
                shutil.copy2(src, dst)
                logger.info(f"复制 K: {src} -> {dst}")


def main() -> int:
    parser = argparse.ArgumentParser(description="WORM 暂态训练 demo")
    WormDemoConfig.add_cli_args(parser)
    args = parser.parse_args()

    cfg = WormDemoConfig.from_args_and_env(args)
    output_path = cfg.public.output_path
    base_trial = cfg.public.base_trial

    logger = _setup_logger(output_path=output_path, prefix=cfg.public.prefix, suffix=cfg.public.suffix)
    logger.info(f"config: {cfg.describe()}")

    _ensure_trial_inputs(base_trial=base_trial, out_dir=output_path, logger=logger)

    # 1) 用 NEURON 做建模 + 准备 target/input（HELIOX 作为仿真/学习后端）
    # 具体超参数（K_mul/ngpu/...）保持一致，便于实验复现。
    model = build_worm_model(
        output_path=output_path,
        random_seed=RANDOM_SEED,
        K_mul=int(cfg.public.k_mul),
        ngpu=NGPU,
        K_max_t_default_ms=float(K_MAX_T),
        K_nblock=K_NBLOCK,
        w_gap_max=W_GAP_MAX,
        w_gap_min=W_GAP_MIN,
        w_syn_max=W_SYN_MAX,
        w_syn_min=W_SYN_MIN,
        k_len=cfg.public.k_len,
        k_max_t_ms=cfg.public.k_max_t_ms,
        tstop_override_ms=cfg.public.tstop_override_ms,
        logger=logger,
    )

    # WormModelBuildResult 是 frozen dataclass（不可修改）。训练所用输入电流矩阵用局部变量维护。
    input_is = np.asarray(model.input_is)

    resume_state = None
    if cfg.public.resume:
        resume_state = _load_resume_state(output_path, cfg.public.prefix, cfg.public.suffix, logger=logger)
        if resume_state is None:
            logger.info("resume requested, but no checkpoint/last-train state found; starting fresh.")
        elif cfg.public.resume_start_epoch is not None:
            resume_state["start_epoch"] = int(cfg.public.resume_start_epoch)

    if resume_state is not None:
        model.net.set_weights(resume_state["w"])
        input_is = np.copy(resume_state["x"])
        logger.info(f"resumed: start_epoch={resume_state['start_epoch']}")
    else:
        # 默认从本目录自带的 0.125 预训练对初始化
        here = os.path.dirname(os.path.abspath(__file__))
        optimal_weight_path = os.path.join(here, "seeds", "weights_optimal_eworm_v4.npy")
        optimal_input_path = os.path.join(here, "seeds", "x_optimal_eworm_v4.npy")
        if os.path.exists(optimal_weight_path) and os.path.exists(optimal_input_path):
            model.net.set_weights(np.load(optimal_weight_path))
            input_is = np.copy(np.load(optimal_input_path))
            logger.info("已加载预训练初值（0.125 预训练对）")

    # 2) 抽取训练所需对象（并 attach HELIOX 后端）
    bundle = extract_trainables(
        model,
        output_path=output_path,
        export_path=cfg.public.export_path,
        logger=logger,
    )

    # 3) 训练 epoch 实现复用 worm_training_impl.py，但通过显式参数传入 dt/tstop/v_init，
    #    避免依赖内部 module-level 全局变量。
    import worm_training_step as wt

    def _train_one_epoch(
        net,
        output_names,
        target,
        *,
        output_path: str,
        prefix: str,
        suffix: str,
        epoch: int,
        state: dict,
        logger=None,
    ) -> dict:
        return wt.train_one_epoch(
            net,
            output_names,
            target,
            output_path=output_path,
            prefix=prefix,
            suffix=suffix,
            dt_ms=float(model.dt),
            tstop_ms=float(model.tstop),
            v_init=float(model.v_r),
            k_mul=int(cfg.public.k_mul),
            epoch=epoch,
            state=state,
            logger=logger,
        )

    # 4) 用 heliox_learn.Trainer.fit(task) 跑训练循环（通过 worm_framework_task 适配）
    result = run_framework_train(
        net=bundle.net,
        output_names=bundle.output_names,
        target=bundle.target,
        cfg=WormFrameworkConfig(output_path=output_path, prefix=cfg.public.prefix, suffix=cfg.public.suffix),
        epochs_total=int(cfg.public.epochs_total),
        resume_state=resume_state,
        init_x=input_is,
        train_one_epoch_fn=_train_one_epoch,
        logger=logger,
    )
    final_state = result.get("state", {}) if isinstance(result, dict) else {}
    train_error = final_state.get("train_error", [])
    np.save(_error_path(output_path, cfg.public.prefix, cfg.public.suffix), np.asarray(train_error, dtype=np.float64))
    logger.info(f"训练完成：保存 error 曲线到 {_error_path(output_path, cfg.public.prefix, cfg.public.suffix)}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
