"""
Checkpoint/resume utilities for the worm transient training pipeline.

Extracted from the training implementation to keep the main entrypoint focused on:
  - building the NEURON frontend model
  - attaching the HELIOX backend
  - running the per-epoch training step

This module intentionally preserves the legacy on-disk formats:
  - `weights_train_*.npy`, `x_train_*.npy`, `error_*.npy` (fallback)
  - `ckpt_*.npz` (preferred unified checkpoint)
  - optional snapshots inside the checkpoint (`opt_*`, `plateau_vmin_*`, `run_best_*`)
"""

from __future__ import annotations

import os
from typing import Any

import numpy as np


def _weights_train_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"weights_train_{prefix}_{suffix}.npy")


def _x_train_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"x_train_{prefix}_{suffix}.npy")


def _weights_optimal_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"weights_optimal_{prefix}_{suffix}.npy")


def _x_optimal_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"x_optimal_{prefix}_{suffix}.npy")


def _error_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"error_{prefix}_{suffix}.npy")


def _ckpt_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"ckpt_{prefix}_{suffix}.npz")


def _infer_start_epoch_from_error(output_path: str, prefix: str, suffix: str) -> int | None:
    """
    If `error_*.npy` exists, resume from `len(error)` (i.e. next epoch index).
    """
    path = _error_path(output_path, prefix, suffix)
    if not os.path.exists(path):
        return None
    try:
        arr = np.load(path, allow_pickle=True)
    except Exception:
        return None
    try:
        return int(len(arr))
    except Exception:
        return None


