"""
Generate advection–diffusion sweep data for the revision figures.

We fix a divergence-free synthetic velocity field (via `flux_scale`) and sweep
over time steps. MG always runs; FVM is attempted for every dt (may diverge).
Results are stored under `data/revision/drift/<timestamp>/` with a manifest
describing all runs.
"""

from __future__ import annotations

import json
import os
import shutil
import subprocess
import time
import signal
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
DATA_ROOT = REPO_ROOT / "data"

# Allow skipping FVM for quick MG-only iterations
ONLY_MG = os.environ.get("DUMUX_ONLY_MG", "0") == "1"

# Experiment knobs (tuned for visible advection while keeping FVM stable for small dt)
FLUX_SCALE = float(os.environ.get("DUMUX_FLUX_SCALE", "1e-1"))
DIFFUSION_COEFF = float(os.environ.get("DUMUX_DIFFUSION_COEFF", "1e-9"))
TARGET_DX = float(os.environ.get("DUMUX_TARGET_DX", "1e-7"))
MIN_CELLS_PER_MIN_EDGE = int(os.environ.get("DUMUX_MIN_CELLS_PER_MIN_EDGE", "10"))
TIME_END = float(os.environ.get("DUMUX_TIME_END", "200.0"))
DT_GRID = [float(x) for x in os.environ.get("DUMUX_DT_GRID", "0.1,1,10,20").split(",")]
CFL_LIMIT = 1.0  # still recorded for reporting, but we attempt FVM regardless
PARTICLES = int(os.environ.get("DUMUX_MG_PARTICLES", "1000000"))
MG_RECORD_INTERVAL = int(os.environ.get("DUMUX_MG_RECORD_INTERVAL", "10"))
TRACER_SAMPLE_EVERY = int(
    os.environ.get("DUMUX_TRACER_SAMPLE_EVERY", "10")
)  # subsample DuMuX snapshots

# Paths (resolved after timestamp is set)
INPUT_COPY_NAME = "network_input.npz"


def _run_cmd(cmd: List[str], env: Dict[str, str]) -> None:
    res = subprocess.run(cmd, env=env, cwd=str(REPO_ROOT))
    res.check_returncode()


def _run_cmd_with_logs(
    cmd: List[str],
    env: Dict[str, str],
    label: str,
    max_logs: int = 10,
    interval: float = 30.0,
) -> None:
    """
    Run a command and emit periodic keepalive logs (up to max_logs).
    Useful for longer MG runs that otherwise stay silent.
    """
    proc = subprocess.Popen(cmd, env=env, cwd=str(REPO_ROOT))
    start = time.monotonic()
    next_log = interval
    logs_emitted = 0
    while True:
        ret = proc.poll()
        if ret is not None:
            if ret != 0:
                raise subprocess.CalledProcessError(ret, cmd)
            break
        elapsed = time.monotonic() - start
        if logs_emitted < max_logs and elapsed >= next_log:
            print(f"[{label}] running for {elapsed / 60:.1f} min …")
            logs_emitted += 1
            next_log += interval
        time.sleep(1.0)


def _dt_tag(dt: float) -> str:
    # Safe file/field tag (e.g., 0.005 -> dt0005)
    s = f"{dt:.4f}".rstrip("0").rstrip(".")
    return s.replace(".", "p")


def preprocess(out_dir: Path, base_env: Dict[str, str]) -> Dict[str, Any]:
    """Run preprocess_network once and copy the input NPZ into the sweep folder."""
    output_npz = out_dir / INPUT_COPY_NAME
    env = base_env.copy()
    # Build the DGF/NPZ with our chosen flux/dx.
    _run_cmd(
        [
            "python",
            "-m",
            "experiments.dumux_tracer.preprocess_network",
            "--target-dx",
            str(TARGET_DX),
            "--min-cells-per-min-edge",
            str(MIN_CELLS_PER_MIN_EDGE),
            "--flux-scale",
            str(FLUX_SCALE),
            "--output-npz",
            str(output_npz),
        ],
        env=env,
    )
    # Also overwrite the default input NPZ path so downstream DuMuX/MG pick it up.
    default_input = REPO_ROOT / "data" / "dumux_advection_network_input.npz"
    shutil.copy(output_npz, default_input)

    data = np.load(output_npz)
    vel = np.asarray(data["cell_velocities"], dtype=float)
    base_dx = float(np.asarray(data["base_dx"]).item())
    cfl_map = {f"{dt}": float(np.max(np.abs(vel)) * dt / base_dx) for dt in DT_GRID}
    stats = {
        "input_npz": output_npz.name,
        "cells": int(vel.size),
        "vel_min": float(vel.min()),
        "vel_max": float(vel.max()),
        "vel_median": float(np.median(vel)),
        "base_dx": base_dx,
        "cfl": cfl_map,
    }
    return stats


