from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import numpy as np

from worm_model_builder import WormModelBuildResult, attach_heliox_backend


@dataclass(frozen=True)
class WormTrainablesBundle:
    """
    从生物学模型（由 NEURON 构建）中抽取出“训练所需对象”，用于把训练脚本写得更像普通 ML。

    约定：
    - 本 bundle 不暴露底层后端实现细节（例如 export/load/vecplay 等），只提供训练所需的上层对象。
    - 训练实现仍复用 `worm_training_impl.train_one_epoch`，此处只负责组装。
    """

    net: Any
    backend: Any
    output_names: list[str]
    target: np.ndarray
    x_init: np.ndarray


def extract_trainables(
    model: WormModelBuildResult,
    *,
    output_path: str,
    export_path: str | None,
    logger=None,
) -> WormTrainablesBundle:
    """
    从 `build_worm_model(...)` 的输出中，抽取训练所需的对象并 attach HELIOX 后端。
    """
    backend = attach_heliox_backend(
        model.net,
        model.output_names,
        dt=model.dt,
        v_init=model.v_r,
        output_path=output_path,
        export_path=export_path,
    )
    if logger:
        logger.info("backend attached")
    return WormTrainablesBundle(
        net=model.net,
        backend=backend,
        output_names=list(model.output_names),
        target=np.asarray(model.target),
        x_init=np.asarray(model.input_is),
    )