def _load_resume_state(output_path: str, prefix: str, suffix: str, logger=None) -> dict[str, Any] | None:
    """
    Resume priority:
    1) Unified checkpoint `ckpt_*.npz` (includes epoch, weights, x, best-so-far, Adam state).
    2) Fallback: `weights_train_*.npy` + `x_train_*.npy` and infer epoch from `error_*.npy` if available.
    Returns: dict or None
    """
    ckpt = _ckpt_path(output_path, prefix, suffix)
    if os.path.exists(ckpt):
        if logger:
            logger.info(f"resume: loading checkpoint {ckpt}")
        data = np.load(ckpt, allow_pickle=True)

        # Scalars may come back as 0-d arrays; normalize.
        def _scalar(key, default):
            if key not in data:
                return default
            v = data[key]
            if isinstance(v, np.ndarray) and v.shape == ():
                return v.item()
            return v

        state = {
            "start_epoch": int(_scalar("start_epoch", 0)),
            "x": data["x"],
            "w": data["w"],
            "train_error": list(data["train_error"]) if "train_error" in data else [],
            "opt_epoch": int(_scalar("opt_epoch", -1)),
            "opt_mean_error": float(_scalar("opt_mean_error", 1e100)),
            "opt_w": data["opt_w"] if "opt_w" in data and data["opt_w"].size else None,
            "opt_x": data["opt_x"] if "opt_x" in data and data["opt_x"].size else None,
            "alpha_multiplier": float(_scalar("alpha_multiplier", 1.0)),
            # Adam params (optional)
            "adam_m_w": data["adam_m_w"] if "adam_m_w" in data else None,
            "adam_v_w": data["adam_v_w"] if "adam_v_w" in data else None,
            "beta_1_t_w": float(_scalar("beta_1_t_w", 1.0)),
            "beta_2_t_w": float(_scalar("beta_2_t_w", 1.0)),
            "adam_m_x": data["adam_m_x"] if "adam_m_x" in data else None,
            "adam_v_x": data["adam_v_x"] if "adam_v_x" in data else None,
            "beta_1_t_x": float(_scalar("beta_1_t_x", 1.0)),
            "beta_2_t_x": float(_scalar("beta_2_t_x", 1.0)),
            # Best-so-far optimizer snapshot (used by retreat logic; optional).
            "opt_adam_m_w": data["opt_adam_m_w"] if "opt_adam_m_w" in data else None,
            "opt_adam_v_w": data["opt_adam_v_w"] if "opt_adam_v_w" in data else None,
            "opt_beta_1_t_w": float(_scalar("opt_beta_1_t_w", 1.0)),
            "opt_beta_2_t_w": float(_scalar("opt_beta_2_t_w", 1.0)),
            "opt_adam_m_x": data["opt_adam_m_x"] if "opt_adam_m_x" in data else None,
            "opt_adam_v_x": data["opt_adam_v_x"] if "opt_adam_v_x" in data else None,
            "opt_beta_1_t_x": float(_scalar("opt_beta_1_t_x", 1.0)),
            "opt_beta_2_t_x": float(_scalar("opt_beta_2_t_x", 1.0)),
            # Plateau vmin-preferred snapshot (optional; backward-compatible).
            "plateau_vmin_best": float(_scalar("plateau_vmin_best", -1e100)),
            "plateau_vmin_best_epoch": int(_scalar("plateau_vmin_best_epoch", -1)),
            "plateau_vmin_best_error": float(_scalar("plateau_vmin_best_error", 1e100)),
            "plateau_vmin_best_w": data["plateau_vmin_best_w"] if "plateau_vmin_best_w" in data else None,
            "plateau_vmin_best_x": data["plateau_vmin_best_x"] if "plateau_vmin_best_x" in data else None,
            "plateau_vmin_best_alpha_multiplier": float(_scalar("plateau_vmin_best_alpha_multiplier", 1.0)),
            "plateau_vmin_best_adam_m_w": data["plateau_vmin_best_adam_m_w"]
            if "plateau_vmin_best_adam_m_w" in data
            else None,
            "plateau_vmin_best_adam_v_w": data["plateau_vmin_best_adam_v_w"]
            if "plateau_vmin_best_adam_v_w" in data
            else None,
            "plateau_vmin_best_beta_1_t_w": float(_scalar("plateau_vmin_best_beta_1_t_w", 1.0)),
            "plateau_vmin_best_beta_2_t_w": float(_scalar("plateau_vmin_best_beta_2_t_w", 1.0)),
            "plateau_vmin_best_adam_m_x": data["plateau_vmin_best_adam_m_x"]
            if "plateau_vmin_best_adam_m_x" in data
            else None,
            "plateau_vmin_best_adam_v_x": data["plateau_vmin_best_adam_v_x"]
            if "plateau_vmin_best_adam_v_x" in data
            else None,
            "plateau_vmin_best_beta_1_t_x": float(_scalar("plateau_vmin_best_beta_1_t_x", 1.0)),
            "plateau_vmin_best_beta_2_t_x": float(_scalar("plateau_vmin_best_beta_2_t_x", 1.0)),
            # Run-best snapshot (optional).
            "run_best_epoch": int(_scalar("run_best_epoch", -1)),
            "run_best_mean_error": float(_scalar("run_best_mean_error", 1e100)),
            "run_best_vmin": float(_scalar("run_best_vmin", 0.0)),
            "run_best_vmax": float(_scalar("run_best_vmax", 0.0)),
            "run_best_w": data["run_best_w"] if "run_best_w" in data and data["run_best_w"].size else None,
            "run_best_x": data["run_best_x"] if "run_best_x" in data and data["run_best_x"].size else None,
            # Backend Adam state for weights (optional; only present when EWORM_OPT_W_BACKEND=1).
            "backend_adam_w_step": int(_scalar("backend_adam_w_step", 0)) if "backend_adam_w_step" in data else None,
            "backend_adam_w_m": data["backend_adam_w_m"] if "backend_adam_w_m" in data else None,
            "backend_adam_w_v": data["backend_adam_w_v"] if "backend_adam_w_v" in data else None,
            "backend_adam_w_beta1": float(_scalar("backend_adam_w_beta1", 0.9)) if "backend_adam_w_beta1" in data else None,
            "backend_adam_w_beta2": float(_scalar("backend_adam_w_beta2", 0.999)) if "backend_adam_w_beta2" in data else None,
            "backend_adam_w_epsilon": float(_scalar("backend_adam_w_epsilon", 1e-8)) if "backend_adam_w_epsilon" in data else None,
            # Best-so-far optimizer snapshot (backend Adam; optional).
            "opt_backend_adam_w_step": int(_scalar("opt_backend_adam_w_step", 0)) if "opt_backend_adam_w_step" in data else None,
            "opt_backend_adam_w_m": data["opt_backend_adam_w_m"] if "opt_backend_adam_w_m" in data else None,
            "opt_backend_adam_w_v": data["opt_backend_adam_w_v"] if "opt_backend_adam_w_v" in data else None,
            "opt_backend_adam_w_beta1": float(_scalar("opt_backend_adam_w_beta1", 0.9)) if "opt_backend_adam_w_beta1" in data else None,
            "opt_backend_adam_w_beta2": float(_scalar("opt_backend_adam_w_beta2", 0.999)) if "opt_backend_adam_w_beta2" in data else None,
            "opt_backend_adam_w_epsilon": float(_scalar("opt_backend_adam_w_epsilon", 1e-8)) if "opt_backend_adam_w_epsilon" in data else None,
            # Plateau vmin-preferred snapshot (backend Adam; optional).
            "plateau_vmin_best_backend_adam_w_step": int(_scalar("plateau_vmin_best_backend_adam_w_step", 0))
            if "plateau_vmin_best_backend_adam_w_step" in data
            else None,
            "plateau_vmin_best_backend_adam_w_m": data["plateau_vmin_best_backend_adam_w_m"]
            if "plateau_vmin_best_backend_adam_w_m" in data
            else None,
            "plateau_vmin_best_backend_adam_w_v": data["plateau_vmin_best_backend_adam_w_v"]
            if "plateau_vmin_best_backend_adam_w_v" in data
            else None,
            "plateau_vmin_best_backend_adam_w_beta1": float(_scalar("plateau_vmin_best_backend_adam_w_beta1", 0.9))
            if "plateau_vmin_best_backend_adam_w_beta1" in data
            else None,
            "plateau_vmin_best_backend_adam_w_beta2": float(_scalar("plateau_vmin_best_backend_adam_w_beta2", 0.999))
            if "plateau_vmin_best_backend_adam_w_beta2" in data
            else None,
            "plateau_vmin_best_backend_adam_w_epsilon": float(_scalar("plateau_vmin_best_backend_adam_w_epsilon", 1e-8))
            if "plateau_vmin_best_backend_adam_w_epsilon" in data
            else None,
        }
        return state

    w_path = _weights_train_path(output_path, prefix, suffix)
    x_path = _x_train_path(output_path, prefix, suffix)
    if os.path.exists(w_path) and os.path.exists(x_path):
        start_epoch = _infer_start_epoch_from_error(output_path, prefix, suffix)
        if start_epoch is None:
            start_epoch = 0
        opt_w = None
        opt_x = None
        opt_epoch = -1
        opt_mean_error = 1e100
        w_opt_path = _weights_optimal_path(output_path, prefix, suffix)
        x_opt_path = _x_optimal_path(output_path, prefix, suffix)
        if os.path.exists(w_opt_path) and os.path.exists(x_opt_path):
            try:
                opt_w = np.load(w_opt_path)
                opt_x = np.load(x_opt_path)
                opt_epoch = int(start_epoch) - 1
            except Exception:
                opt_w = None
                opt_x = None
        err_path = _error_path(output_path, prefix, suffix)
        train_error = list(np.load(err_path, allow_pickle=True)) if os.path.exists(err_path) else []
        if len(train_error) > 0:
            try:
                opt_mean_error = float(np.min(np.asarray(train_error, dtype=np.float64)))
            except Exception:
                opt_mean_error = 1e100
        if logger:
            logger.info(f"resume: loading last-train state w={w_path}, x={x_path}, start_epoch={start_epoch}")
        return {
            "start_epoch": int(start_epoch),
            "x": np.load(x_path),
            "w": np.load(w_path),
            "train_error": train_error,
            "opt_epoch": opt_epoch,
            "opt_mean_error": opt_mean_error,
            "opt_w": opt_w,
            "opt_x": opt_x,
            "alpha_multiplier": 1.0,
            "adam_m_w": None,
            "adam_v_w": None,
            "beta_1_t_w": 1.0,
            "beta_2_t_w": 1.0,
            "adam_m_x": None,
            "adam_v_x": None,
            "beta_1_t_x": 1.0,
            "beta_2_t_x": 1.0,
        }

    return None


