from __future__ import annotations

import copy
from typing import Any

import numpy as np

from worm_checkpoint import _save_ckpt
from worm_train_config import WormTrainConfig


def _persist_backend_adam_state(
    *,
    state: dict[str, Any],
    backend_adam_w_state: dict | None,
    opt_backend_adam_w_state: dict | None,
    plateau_vmin_best_backend_adam_w_state: dict | None,
    beta_1: float,
    beta_2: float,
    epsilon: float,
):
    if backend_adam_w_state is None:
        return

    state.update(
        {
            "backend_adam_w_step": int(backend_adam_w_state.get("step", 0)),
            "backend_adam_w_m": np.asarray(backend_adam_w_state.get("m", []), dtype=np.float64),
            "backend_adam_w_v": np.asarray(backend_adam_w_state.get("v", []), dtype=np.float64),
            "backend_adam_w_beta1": float(backend_adam_w_state.get("beta1", beta_1)),
            "backend_adam_w_beta2": float(backend_adam_w_state.get("beta2", beta_2)),
            "backend_adam_w_epsilon": float(backend_adam_w_state.get("epsilon", epsilon)),
        }
    )

    if opt_backend_adam_w_state is not None:
        state.update(
            {
                "opt_backend_adam_w_step": int(opt_backend_adam_w_state.get("step", 0)),
                "opt_backend_adam_w_m": np.asarray(opt_backend_adam_w_state.get("m", []), dtype=np.float64),
                "opt_backend_adam_w_v": np.asarray(opt_backend_adam_w_state.get("v", []), dtype=np.float64),
                "opt_backend_adam_w_beta1": float(opt_backend_adam_w_state.get("beta1", beta_1)),
                "opt_backend_adam_w_beta2": float(opt_backend_adam_w_state.get("beta2", beta_2)),
                "opt_backend_adam_w_epsilon": float(opt_backend_adam_w_state.get("epsilon", epsilon)),
            }
        )

    if plateau_vmin_best_backend_adam_w_state is not None:
        state.update(
            {
                "plateau_vmin_best_backend_adam_w_step": int(plateau_vmin_best_backend_adam_w_state.get("step", 0)),
                "plateau_vmin_best_backend_adam_w_m": np.asarray(
                    plateau_vmin_best_backend_adam_w_state.get("m", []), dtype=np.float64
                ),
                "plateau_vmin_best_backend_adam_w_v": np.asarray(
                    plateau_vmin_best_backend_adam_w_state.get("v", []), dtype=np.float64
                ),
                "plateau_vmin_best_backend_adam_w_beta1": float(
                    plateau_vmin_best_backend_adam_w_state.get("beta1", beta_1)
                ),
                "plateau_vmin_best_backend_adam_w_beta2": float(
                    plateau_vmin_best_backend_adam_w_state.get("beta2", beta_2)
                ),
                "plateau_vmin_best_backend_adam_w_epsilon": float(
                    plateau_vmin_best_backend_adam_w_state.get("epsilon", epsilon)
                ),
            }
        )


