from __future__ import annotations

import math
from typing import Any

import numpy as np
import torch

from worm_train_config import WormTrainConfig


def _to_numpy_dLtdv_lr(dLtdv: np.ndarray | torch.Tensor, *, k_mul: int) -> np.ndarray:
    # NOTE: `dLtdv` is on dt-grid; backend expects LR-grid (subsampled by K_mul).
    if isinstance(dLtdv, torch.Tensor):
        dLtdv_lr = dLtdv[:, :: int(k_mul)].detach()
        dLtdv_lr = torch.nan_to_num(dLtdv_lr, nan=0.0, posinf=0.0, neginf=0.0)
        return np.asarray(dLtdv_lr.contiguous().cpu().numpy(), dtype=np.float32)
    return np.nan_to_num(
        np.asarray(dLtdv[:, :: int(k_mul)], dtype=np.float32),
        nan=0.0,
        posinf=0.0,
        neginf=0.0,
    )


def _align_dLtdv_lr_to_replay_ticks(
    dLtdv_lr: np.ndarray,
    *,
    k_mul: int,
    it_lr: Any | None,
    total_steps: int,
) -> np.ndarray:
    # Align dL/dv (LR grid) with the replay loop's dvtdw tick index:
    # - replay dvtdw tick `t_lr=1` uses signals captured after the first K_mul fadvance steps (t=K_mul*dt),
    #   while `dLtdv[:, ::K_mul]` includes the t=0 sample at index 0.
    # - Backend kernel indexes `dLtdv_lr_to[t_lr-1]`, so we shift by one.
    #
    # We don't have `dLtdv` for the final t=K_mul*ksteps_total sample because the legacy loss uses
    # `output_vs[:, lr_start:lr_end]` (excludes the last recorder point). We pad the last LR tick by
    # repeating the last available sample.
    if it_lr is not None:
        ksteps_total = int(it_lr.shape[1] - 1)
    else:
        ksteps_total = int(int(total_steps) // int(k_mul))
    if dLtdv_lr.shape[1] < ksteps_total:
        raise RuntimeError(f"dLtdv_lr has too few LR samples ({dLtdv_lr.shape[1]}) for ksteps_total={ksteps_total}")
    dLtdv_lr_aligned = np.empty((dLtdv_lr.shape[0], ksteps_total), dtype=np.float32)
    if ksteps_total > 1:
        dLtdv_lr_aligned[:, : ksteps_total - 1] = dLtdv_lr[:, 1:ksteps_total]
    dLtdv_lr_aligned[:, ksteps_total - 1] = dLtdv_lr[:, ksteps_total - 1]
    return np.ascontiguousarray(dLtdv_lr_aligned, dtype=np.float32)


def compute_dw_dx(
    *,
    net,
    backend,
    dLtdv: np.ndarray | torch.Tensor,
    it_lr: Any | None,
    ditdv_lr: Any | None,
    ditdvpre_lr: Any | None,
    x: np.ndarray,
    total_steps: int,
    k_mul: int,
    percise: bool,
    cfg: WormTrainConfig,
    epoch_prof: Any | None,
    lr_start: int,
    lr_end: int,
    dt_ms: float,
    tstop_ms: float,
    v_init: float,
) -> tuple[np.ndarray, np.ndarray]:
    dLtdv_lr = _to_numpy_dLtdv_lr(dLtdv, k_mul=int(k_mul))
    dLtdv_lr = _align_dLtdv_lr_to_replay_ticks(
        dLtdv_lr,
        k_mul=int(k_mul),
        it_lr=it_lr,
        total_steps=int(total_steps),
    )

    if it_lr is None and bool(cfg.replay_cache_signals):
        if epoch_prof:
            with epoch_prof.phase("get_dw_dx_backend_cached"):
                grads = backend.replay_grads_from_cached_signals(
                    dLtdv_lr,
                    percise=bool(percise),
                    dt_ms=float(dt_ms),
                    grad_scale=1.0,
                    eps=1e-6,
                )
        else:
            grads = backend.replay_grads_from_cached_signals(
                dLtdv_lr,
                percise=bool(percise),
                dt_ms=float(dt_ms),
                grad_scale=1.0,
                eps=1e-6,
            )
        dw = np.asarray(grads.dw_out_n)
        if grads.dx_lr_it is None:
            ksteps_total = int(dLtdv_lr.shape[1])
            dx = np.zeros((int(net.N_input), int(ksteps_total)), dtype=np.float32)
        else:
            dx = np.asarray(grads.dx_lr_it)
    elif it_lr is None:
        if epoch_prof:
            with epoch_prof.phase("get_dw_dx_backend_streaming"):
                grads = backend.simulate_and_replay_grads_streaming(
                    x,
                    dLtdv_lr,
                    tstop_ms=float(tstop_ms),
                    dt_ms=float(dt_ms),
                    k_mul=int(k_mul),
                    percise=bool(percise),
                    use_vecplay=bool(cfg.replay_use_vecplay),
                    v_init=float(v_init),
                    grad_scale=1.0,
                    eps=1e-6,
                    assume_weights_already_set=True,
                    assume_inputs_already_played=bool(cfg.replay_use_vecplay),
                )
        else:
            grads = backend.simulate_and_replay_grads_streaming(
                x,
                dLtdv_lr,
                tstop_ms=float(tstop_ms),
                dt_ms=float(dt_ms),
                k_mul=int(k_mul),
                percise=bool(percise),
                use_vecplay=bool(cfg.replay_use_vecplay),
                v_init=float(v_init),
                grad_scale=1.0,
                eps=1e-6,
                assume_weights_already_set=True,
                assume_inputs_already_played=bool(cfg.replay_use_vecplay),
            )
        dw = np.asarray(grads.dw_out_n)
        dx = np.asarray(grads.dx_lr_it) if grads.dx_lr_it is not None else np.zeros((int(net.N_input), int(dLtdv_lr.shape[1])), dtype=np.float32)
    else:
        if epoch_prof:
            with epoch_prof.phase("get_dw_dx_backend"):
                grads = backend.replay_grads_from_signals(
                    it_lr,
                    ditdv_lr,
                    ditdvpre_lr,
                    dLtdv_lr,
                    percise=bool(percise),
                    dt_ms=float(dt_ms),
                    grad_scale=1.0,
                    eps=1e-6,
                )
        else:
            grads = backend.replay_grads_from_signals(
                it_lr,
                ditdv_lr,
                ditdvpre_lr,
                dLtdv_lr,
                percise=bool(percise),
                dt_ms=float(dt_ms),
                grad_scale=1.0,
                eps=1e-6,
            )
        dw = np.asarray(grads.dw_out_n)
        dx = np.asarray(grads.dx_lr_it) if grads.dx_lr_it is not None else np.zeros((int(net.N_input), int(dLtdv_lr.shape[1])), dtype=np.float32)
    return np.asarray(dw), np.asarray(dx)


def prepare_dw_dx_for_update(
    *,
    dw: np.ndarray,
    dx: np.ndarray,
    epoch: int,
    cfg: WormTrainConfig,
    logger,
    ADAM_W: bool,
    ADAM_X: bool,
    use_backend_opt_w: bool,
    alpha_w: float,
    alpha_x: float,
    k_mul: int,
    tstep: int,
    lr_start: int,
    lr_end: int,
    x: np.ndarray,
    beta_1: float,
    beta_2: float,
    epsilon: float,
    use_heliox_learn_adam: bool,
    AdamState: Any | None,
    adam_m_w: Any,
    adam_v_w: Any,
    beta_1_t_w: float,
    beta_2_t_w: float,
    adam_m_x: Any,
    adam_v_x: Any,
    beta_1_t_x: float,
    beta_2_t_x: float,
    epoch_prof: Any | None,
) -> tuple[np.ndarray, np.ndarray, bool, bool, Any, Any, float, float, Any, Any, float, float]:
    # Training-time robustness: keep the pipeline finite. The legacy training pipeline
    # used "inf -> 0" behavior to avoid optimizer blow-ups.
    dw = np.nan_to_num(np.asarray(dw), nan=0.0, posinf=0.0, neginf=0.0)
    dx = np.nan_to_num(np.asarray(dx), nan=0.0, posinf=0.0, neginf=0.0)

    freeze_w = bool(cfg.freeze_w)
    freeze_x_base = bool(cfg.freeze_x)

    # Optional x update schedule (intended for the "mostly freeze x, occasionally unfreeze" workflow).
    #
    # IMPORTANT:
    # - `freeze_x=1` must mean "do not update x" by default.
    # - The schedule is only active when explicitly configured (e.g. update_every > 1).
    #   Otherwise the default values would accidentally unfreeze x every epoch.
    x_update_every = int(cfg.x_update_every)
    x_update_burst = int(cfg.x_update_burst)
    x_update_offset = int(cfg.x_update_offset)
    x_update_now = False
    if freeze_x_base:
        if x_update_every > 1 and x_update_burst > 0:
            phase = (int(epoch) - int(x_update_offset)) % int(x_update_every)
            x_update_now = phase < int(x_update_burst)
    freeze_x = bool(freeze_x_base) and (not bool(x_update_now))

    if freeze_w:
        dw[...] = 0.0
    if freeze_x:
        dx[...] = 0.0
    if logger and freeze_x_base and (x_update_every != 1 or x_update_burst != 1 or x_update_offset != 0):
        logger.info(
            f"x-schedule: freeze_x_base=1 update_every={x_update_every} burst={x_update_burst} "
            f"offset={x_update_offset} update_now={int(bool(x_update_now))}"
        )

    if ADAM_W:
        if use_backend_opt_w:
            # Adam for weights is handled inside the HELIOX backend (keeps optimizer state off Python).
            pass
        elif use_heliox_learn_adam and AdamState is not None:
            st = AdamState.from_legacy(
                beta1=beta_1, beta2=beta_2, eps=epsilon, m=adam_m_w, v=adam_v_w, beta1_t=beta_1_t_w, beta2_t=beta_2_t_w
            )
            dw = st.step(dw)
            adam_m_w, adam_v_w, beta_1_t_w, beta_2_t_w = st.to_legacy()
        else:
            adam_m_w = beta_1 * adam_m_w + (1.0 - beta_1) * dw
            adam_v_w = beta_2 * adam_v_w + (1.0 - beta_2) * dw * dw
            beta_1_t_w = beta_1_t_w * beta_1
            beta_2_t_w = beta_2_t_w * beta_2
            m_hat_w = adam_m_w / (1.0 - beta_1_t_w)
            v_hat_w = adam_v_w / (1.0 - beta_2_t_w)
            dw = m_hat_w / (np.sqrt(v_hat_w) + epsilon)

    if ADAM_X and (not freeze_x):
        if use_heliox_learn_adam and AdamState is not None:
            st = AdamState.from_legacy(
                beta1=beta_1, beta2=beta_2, eps=epsilon, m=adam_m_x, v=adam_v_x, beta1_t=beta_1_t_x, beta2_t=beta_2_t_x
            )
            dx = st.step(dx)
            adam_m_x, adam_v_x, beta_1_t_x, beta_2_t_x = st.to_legacy()
        else:
            adam_m_x = beta_1 * adam_m_x + (1.0 - beta_1) * dx
            adam_v_x = beta_2 * adam_v_x + (1.0 - beta_2) * dx * dx
            beta_1_t_x = beta_1_t_x * beta_1
            beta_2_t_x = beta_2_t_x * beta_2
            m_hat_x = adam_m_x / (1.0 - beta_1_t_x)
            v_hat_x = adam_v_x / (1.0 - beta_2_t_x)
            dx = m_hat_x / (np.sqrt(v_hat_x) + epsilon)

    x_l2_coef = float(cfg.x_l2_coef)

    def _do_opt_math():
        nonlocal dw, dx
        if not use_backend_opt_w:
            dw *= float(alpha_w)
        if not freeze_x:
            dx *= float(alpha_x)
            # dx is on LR grid; expand to dt grid and slice to the learning window.
            dx = (
                np.array(
                    [
                        np.interp(np.arange(int(tstep)), np.arange(0, int(tstep), int(k_mul)), dxi)
                        for dxi in dx
                    ]
                )[:, int(lr_start) : int(lr_end)]
            )
            if x_l2_coef != 0.0:
                dx += -x_l2_coef * x[:, int(lr_start) : int(lr_end)]
        else:
            dx[...] = 0.0

    if epoch_prof:
        with epoch_prof.phase("opt_math"):
            _do_opt_math()
    else:
        _do_opt_math()

    return (
        np.asarray(dw),
        np.asarray(dx),
        bool(freeze_w),
        bool(freeze_x),
        adam_m_w,
        adam_v_w,
        float(beta_1_t_w),
        float(beta_2_t_w),
        adam_m_x,
        adam_v_x,
        float(beta_1_t_x),
        float(beta_2_t_x),
    )


def apply_updates(
    *,
    net,
    backend,
    dw: np.ndarray,
    dx: np.ndarray,
    x: np.ndarray,
    lr_start: int,
    lr_end: int,
    freeze_w: bool,
    freeze_x: bool,
    use_backend_opt_w: bool,
    alpha_w: float,
    beta_1: float,
    beta_2: float,
    epsilon: float,
    cfg: WormTrainConfig,
    epoch_prof: Any | None,
    logger,
) -> tuple[np.ndarray, Any | None]:
    backend_adam_w_state = None

    def _do_apply():
        nonlocal x, backend_adam_w_state
        if use_backend_opt_w:
            if not freeze_w:
                debug_backend_adam = bool(cfg.debug_backend_adam_state)
                step_before = None
                if debug_backend_adam and logger:
                    try:
                        step_before = int(backend.get_weight_adam_state().get("step", -1))
                    except Exception:
                        step_before = None
                backend.ensure_weight_adam_optimizer(beta1=float(beta_1), beta2=float(beta_2), epsilon=float(epsilon))
                backend.adam_step_weights_from_dw(dw, learning_rate=float(-alpha_w))
                w_subset = backend.pull_weights()
                backend_adam_w_state = backend.get_weight_adam_state()
                if debug_backend_adam and logger:
                    try:
                        step_after = int(backend_adam_w_state.get("step", -1))
                        logger.info(f"backend-adam(step): {step_before} -> {step_after}")
                    except Exception:
                        pass
                idx_t = torch.as_tensor(backend.weight_p_indices, dtype=torch.long)
                net.w[idx_t] = torch.asarray(w_subset, dtype=torch.float32)
                net.set_weights()
        else:
            net.update_weights(dw)

        if not freeze_x:
            x[:, int(lr_start) : int(lr_end)] += dx
            x = np.clip(x, a_min=-0.2, a_max=0.2)

    if epoch_prof:
        with epoch_prof.phase("apply_updates"):
            _do_apply()
    else:
        _do_apply()

    return x, backend_adam_w_state