def _save_ckpt(output_path: str, prefix: str, suffix: str, state: dict, logger=None) -> None:
    path = _ckpt_path(output_path, prefix, suffix)

    def _as_scalar(v, default=None):
        if v is None:
            return default
        v = np.asarray(v)
        if v.shape == ():
            return v.item()
        return v

    def _maybe_array(v):
        if v is None:
            return None
        arr = np.asarray(v)
        return arr if arr.size else None

    payload = {
        "start_epoch": int(state["start_epoch"]),
        "x": np.asarray(state["x"]),
        "w": np.asarray(state["w"]),
        "train_error": np.asarray(state.get("train_error", []), dtype=np.float64),
        "opt_epoch": int(state.get("opt_epoch", -1)),
        "opt_mean_error": float(state.get("opt_mean_error", 1e100)),
        "opt_w": np.asarray(state["opt_w"]) if state.get("opt_w", None) is not None else np.asarray([]),
        "opt_x": np.asarray(state["opt_x"]) if state.get("opt_x", None) is not None else np.asarray([]),
        "alpha_multiplier": float(state.get("alpha_multiplier", 1.0)),
        "beta_1_t_w": float(state.get("beta_1_t_w", 1.0)),
        "beta_2_t_w": float(state.get("beta_2_t_w", 1.0)),
        "beta_1_t_x": float(state.get("beta_1_t_x", 1.0)),
        "beta_2_t_x": float(state.get("beta_2_t_x", 1.0)),
    }

    # Optional: snapshot of optimizer state at current best-so-far (for retreat correctness).
    if state.get("opt_adam_m_w", None) is not None:
        payload["opt_adam_m_w"] = np.asarray(state["opt_adam_m_w"])
    if state.get("opt_adam_v_w", None) is not None:
        payload["opt_adam_v_w"] = np.asarray(state["opt_adam_v_w"])
    if state.get("opt_beta_1_t_w", None) is not None:
        payload["opt_beta_1_t_w"] = float(_as_scalar(state.get("opt_beta_1_t_w", 1.0), 1.0))
    if state.get("opt_beta_2_t_w", None) is not None:
        payload["opt_beta_2_t_w"] = float(_as_scalar(state.get("opt_beta_2_t_w", 1.0), 1.0))
    if state.get("opt_adam_m_x", None) is not None:
        payload["opt_adam_m_x"] = np.asarray(state["opt_adam_m_x"])
    if state.get("opt_adam_v_x", None) is not None:
        payload["opt_adam_v_x"] = np.asarray(state["opt_adam_v_x"])
    if state.get("opt_beta_1_t_x", None) is not None:
        payload["opt_beta_1_t_x"] = float(_as_scalar(state.get("opt_beta_1_t_x", 1.0), 1.0))
    if state.get("opt_beta_2_t_x", None) is not None:
        payload["opt_beta_2_t_x"] = float(_as_scalar(state.get("opt_beta_2_t_x", 1.0), 1.0))

    # Optional: backend optimizer state for weight Adam (HELIOX internal optimizer).
    if state.get("backend_adam_w_step", None) is not None:
        payload["backend_adam_w_step"] = int(_as_scalar(state.get("backend_adam_w_step", 0), 0))
    if state.get("backend_adam_w_m", None) is not None:
        payload["backend_adam_w_m"] = np.asarray(state["backend_adam_w_m"])
    if state.get("backend_adam_w_v", None) is not None:
        payload["backend_adam_w_v"] = np.asarray(state["backend_adam_w_v"])
    if state.get("backend_adam_w_beta1", None) is not None:
        payload["backend_adam_w_beta1"] = float(_as_scalar(state.get("backend_adam_w_beta1", 0.9), 0.9))
    if state.get("backend_adam_w_beta2", None) is not None:
        payload["backend_adam_w_beta2"] = float(_as_scalar(state.get("backend_adam_w_beta2", 0.999), 0.999))
    if state.get("backend_adam_w_epsilon", None) is not None:
        payload["backend_adam_w_epsilon"] = float(_as_scalar(state.get("backend_adam_w_epsilon", 1e-9), 1e-9))

    if state.get("opt_backend_adam_w_step", None) is not None:
        payload["opt_backend_adam_w_step"] = int(_as_scalar(state.get("opt_backend_adam_w_step", 0), 0))
    if state.get("opt_backend_adam_w_m", None) is not None:
        payload["opt_backend_adam_w_m"] = np.asarray(state["opt_backend_adam_w_m"])
    if state.get("opt_backend_adam_w_v", None) is not None:
        payload["opt_backend_adam_w_v"] = np.asarray(state["opt_backend_adam_w_v"])
    if state.get("opt_backend_adam_w_beta1", None) is not None:
        payload["opt_backend_adam_w_beta1"] = float(_as_scalar(state.get("opt_backend_adam_w_beta1", 0.9), 0.9))
    if state.get("opt_backend_adam_w_beta2", None) is not None:
        payload["opt_backend_adam_w_beta2"] = float(_as_scalar(state.get("opt_backend_adam_w_beta2", 0.999), 0.999))
    if state.get("opt_backend_adam_w_epsilon", None) is not None:
        payload["opt_backend_adam_w_epsilon"] = float(_as_scalar(state.get("opt_backend_adam_w_epsilon", 1e-9), 1e-9))

    # Optional: plateau vmin-preferred snapshot state.
    if state.get("plateau_vmin_best", None) is not None:
        payload["plateau_vmin_best"] = float(_as_scalar(state.get("plateau_vmin_best", -1e100), -1e100))
    if state.get("plateau_vmin_best_epoch", None) is not None:
        payload["plateau_vmin_best_epoch"] = int(_as_scalar(state.get("plateau_vmin_best_epoch", -1), -1))
    if state.get("plateau_vmin_best_error", None) is not None:
        payload["plateau_vmin_best_error"] = float(_as_scalar(state.get("plateau_vmin_best_error", 1e100), 1e100))
    if _maybe_array(state.get("plateau_vmin_best_w", None)) is not None:
        payload["plateau_vmin_best_w"] = np.asarray(state["plateau_vmin_best_w"])
    if _maybe_array(state.get("plateau_vmin_best_x", None)) is not None:
        payload["plateau_vmin_best_x"] = np.asarray(state["plateau_vmin_best_x"])
    if state.get("plateau_vmin_best_alpha_multiplier", None) is not None:
        payload["plateau_vmin_best_alpha_multiplier"] = float(
            _as_scalar(state.get("plateau_vmin_best_alpha_multiplier", 1.0), 1.0)
        )
    if state.get("plateau_vmin_best_adam_m_w", None) is not None:
        payload["plateau_vmin_best_adam_m_w"] = np.asarray(state["plateau_vmin_best_adam_m_w"])
    if state.get("plateau_vmin_best_adam_v_w", None) is not None:
        payload["plateau_vmin_best_adam_v_w"] = np.asarray(state["plateau_vmin_best_adam_v_w"])
    if state.get("plateau_vmin_best_beta_1_t_w", None) is not None:
        payload["plateau_vmin_best_beta_1_t_w"] = float(
            _as_scalar(state.get("plateau_vmin_best_beta_1_t_w", 1.0), 1.0)
        )
    if state.get("plateau_vmin_best_beta_2_t_w", None) is not None:
        payload["plateau_vmin_best_beta_2_t_w"] = float(
            _as_scalar(state.get("plateau_vmin_best_beta_2_t_w", 1.0), 1.0)
        )
    if state.get("plateau_vmin_best_adam_m_x", None) is not None:
        payload["plateau_vmin_best_adam_m_x"] = np.asarray(state["plateau_vmin_best_adam_m_x"])
    if state.get("plateau_vmin_best_adam_v_x", None) is not None:
        payload["plateau_vmin_best_adam_v_x"] = np.asarray(state["plateau_vmin_best_adam_v_x"])
    if state.get("plateau_vmin_best_beta_1_t_x", None) is not None:
        payload["plateau_vmin_best_beta_1_t_x"] = float(
            _as_scalar(state.get("plateau_vmin_best_beta_1_t_x", 1.0), 1.0)
        )
    if state.get("plateau_vmin_best_beta_2_t_x", None) is not None:
        payload["plateau_vmin_best_beta_2_t_x"] = float(
            _as_scalar(state.get("plateau_vmin_best_beta_2_t_x", 1.0), 1.0)
        )

    if state.get("plateau_vmin_best_backend_adam_w_step", None) is not None:
        payload["plateau_vmin_best_backend_adam_w_step"] = int(
            _as_scalar(state.get("plateau_vmin_best_backend_adam_w_step", 0), 0)
        )
    if state.get("plateau_vmin_best_backend_adam_w_m", None) is not None:
        payload["plateau_vmin_best_backend_adam_w_m"] = np.asarray(state["plateau_vmin_best_backend_adam_w_m"])
    if state.get("plateau_vmin_best_backend_adam_w_v", None) is not None:
        payload["plateau_vmin_best_backend_adam_w_v"] = np.asarray(state["plateau_vmin_best_backend_adam_w_v"])
    if state.get("plateau_vmin_best_backend_adam_w_beta1", None) is not None:
        payload["plateau_vmin_best_backend_adam_w_beta1"] = float(
            _as_scalar(state.get("plateau_vmin_best_backend_adam_w_beta1", 0.9), 0.9)
        )
    if state.get("plateau_vmin_best_backend_adam_w_beta2", None) is not None:
        payload["plateau_vmin_best_backend_adam_w_beta2"] = float(
            _as_scalar(state.get("plateau_vmin_best_backend_adam_w_beta2", 0.999), 0.999)
        )
    if state.get("plateau_vmin_best_backend_adam_w_epsilon", None) is not None:
        payload["plateau_vmin_best_backend_adam_w_epsilon"] = float(
            _as_scalar(state.get("plateau_vmin_best_backend_adam_w_epsilon", 1e-9), 1e-9)
        )

    # Optional: run-best snapshot (best within a run/phase, useful when global opt is fixed).
    if state.get("run_best_epoch", None) is not None:
        payload["run_best_epoch"] = int(_as_scalar(state.get("run_best_epoch", -1), -1))
    if state.get("run_best_mean_error", None) is not None:
        payload["run_best_mean_error"] = float(_as_scalar(state.get("run_best_mean_error", 1e100), 1e100))
    if state.get("run_best_vmin", None) is not None:
        payload["run_best_vmin"] = float(_as_scalar(state.get("run_best_vmin", 0.0), 0.0))
    if state.get("run_best_vmax", None) is not None:
        payload["run_best_vmax"] = float(_as_scalar(state.get("run_best_vmax", 0.0), 0.0))
    if _maybe_array(state.get("run_best_w", None)) is not None:
        payload["run_best_w"] = np.asarray(state["run_best_w"])
    if _maybe_array(state.get("run_best_x", None)) is not None:
        payload["run_best_x"] = np.asarray(state["run_best_x"])

    if state.get("adam_m_w", None) is not None:
        payload["adam_m_w"] = np.asarray(state["adam_m_w"])
    if state.get("adam_v_w", None) is not None:
        payload["adam_v_w"] = np.asarray(state["adam_v_w"])
    if state.get("adam_m_x", None) is not None:
        payload["adam_m_x"] = np.asarray(state["adam_m_x"])
    if state.get("adam_v_x", None) is not None:
        payload["adam_v_x"] = np.asarray(state["adam_v_x"])

    np.savez_compressed(path, **payload)
    if logger:
        logger.info(f"checkpoint saved: {path}")


def _plateau_vmin_snapshot_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"plateau_vmin_{prefix}_{suffix}.npz")


def _run_best_snapshot_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"run_best_{prefix}_{suffix}.npz")
