#!/usr/bin/env python3
from __future__ import annotations

import argparse
import datetime as _dt
import os
import shutil

import numpy as np


def _ckpt_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"ckpt_{prefix}_{suffix}.npz")

def _plateau_vmin_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"plateau_vmin_{prefix}_{suffix}.npz")

def _run_best_path(output_path: str, prefix: str, suffix: str) -> str:
    return os.path.join(output_path, f"run_best_{prefix}_{suffix}.npz")


def _backup_file(path: str) -> str:
    ts = _dt.datetime.now().strftime("%Y%m%d-%H%M%S")
    return f"{path}.bak.{ts}"


def _maybe(arr):
    arr = np.asarray(arr)
    return arr if arr.size else None


def main() -> int:
    ap = argparse.ArgumentParser(description="Edit eworm training checkpoint in-place (with backup).")
    ap.add_argument("--out", required=True, help="output dir that contains ckpt_*.npz")
    ap.add_argument("--prefix", default="eworm")
    ap.add_argument("--suffix", required=True)
    ap.add_argument("--no-backup", action="store_true", help="do not create a .bak timestamp copy")

    ap.add_argument("--restore-opt", action="store_true", help="set (w,x) to (opt_w,opt_x)")
    ap.add_argument(
        "--restore-vmin-snapshot",
        action="store_true",
        help="set (w,x,alpha_multiplier) from plateau_vmin_*.npz if present",
    )
    ap.add_argument(
        "--restore-run-best",
        action="store_true",
        help="set (w,x,alpha_multiplier) from run_best_*.npz if present",
    )
    ap.add_argument(
        "--reset-adam",
        choices=["none", "w", "x", "both"],
        default="none",
        help="reset Adam moments for w and/or x",
    )
    ap.add_argument("--set-alpha-multiplier", type=float, default=None)
    ap.add_argument("--set-start-epoch", type=int, default=None)

    args = ap.parse_args()

    ckpt = _ckpt_path(args.out, args.prefix, args.suffix)
    if not os.path.exists(ckpt):
        raise FileNotFoundError(ckpt)

    if not args.no_backup:
        bak = _backup_file(ckpt)
        shutil.copy2(ckpt, bak)
        print(f"[edit_ckpt] backup: {bak}")

    z = np.load(ckpt, allow_pickle=False)
    state = {k: z[k] for k in z.files}

    if args.restore_opt:
        opt_w = _maybe(state.get("opt_w", np.asarray([])))
        opt_x = _maybe(state.get("opt_x", np.asarray([])))
        if opt_w is None or opt_x is None:
            raise ValueError("opt_w/opt_x missing in ckpt; cannot --restore-opt")
        state["w"] = np.asarray(opt_w)
        state["x"] = np.asarray(opt_x)
        # Also restore optimizer state when available, to make `--resume --restore-opt` truly continue
        # from the best-so-far checkpoint (weights + moments).
        if "opt_adam_m_w" in state and "opt_adam_v_w" in state:
            state["adam_m_w"] = np.asarray(state["opt_adam_m_w"])
            state["adam_v_w"] = np.asarray(state["opt_adam_v_w"])
            if "opt_beta_1_t_w" in state:
                state["beta_1_t_w"] = np.asarray(state["opt_beta_1_t_w"], dtype=np.float64)
            if "opt_beta_2_t_w" in state:
                state["beta_2_t_w"] = np.asarray(state["opt_beta_2_t_w"], dtype=np.float64)
        if "opt_backend_adam_w_m" in state and "opt_backend_adam_w_v" in state:
            state["backend_adam_w_m"] = np.asarray(state["opt_backend_adam_w_m"])
            state["backend_adam_w_v"] = np.asarray(state["opt_backend_adam_w_v"])
            if "opt_backend_adam_w_step" in state:
                state["backend_adam_w_step"] = np.asarray(state["opt_backend_adam_w_step"], dtype=np.int64)
            if "opt_backend_adam_w_beta1" in state:
                state["backend_adam_w_beta1"] = np.asarray(state["opt_backend_adam_w_beta1"], dtype=np.float64)
            if "opt_backend_adam_w_beta2" in state:
                state["backend_adam_w_beta2"] = np.asarray(state["opt_backend_adam_w_beta2"], dtype=np.float64)
            if "opt_backend_adam_w_epsilon" in state:
                state["backend_adam_w_epsilon"] = np.asarray(state["opt_backend_adam_w_epsilon"], dtype=np.float64)
        print("[edit_ckpt] restored w/x from opt_w/opt_x")

    if args.restore_vmin_snapshot:
        snap_path = _plateau_vmin_path(args.out, args.prefix, args.suffix)
        if not os.path.exists(snap_path):
            raise FileNotFoundError(snap_path)
        snap = np.load(snap_path, allow_pickle=False)
        if "w" not in snap.files or "x" not in snap.files:
            raise ValueError(f"vmin snapshot missing w/x: {snap_path}")
        state["w"] = np.asarray(snap["w"])
        state["x"] = np.asarray(snap["x"])
        if "alpha_multiplier" in snap.files:
            state["alpha_multiplier"] = float(np.asarray(snap["alpha_multiplier"]).item())
        if "epoch" in snap.files:
            snap_epoch = int(np.asarray(snap["epoch"]).item())
            print(f"[edit_ckpt] restored w/x from plateau vmin snapshot (epoch={snap_epoch})")
        else:
            print("[edit_ckpt] restored w/x from plateau vmin snapshot")

    if args.restore_run_best:
        snap_path = _run_best_path(args.out, args.prefix, args.suffix)
        if not os.path.exists(snap_path):
            raise FileNotFoundError(snap_path)
        snap = np.load(snap_path, allow_pickle=False)
        if "w" not in snap.files or "x" not in snap.files:
            raise ValueError(f"run_best snapshot missing w/x: {snap_path}")
        state["w"] = np.asarray(snap["w"])
        state["x"] = np.asarray(snap["x"])
        if "alpha_multiplier" in snap.files:
            state["alpha_multiplier"] = float(np.asarray(snap["alpha_multiplier"]).item())
        if "epoch" in snap.files:
            snap_epoch = int(np.asarray(snap["epoch"]).item())
            snap_err = float(np.asarray(snap["mean_error"]).item()) if "mean_error" in snap.files else float("nan")
            print(f"[edit_ckpt] restored w/x from run-best snapshot (epoch={snap_epoch}, err={snap_err})")
        else:
            print("[edit_ckpt] restored w/x from run-best snapshot")

    if args.set_alpha_multiplier is not None:
        state["alpha_multiplier"] = float(args.set_alpha_multiplier)
        print(f"[edit_ckpt] alpha_multiplier={float(state['alpha_multiplier'])}")

    if args.set_start_epoch is not None:
        state["start_epoch"] = int(args.set_start_epoch)
        print(f"[edit_ckpt] start_epoch={int(state['start_epoch'])}")

    reset = args.reset_adam
    if reset in ("w", "both"):
        if "adam_m_w" in state:
            state["adam_m_w"] = np.zeros_like(state["adam_m_w"])
        if "adam_v_w" in state:
            state["adam_v_w"] = np.zeros_like(state["adam_v_w"])
        state["beta_1_t_w"] = np.asarray(1.0, dtype=np.float64)
        state["beta_2_t_w"] = np.asarray(1.0, dtype=np.float64)
        if "backend_adam_w_m" in state:
            state["backend_adam_w_m"] = np.zeros_like(state["backend_adam_w_m"])
        if "backend_adam_w_v" in state:
            state["backend_adam_w_v"] = np.zeros_like(state["backend_adam_w_v"])
        if "backend_adam_w_step" in state:
            state["backend_adam_w_step"] = np.asarray(0, dtype=np.int64)
        print("[edit_ckpt] reset Adam state for w")

    if reset in ("x", "both"):
        if "adam_m_x" in state:
            state["adam_m_x"] = np.zeros_like(state["adam_m_x"])
        if "adam_v_x" in state:
            state["adam_v_x"] = np.zeros_like(state["adam_v_x"])
        state["beta_1_t_x"] = np.asarray(1.0, dtype=np.float64)
        state["beta_2_t_x"] = np.asarray(1.0, dtype=np.float64)
        print("[edit_ckpt] reset Adam state for x")

    # Normalize types for savez_compressed
    payload = {}
    for k, v in state.items():
        if k in ("start_epoch", "opt_epoch"):
            payload[k] = int(np.asarray(v).item())
        elif k in (
            "opt_mean_error",
            "alpha_multiplier",
            "beta_1_t_w",
            "beta_2_t_w",
            "beta_1_t_x",
            "beta_2_t_x",
        ):
            payload[k] = float(np.asarray(v).item())
        else:
            payload[k] = np.asarray(v)

    np.savez_compressed(ckpt, **payload)
    print(f"[edit_ckpt] wrote: {ckpt}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