def maybe_emergency_brake(
    *,
    cfg: WormTrainConfig,
    epoch: int,
    mean_error: float,
    opt_mean_error: float,
    opt_epoch: int,
    alpha_multiplier: float,
    net,
    backend,
    x: np.ndarray,
    # snapshots to restore
    opt_w: np.ndarray,
    opt_x: np.ndarray,
    # optimizer snapshots
    ADAM_W: bool,
    ADAM_X: bool,
    use_backend_opt_w: bool,
    beta_1: float,
    beta_2: float,
    epsilon: float,
    adam_m_w: Any,
    adam_v_w: Any,
    beta_1_t_w: float,
    beta_2_t_w: float,
    adam_m_x: Any,
    adam_v_x: Any,
    beta_1_t_x: float,
    beta_2_t_x: float,
    opt_adam_params_w: Any,
    opt_adam_params_x: Any,
    opt_backend_adam_w_state: dict | None,
    plateau_vmin_best_backend_adam_w_state: dict | None,
    # run-best snapshot (persisted for resume)
    run_best_epoch: int,
    run_best_mean_error: float,
    run_best_vmin: float,
    run_best_vmax: float,
    run_best_w: Any,
    run_best_x: Any,
    # plateau vmin snapshot (persisted for resume)
    plateau_vmin_best: float,
    plateau_vmin_best_epoch: int,
    plateau_vmin_best_error: float,
    plateau_vmin_best_w: Any,
    plateau_vmin_best_x: Any,
    plateau_vmin_best_alpha_multiplier: float,
    plateau_vmin_best_adam_params_w: Any,
    plateau_vmin_best_adam_params_x: Any,
    # io
    train_error: list[float],
    output_path: str,
    prefix: str,
    suffix: str,
    logger=None,
) -> dict | None:
    emergency_err_add = float(cfg.plateau_emergency_err_add)
    if not (emergency_err_add > 0 and float(mean_error) > float(opt_mean_error) + emergency_err_add):
        return None

    if logger:
        logger.info(
            f"plateau-emergency: err={mean_error:.5g} exceeds best+{emergency_err_add} "
            f"(best={opt_mean_error:.5g}); restore weights & x"
        )

    net.set_weights(opt_w)
    x = np.copy(opt_x)

    backend_adam_w_state = None
    if ADAM_W:
        if use_backend_opt_w:
            if opt_backend_adam_w_state is not None:
                backend.set_weight_adam_state(opt_backend_adam_w_state)
                backend_adam_w_state = copy.deepcopy(opt_backend_adam_w_state)
            else:
                backend.reset_weight_adam_state()
                backend_adam_w_state = backend.get_weight_adam_state()
        elif opt_adam_params_w is not None:
            adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w = copy.deepcopy(opt_adam_params_w)
    if ADAM_X and opt_adam_params_x is not None:
        adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x = copy.deepcopy(opt_adam_params_x)

    opt_epoch = int(epoch)
    alpha_multiplier *= float(cfg.plateau_lr_multiplier)

    state: dict[str, Any] = {
        "start_epoch": int(epoch) + 1,
        "x": x,
        "w": net.w.numpy(),
        "train_error": list(train_error),
        "opt_epoch": int(opt_epoch),
        "opt_mean_error": float(opt_mean_error),
        "opt_w": np.asarray(opt_w),
        "opt_x": np.asarray(opt_x),
        "run_best_epoch": int(run_best_epoch),
        "run_best_mean_error": float(run_best_mean_error),
        "run_best_vmin": float(run_best_vmin),
        "run_best_vmax": float(run_best_vmax),
        "run_best_w": run_best_w if run_best_w is not None else None,
        "run_best_x": run_best_x if run_best_x is not None else None,
        "alpha_multiplier": float(alpha_multiplier),
        "beta_1_t_w": float(beta_1_t_w) if ADAM_W else 1.0,
        "beta_2_t_w": float(beta_2_t_w) if ADAM_W else 1.0,
        "beta_1_t_x": float(beta_1_t_x) if ADAM_X else 1.0,
        "beta_2_t_x": float(beta_2_t_x) if ADAM_X else 1.0,
        "adam_m_w": adam_m_w if ADAM_W else None,
        "adam_v_w": adam_v_w if ADAM_W else None,
        "adam_m_x": adam_m_x if ADAM_X else None,
        "adam_v_x": adam_v_x if ADAM_X else None,
        # Persist snapshots for next resume.
        "opt_adam_m_w": opt_adam_params_w[0] if (ADAM_W and opt_adam_params_w is not None) else None,
        "opt_adam_v_w": opt_adam_params_w[1] if (ADAM_W and opt_adam_params_w is not None) else None,
        "opt_beta_1_t_w": opt_adam_params_w[2] if (ADAM_W and opt_adam_params_w is not None) else 1.0,
        "opt_beta_2_t_w": opt_adam_params_w[3] if (ADAM_W and opt_adam_params_w is not None) else 1.0,
        "opt_adam_m_x": opt_adam_params_x[0] if (ADAM_X and opt_adam_params_x is not None) else None,
        "opt_adam_v_x": opt_adam_params_x[1] if (ADAM_X and opt_adam_params_x is not None) else None,
        "opt_beta_1_t_x": opt_adam_params_x[2] if (ADAM_X and opt_adam_params_x is not None) else 1.0,
        "opt_beta_2_t_x": opt_adam_params_x[3] if (ADAM_X and opt_adam_params_x is not None) else 1.0,
        "plateau_vmin_best": float(plateau_vmin_best),
        "plateau_vmin_best_epoch": int(plateau_vmin_best_epoch),
        "plateau_vmin_best_error": float(plateau_vmin_best_error),
        "plateau_vmin_best_w": plateau_vmin_best_w if plateau_vmin_best_w is not None else None,
        "plateau_vmin_best_x": plateau_vmin_best_x if plateau_vmin_best_x is not None else None,
        "plateau_vmin_best_alpha_multiplier": float(plateau_vmin_best_alpha_multiplier),
        "plateau_vmin_best_adam_m_w": plateau_vmin_best_adam_params_w[0]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else None,
        "plateau_vmin_best_adam_v_w": plateau_vmin_best_adam_params_w[1]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else None,
        "plateau_vmin_best_beta_1_t_w": plateau_vmin_best_adam_params_w[2]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else 1.0,
        "plateau_vmin_best_beta_2_t_w": plateau_vmin_best_adam_params_w[3]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else 1.0,
        "plateau_vmin_best_adam_m_x": plateau_vmin_best_adam_params_x[0]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else None,
        "plateau_vmin_best_adam_v_x": plateau_vmin_best_adam_params_x[1]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else None,
        "plateau_vmin_best_beta_1_t_x": plateau_vmin_best_adam_params_x[2]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else 1.0,
        "plateau_vmin_best_beta_2_t_x": plateau_vmin_best_adam_params_x[3]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else 1.0,
    }

    if use_backend_opt_w:
        _persist_backend_adam_state(
            state=state,
            backend_adam_w_state=backend_adam_w_state,
            opt_backend_adam_w_state=opt_backend_adam_w_state,
            plateau_vmin_best_backend_adam_w_state=plateau_vmin_best_backend_adam_w_state,
            beta_1=float(beta_1),
            beta_2=float(beta_2),
            epsilon=float(epsilon),
        )

    _save_ckpt(output_path, prefix, suffix, state, logger=logger)
    return state