def run_fvm(
    dt: float,
    out_dir: Path,
    env: Dict[str, str],
    cfl_max: float,
) -> Dict[str, Any]:
    """Run DuMuX FVM for a single dt (status may be 'fail' on divergence)."""
    tag = _dt_tag(dt)
    tracer_npz = out_dir / f"dumux_fvm_dt{tag}.npz"
    runtime_file = out_dir / f"dumux_fvm_dt{tag}_runtime.txt"

    env_fvm = env.copy()
    env_fvm.update(
        {
            "DUMUX_DT": str(dt),
            "DUMUX_TRACER_OUTPUT": str(tracer_npz),
            "DUMUX_RUNTIME_FILE": str(runtime_file),
            "DUMUX_TRACER_SAMPLE_EVERY": str(TRACER_SAMPLE_EVERY),
            "DUMUX_TIME_END": str(TIME_END),
        }
    )

    execution_status = "ok"
    stdout = ""
    stderr = ""
    raw_timeout = os.environ.get("DUMUX_FVM_TIMEOUT", "120")
    timeout_s: float | None
    try:
        timeout_val = float(raw_timeout)
        timeout_s = None if timeout_val < 0 else timeout_val
    except ValueError:
        timeout_s = None if raw_timeout.lower() == "inf" else 120.0
    # Run the DuMuX solve with a wall-clock guard; extract whatever steps were written so far.
    start_wall = time.monotonic()
    elapsed = None
    try:
        stream_logs = os.environ.get("DUMUX_FVM_STREAM", "0") == "1"
        proc = subprocess.Popen(
            ["make", "dumux-advection-run"],
            env=env_fvm,
            cwd=str(REPO_ROOT),
            stdout=None if stream_logs else subprocess.PIPE,
            stderr=None if stream_logs else subprocess.PIPE,
            text=True,
            preexec_fn=os.setsid,
        )
        try:
            stdout, stderr = proc.communicate(timeout=timeout_s)
        except subprocess.TimeoutExpired:
            execution_status = "timeout"
            os.killpg(proc.pid, signal.SIGKILL)
            stdout, stderr = proc.communicate()
        elapsed = time.monotonic() - start_wall
        if proc.returncode not in (0, None) and execution_status != "timeout":
            execution_status = "fail"
    except Exception as exc:
        execution_status = "fail"
        stderr = f"{stderr}\n{exc}"
    if elapsed is None:
        elapsed = time.monotonic() - start_wall

    # Always attempt extraction of whatever was written.
    try:
        _run_cmd(
            ["python", "-m", "experiments.dumux_tracer.extract_dumux_tracer"],
            env=env_fvm,
        )
    except Exception:
        # ignore extraction failure; will be reflected by missing npz
        pass

    if not tracer_npz.exists():
        return {
            "status": execution_status,
            "dt": dt,
            "cfl_max": cfl_max,
            "npz": tracer_npz.name,
            "runtime_file": runtime_file.name,
            "execution_status": execution_status,
        }

    data = np.load(tracer_npz)
    tracer = np.asarray(data["tracer"], dtype=float)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
    area = (
        np.pi * cell_radii * cell_radii
        if cell_radii.size == cell_lengths.size
        else np.ones_like(cell_lengths)
    )
    molar_density = 1000.0 / 0.018
    total_mass = (tracer[0] * cell_lengths * area * molar_density).sum()
    wall = float(runtime_file.read_text().strip()) if runtime_file.exists() else elapsed
    times = np.asarray(data.get("times", []), dtype=float)
    steps = int(np.ceil(TIME_END / dt))
    if times.size > 0:
        steps = times.size
    dt_diffs = np.diff(times) if times.size > 1 else np.array([dt])
    min_dt = float(dt_diffs.min()) if dt_diffs.size else dt
    max_dt = float(dt_diffs.max()) if dt_diffs.size else dt
    # If we're subsampling outputs, the observed dt_diffs reflect sampling, not adaptation.
    adaptive = False
    if TRACER_SAMPLE_EVERY <= 1:
        adaptive = dt_diffs.size > 0 and (min_dt < 0.9 * dt or max_dt > 1.1 * dt)
    adaptive = adaptive or ("Retrying with time step" in stdout)

    info = {
        "status": "adaptive" if adaptive else execution_status,
        "execution_status": execution_status,
        "dt": dt,
        "cfl_max": cfl_max,
        "npz": tracer_npz.name,
        "runtime_file": runtime_file.name,
        "wall_time_s": wall,
        "steps": steps,
        "per_step_s": wall / steps if steps > 0 else None,
        "total_mass0": total_mass,
        "min_dt": min_dt,
        "max_dt": max_dt,
        "adaptive_time_step": adaptive,
        "log_contains_retry": "Retrying with time step" in stdout,
        "stdout": stdout,
        "stderr": stderr,
    }
    return info


