from __future__ import annotations

import os
import sys
from dataclasses import dataclass
from typing import Any

import numpy as np

# This module keeps the Task/Trainer adapter logic out of the training implementation
# (`worm_training_impl.py`) so that `train.py` remains a small entrypoint.


def _ensure_heliox_learn_on_path() -> None:
    heliox_lib = os.environ.get("HELIOX_PYTHON_LIB", "").strip()
    if heliox_lib:
        heliox_lib = os.path.abspath(os.path.expanduser(heliox_lib))
        if heliox_lib not in sys.path and os.path.isdir(heliox_lib):
            sys.path.insert(0, heliox_lib)


@dataclass(frozen=True)
class WormFrameworkConfig:
    output_path: str
    prefix: str
    suffix: str


class WormCorrTask:
    """Worm transient training task adapter for heliox_learn.Trainer.

    This adapter intentionally reuses the existing implementation
    (train_one_epoch / checkpoint layout) while presenting a stable
    Trainer/Task interface for the demo entrypoint.
    """

    def __init__(
        self,
        *,
        net: Any,
        output_names: list[str],
        target: np.ndarray,
        cfg: WormFrameworkConfig,
        init_x: np.ndarray | None,
        init_state: dict | None,
        train_one_epoch_fn: Any,
        logger: Any = None,
    ):
        self.net = net
        self.output_names = list(output_names)
        self.target = np.asarray(target)
        self.cfg = cfg
        self._init_x = None if init_x is None else np.asarray(init_x)
        self._init_state = dict(init_state) if init_state is not None else None
        self._train_one_epoch = train_one_epoch_fn
        self._logger = logger

        _ensure_heliox_learn_on_path()
        from heliox_learn import TaskSpec  # type: ignore

        self._spec = TaskSpec(name="worm_v4_val_corr")

    @property
    def spec(self):
        return self._spec

    def setup(self, runtime: Any) -> dict:
        if self._init_state is not None:
            return dict(self._init_state)
        if self._init_x is None:
            raise ValueError("WormCorrTask requires init_x when init_state is not provided")
        # Minimal default state: start from provided x, current net weights, and empty history.
        return {"start_epoch": 0, "x": np.asarray(self._init_x), "w": self.net.w.numpy(), "train_error": []}

    def run_epoch(self, runtime: Any, state: dict, epoch: int):
        # `epoch` in Trainer.fit is a local counter (0..epochs-1). We use the absolute epoch
        # stored in state['start_epoch'] to keep continuity with legacy resume behavior.
        abs_epoch = int(state.get("start_epoch", 0))
        state = self._train_one_epoch(
            self.net,
            self.output_names,
            self.target,
            output_path=self.cfg.output_path,
            prefix=self.cfg.prefix,
            suffix=self.cfg.suffix,
            epoch=abs_epoch,
            state=state,
            logger=self._logger,
        )
        # Provide minimal metrics to the training history.
        train_error = state.get("train_error", [])
        mean_error = float(train_error[-1]) if train_error else float("nan")
        return state, {"epoch": float(abs_epoch), "mean_error": mean_error}


def run_framework_train(
    *,
    net: Any,
    output_names: list[str],
    target: np.ndarray,
    cfg: WormFrameworkConfig,
    epochs_total: int,
    resume_state: dict | None,
    init_x: np.ndarray | None,
    train_one_epoch_fn: Any,
    logger: Any = None,
) -> dict:
    """Run worm training via heliox_learn.Trainer using WormCorrTask."""
    _ensure_heliox_learn_on_path()
    from heliox_learn import Trainer, TrainerConfig  # type: ignore

    # Trainer.fit uses a local counter, so we compute how many
    # epochs remain from the absolute `resume_state['start_epoch']`.
    start_epoch = int(resume_state.get("start_epoch", 0)) if resume_state is not None else 0
    remaining = int(epochs_total) - int(start_epoch)
    if remaining <= 0:
        if logger:
            logger.info(f"train: nothing to do (start_epoch={start_epoch} >= epochs={epochs_total})")
        return {"state": resume_state or {"start_epoch": start_epoch}, "history": []}

    # Keep this checkpoint separate from the legacy eworm ckpt.
    trainer = Trainer(
        runtime=_RuntimeNoop(),
        cfg=TrainerConfig(
            output_dir=cfg.output_path,
            resume=False,
            checkpoint_name=f"framework_ckpt_{cfg.prefix}_{cfg.suffix}",
            save_every=1,
        ),
    )
    task = WormCorrTask(
        net=net,
        output_names=output_names,
        target=target,
        cfg=cfg,
        init_x=init_x,
        init_state=resume_state,
        train_one_epoch_fn=train_one_epoch_fn,
        logger=logger,
    )
    return trainer.fit(task, epochs=remaining)


class _RuntimeNoop:
    """Runtime adapter: EWORM has already constructed and attached a backend."""

    loaded = True

    def load(self) -> None:
        return