def maybe_plateau_restore(
    *,
    cfg: WormTrainConfig,
    epoch: int,
    opt_epoch: int,
    opt_mean_error: float,
    opt_w: np.ndarray,
    opt_x: np.ndarray,
    alpha_multiplier: float,
    net,
    backend,
    x: np.ndarray,
    # plateau vmin candidate
    use_plateau_vmin: bool,
    plateau_vmin_best_epoch: int,
    plateau_vmin_best_error: float,
    plateau_vmin_best: float,
    plateau_vmin_best_w: Any,
    plateau_vmin_best_x: Any,
    plateau_vmin_best_alpha_multiplier: float,
    plateau_vmin_best_adam_params_w: Any,
    plateau_vmin_best_adam_params_x: Any,
    plateau_vmin_best_backend_adam_w_state: dict | None,
    # optimizer snapshots
    ADAM_W: bool,
    ADAM_X: bool,
    use_backend_opt_w: bool,
    beta_1: float,
    beta_2: float,
    epsilon: float,
    adam_m_w: Any,
    adam_v_w: Any,
    beta_1_t_w: float,
    beta_2_t_w: float,
    adam_m_x: Any,
    adam_v_x: Any,
    beta_1_t_x: float,
    beta_2_t_x: float,
    opt_adam_params_w: Any,
    opt_adam_params_x: Any,
    opt_backend_adam_w_state: dict | None,
    # run-best snapshot (persisted for resume)
    run_best_epoch: int,
    run_best_mean_error: float,
    run_best_vmin: float,
    run_best_vmax: float,
    run_best_w: Any,
    run_best_x: Any,
    # io
    train_error: list[float],
    output_path: str,
    prefix: str,
    suffix: str,
    logger=None,
) -> dict | None:
    plateau_patience = int(cfg.plateau_patience_epochs)
    if not (plateau_patience > 0 and (int(epoch) - int(opt_epoch)) >= plateau_patience):
        return None

    restore_tag = "opt"
    restore_w = opt_w
    restore_x = opt_x
    restore_adam_params_w = opt_adam_params_w
    restore_backend_adam_w_state = opt_backend_adam_w_state
    restore_adam_params_x = opt_adam_params_x
    if use_plateau_vmin and (plateau_vmin_best_epoch >= 0) and (plateau_vmin_best_w is not None) and (plateau_vmin_best_x is not None):
        err_tol = float(cfg.plateau_vmin_err_tol)
        if float(plateau_vmin_best_error) <= float(opt_mean_error) + err_tol:
            restore_tag = "vmin"
            restore_w = np.asarray(plateau_vmin_best_w)
            restore_x = np.asarray(plateau_vmin_best_x)
            if plateau_vmin_best_adam_params_w is not None:
                restore_adam_params_w = plateau_vmin_best_adam_params_w
            if plateau_vmin_best_backend_adam_w_state is not None:
                restore_backend_adam_w_state = plateau_vmin_best_backend_adam_w_state
            if plateau_vmin_best_adam_params_x is not None:
                restore_adam_params_x = plateau_vmin_best_adam_params_x

    if logger:
        logger.info(
            f"no improvement since epoch {opt_epoch} for {plateau_patience} epochs, restore weights & x, "
            f"learning rate *= {cfg.plateau_lr_multiplier}"
        )
        if restore_tag == "vmin":
            logger.info(
                f"plateau-restore: prefer vmin checkpoint (epoch={plateau_vmin_best_epoch}, "
                f"err={plateau_vmin_best_error:.5g})"
            )

    net.set_weights(restore_w)
    x = np.copy(restore_x)

    plateau_reset_adam = str(cfg.plateau_reset_adam)
    if plateau_reset_adam not in ("none", "w", "x", "both"):
        raise ValueError("EWORM_PLATEAU_RESET_ADAM must be none|w|x|both")

    def _reset_adam_w():
        nonlocal adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w
        if ADAM_W and (adam_m_w is not None) and (adam_v_w is not None):
            adam_m_w[...] = 0.0
            adam_v_w[...] = 0.0
        beta_1_t_w = 1.0
        beta_2_t_w = 1.0

    def _reset_adam_x():
        nonlocal adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x
        if ADAM_X and (adam_m_x is not None) and (adam_v_x is not None):
            adam_m_x[...] = 0.0
            adam_v_x[...] = 0.0
        beta_1_t_x = 1.0
        beta_2_t_x = 1.0

    backend_adam_w_state = None
    if plateau_reset_adam in ("w", "both"):
        if use_backend_opt_w:
            backend.reset_weight_adam_state()
            backend_adam_w_state = backend.get_weight_adam_state()
        else:
            _reset_adam_w()
    if plateau_reset_adam in ("x", "both"):
        _reset_adam_x()
    if plateau_reset_adam == "none":
        if ADAM_W:
            if use_backend_opt_w:
                if restore_backend_adam_w_state is not None:
                    backend.set_weight_adam_state(restore_backend_adam_w_state)
                    backend_adam_w_state = copy.deepcopy(restore_backend_adam_w_state)
            elif restore_adam_params_w is not None:
                adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w = copy.deepcopy(restore_adam_params_w)
        if ADAM_X and restore_adam_params_x is not None:
            adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x = copy.deepcopy(restore_adam_params_x)

    opt_epoch = int(epoch)
    alpha_multiplier *= float(cfg.plateau_lr_multiplier)

    state: dict[str, Any] = {
        "start_epoch": int(epoch) + 1,
        "x": x,
        "w": net.w.numpy(),
        "train_error": list(train_error),
        "opt_epoch": int(opt_epoch),
        "opt_mean_error": float(opt_mean_error),
        "opt_w": np.asarray(opt_w),
        "opt_x": np.asarray(opt_x),
        "run_best_epoch": int(run_best_epoch),
        "run_best_mean_error": float(run_best_mean_error),
        "run_best_vmin": float(run_best_vmin),
        "run_best_vmax": float(run_best_vmax),
        "run_best_w": run_best_w if run_best_w is not None else None,
        "run_best_x": run_best_x if run_best_x is not None else None,
        "alpha_multiplier": float(alpha_multiplier),
        "beta_1_t_w": float(beta_1_t_w) if ADAM_W else 1.0,
        "beta_2_t_w": float(beta_2_t_w) if ADAM_W else 1.0,
        "beta_1_t_x": float(beta_1_t_x) if ADAM_X else 1.0,
        "beta_2_t_x": float(beta_2_t_x) if ADAM_X else 1.0,
        "adam_m_w": adam_m_w if ADAM_W else None,
        "adam_v_w": adam_v_w if ADAM_W else None,
        "adam_m_x": adam_m_x if ADAM_X else None,
        "adam_v_x": adam_v_x if ADAM_X else None,
        "opt_adam_m_w": opt_adam_params_w[0] if (ADAM_W and opt_adam_params_w is not None) else None,
        "opt_adam_v_w": opt_adam_params_w[1] if (ADAM_W and opt_adam_params_w is not None) else None,
        "opt_beta_1_t_w": opt_adam_params_w[2] if (ADAM_W and opt_adam_params_w is not None) else 1.0,
        "opt_beta_2_t_w": opt_adam_params_w[3] if (ADAM_W and opt_adam_params_w is not None) else 1.0,
        "opt_adam_m_x": opt_adam_params_x[0] if (ADAM_X and opt_adam_params_x is not None) else None,
        "opt_adam_v_x": opt_adam_params_x[1] if (ADAM_X and opt_adam_params_x is not None) else None,
        "opt_beta_1_t_x": opt_adam_params_x[2] if (ADAM_X and opt_adam_params_x is not None) else 1.0,
        "opt_beta_2_t_x": opt_adam_params_x[3] if (ADAM_X and opt_adam_params_x is not None) else 1.0,
        "plateau_vmin_best": float(plateau_vmin_best),
        "plateau_vmin_best_epoch": int(plateau_vmin_best_epoch),
        "plateau_vmin_best_error": float(plateau_vmin_best_error),
        "plateau_vmin_best_w": plateau_vmin_best_w if plateau_vmin_best_w is not None else None,
        "plateau_vmin_best_x": plateau_vmin_best_x if plateau_vmin_best_x is not None else None,
        "plateau_vmin_best_alpha_multiplier": float(plateau_vmin_best_alpha_multiplier),
        "plateau_vmin_best_adam_m_w": plateau_vmin_best_adam_params_w[0]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else None,
        "plateau_vmin_best_adam_v_w": plateau_vmin_best_adam_params_w[1]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else None,
        "plateau_vmin_best_beta_1_t_w": plateau_vmin_best_adam_params_w[2]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else 1.0,
        "plateau_vmin_best_beta_2_t_w": plateau_vmin_best_adam_params_w[3]
        if (ADAM_W and plateau_vmin_best_adam_params_w is not None)
        else 1.0,
        "plateau_vmin_best_adam_m_x": plateau_vmin_best_adam_params_x[0]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else None,
        "plateau_vmin_best_adam_v_x": plateau_vmin_best_adam_params_x[1]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else None,
        "plateau_vmin_best_beta_1_t_x": plateau_vmin_best_adam_params_x[2]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else 1.0,
        "plateau_vmin_best_beta_2_t_x": plateau_vmin_best_adam_params_x[3]
        if (ADAM_X and plateau_vmin_best_adam_params_x is not None)
        else 1.0,
    }

    if use_backend_opt_w:
        _persist_backend_adam_state(
            state=state,
            backend_adam_w_state=backend_adam_w_state,
            opt_backend_adam_w_state=opt_backend_adam_w_state,
            plateau_vmin_best_backend_adam_w_state=plateau_vmin_best_backend_adam_w_state,
            beta_1=float(beta_1),
            beta_2=float(beta_2),
            epsilon=float(epsilon),
        )

    _save_ckpt(output_path, prefix, suffix, state, logger=logger)
    return state