def run_mg(
    dt: float, out_dir: Path, env: Dict[str, str], use_dumux_init: bool
) -> Dict[str, Any]:
    tag = _dt_tag(dt)
    mg_out = out_dir / f"mg_dt{tag}.npz"
    env_mg = env.copy()
    env_mg.update(
        {
            "DUMUX_DT": str(dt),
            "DUMUX_TIME_END": str(TIME_END),
            "DUMUX_TRACER_OUTPUT": str(out_dir / f"dumux_fvm_dt{tag}.npz"),
            "DUMUX_METRIC_OUTPUT": str(mg_out),
            "DUMUX_MG_PARTICLES": str(PARTICLES),
            "DUMUX_MG_RECORD_INTERVAL": str(MG_RECORD_INTERVAL),
            "DUMUX_INITIAL_FROM_DUMUX": "1" if use_dumux_init else "0",
        }
    )

    status = "ok"
    try:
        _run_cmd_with_logs(
            ["make", "dumux-advection-metric"], env=env_mg, label=f"MG dt={dt}"
        )
    except subprocess.CalledProcessError:
        status = "fail"

    if not mg_out.exists():
        return {"status": status, "dt": dt, "npz": mg_out.name}

    data = np.load(mg_out)
    density = np.asarray(data["density"], dtype=float)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    steps = int(np.asarray(data.get("steps", [density.shape[0]])).item())
    wall = float(np.asarray(data.get("total_wall_time", [np.nan])).item())
    total_mass = (density[-1] * cell_lengths).sum()
    mg_dt = float(np.asarray(data.get("dt", [np.nan])).item())

    return {
        "status": status,
        "dt": dt,
        "npz": mg_out.name,
        "mg_dt": mg_dt,
        "steps": steps,
        "wall_time_s": wall,
        "per_step_s": wall / steps if steps > 0 else None,
        "total_mass_final": total_mass,
    }


def main() -> None:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_dir = DATA_ROOT / "revision" / "drift" / timestamp
    out_dir.mkdir(parents=True, exist_ok=True)

    base_env = os.environ.copy()
    base_env.update(
        {
            "DUMUX_SETTINGS_MODULE": "experiments.dumux_tracer.advection_settings",
            "DUMUX_CONFIG_MODULE": "experiments.dumux_tracer.advection_metric_graph_config",
            "DUMUX_COMPARE_MODULE": "experiments.dumux_tracer.compare_config",
            "DUMUX_FLUX_SCALE": str(FLUX_SCALE),
            "DUMUX_DIFFUSION_COEFF": str(DIFFUSION_COEFF),
            "DUMUX_TARGET_DX": str(TARGET_DX),
            "DUMUX_MIN_CELLS_PER_MIN_EDGE": str(MIN_CELLS_PER_MIN_EDGE),
            "DUMUX_TIME_END": str(TIME_END),
            "DUMUX_DT": str(
                DT_GRID[0]
            ),  # DuMuX params input uses this as initial; per-run overrides below
        }
    )

    preprocess_info = preprocess(out_dir, base_env)

    runs: List[Dict[str, Any]] = []
    for dt in DT_GRID:
        cfl_max = preprocess_info["cfl"][f"{dt}"]
        # Attempt FVM (may diverge); always run MG regardless of FVM status.
        fvm_info: Dict[str, Any] = {}
        if not ONLY_MG:
            fvm_info = run_fvm(dt, out_dir, base_env, cfl_max=cfl_max)
        use_dumux_init = False  # keep MG independent of potentially diverged FVM
        mg_info = run_mg(dt, out_dir, base_env, use_dumux_init=use_dumux_init)

        entry = {"dt": dt, "cfl_max": cfl_max, "fvm": fvm_info, "mg": mg_info}
        runs.append(entry)

    manifest = {
        "timestamp": timestamp,
        "only_mg": ONLY_MG,
        "settings": {
            "flux_scale": FLUX_SCALE,
            "diffusion_coefficient": DIFFUSION_COEFF,
            "target_dx": TARGET_DX,
            "min_cells_per_min_edge": MIN_CELLS_PER_MIN_EDGE,
            "time_end": TIME_END,
            "dt_grid": DT_GRID,
            "particles": PARTICLES,
            "mg_record_interval": MG_RECORD_INTERVAL,
            "cfl_limit": CFL_LIMIT,
            "tracer_sample_every": TRACER_SAMPLE_EVERY,
        },
        "preprocess": preprocess_info,
        "runs": runs,
    }
    manifest_path = out_dir / "manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2))
    print(f"[revision drift] wrote manifest {manifest_path}")


if __name__ == "__main__":
    main()
