from __future__ import annotations

import time
from typing import Any

import numpy as np

from worm_train_config import WormTrainConfig


def simulate_and_record_epoch(
    *,
    net,
    backend,
    output_names: list[str],
    x: np.ndarray,
    epoch: int,
    cfg: WormTrainConfig,
    epoch_prof: Any | None,
    tstop_ms: float,
    dt_ms: float,
    v_init: float,
    total_steps: int,
    k_mul: int,
    percise: bool,
) -> tuple[np.ndarray, Any | None, Any | None, Any | None]:
    """
    Run one epoch forward simulation (HELIOX-only) and record:
    - output voltage traces (output_vs)
    - optional LR signals for dv/dw replay (it_lr, ditdv_lr, ditdvpre_lr)

    This is a direct extraction of the simulation/capture logic from the training implementation,
    intended to reduce clutter in the main per-epoch function.
    """

    tstep = 0
    print_every_steps = max(1, int(round(float(cfg.print_interval_ms) / float(dt_ms))))
    pre_time = time.time()

    output_vs = None
    it_lr = None
    ditdv_lr = None
    ditdvpre_lr = None

    if bool(cfg.replay):
        if bool(cfg.replay_cache_signals):
            # One-simulation-pass path:
            # - simulate + capture output_vs to CPU, cache LR signals on GPU
            # - Python computes corr loss + dL/dv
            # - backend replays dw/dx from cached signals (no re-sim, no uploading huge matrices)
            if epoch_prof:
                with epoch_prof.phase("simulate_capture_cached"):
                    output_vs = backend.simulate_and_capture_lr_signals_cached(
                        x,
                        tstop_ms=float(tstop_ms),
                        dt_ms=float(dt_ms),
                        k_mul=int(k_mul),
                        percise=bool(percise),
                        use_vecplay=bool(cfg.replay_use_vecplay),
                        v_init=float(v_init),
                        assume_weights_already_set=True,
                        assume_inputs_already_played=False,
                    )
            else:
                output_vs = backend.simulate_and_capture_lr_signals_cached(
                    x,
                    tstop_ms=float(tstop_ms),
                    dt_ms=float(dt_ms),
                    k_mul=int(k_mul),
                    percise=bool(percise),
                    use_vecplay=bool(cfg.replay_use_vecplay),
                    v_init=float(v_init),
                    assume_weights_already_set=True,
                    assume_inputs_already_played=False,
                )
        elif bool(cfg.replay_streaming):
            # Streaming replay:
            # - pass 1: simulate and capture only output v(t) (for objective/loss).
            # - pass 2: simulate again and compute dw/dx on the backend with a GPU ring buffer (no full it_lr matrices).
            if epoch_prof:
                with epoch_prof.phase("simulate_output_vs"):
                    output_vs = backend.simulate_output_vs(
                        x,
                        tstop_ms=float(tstop_ms),
                        dt_ms=float(dt_ms),
                        use_vecplay=bool(cfg.replay_use_vecplay),
                        v_init=float(v_init),
                        assume_weights_already_set=False,
                        assume_inputs_already_played=False,
                    )
            else:
                output_vs = backend.simulate_output_vs(
                    x,
                    tstop_ms=float(tstop_ms),
                    dt_ms=float(dt_ms),
                    use_vecplay=bool(cfg.replay_use_vecplay),
                    v_init=float(v_init),
                    assume_weights_already_set=False,
                    assume_inputs_already_played=False,
                )
        else:
            if epoch_prof:
                with epoch_prof.phase("simulate_capture"):
                    output_vs, it_lr, ditdv_lr, ditdvpre_lr = backend.simulate_and_capture_lr_signals(
                        x,
                        tstop_ms=float(tstop_ms),
                        dt_ms=float(dt_ms),
                        k_mul=int(k_mul),
                        percise=bool(percise),
                        use_vecplay=bool(cfg.replay_use_vecplay),
                        v_init=float(v_init),
                    )
            else:
                output_vs, it_lr, ditdv_lr, ditdvpre_lr = backend.simulate_and_capture_lr_signals(
                    x,
                    tstop_ms=float(tstop_ms),
                    dt_ms=float(dt_ms),
                    k_mul=int(k_mul),
                    percise=bool(percise),
                    use_vecplay=bool(cfg.replay_use_vecplay),
                    v_init=float(v_init),
                )

        # In replay modes, the forward pass is done by HELIOX helpers, so we don't step in Python.
        tstep = int(total_steps)

    else:
        output_vs = np.zeros((len(output_names), int(total_steps) + 1), dtype=np.float32)
        output_vs[:, 0] = backend.read_output_v()
        if epoch_prof:
            with epoch_prof.phase("simulate_step_loop"):
                while backend.get_t() < float(tstop_ms):
                    if bool(cfg.print_timestep) and (tstep % print_every_steps) == 0:
                        now_time = time.time()
                        t_ms = backend.get_t()
                        print(f"epoch: {epoch}, t: {t_ms}, used: {now_time-pre_time:.3f}s", end="\r")
                        pre_time = now_time
                    backend.set_input_amps(x[:, tstep])
                    backend.fadvance()
                    if tstep + 1 < output_vs.shape[1]:
                        output_vs[:, tstep + 1] = backend.read_output_v()
                    tstep += 1
                    if tstep % int(k_mul) == 0:
                        net.update_dvdw(tstep // int(k_mul), percise=bool(percise))
        else:
            while backend.get_t() < float(tstop_ms):
                if bool(cfg.print_timestep) and (tstep % print_every_steps) == 0:
                    now_time = time.time()
                    t_ms = backend.get_t()
                    print(f"epoch: {epoch}, t: {t_ms}, used: {now_time-pre_time:.3f}s", end="\r")
                    pre_time = now_time
                backend.set_input_amps(x[:, tstep])
                backend.fadvance()
                if tstep + 1 < output_vs.shape[1]:
                    output_vs[:, tstep + 1] = backend.read_output_v()
                tstep += 1
                if tstep % int(k_mul) == 0:
                    net.update_dvdw(tstep // int(k_mul), percise=bool(percise))

    if bool(cfg.print_timestep):
        print("")

    return output_vs, it_lr, ditdv_lr, ditdvpre_lr
