"""
Generate diffusion-only sweep data for the revision figures.

Runs:
- DuMuX FVM baseline once.
- MG sweeps over particle counts with multiple seeds.

Outputs:
- All NPZ/manifest files under data/revision/diffusion/<timestamp>/.
"""

from __future__ import annotations

import json
import os
import shutil
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List

import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
DATA_ROOT = REPO_ROOT / "data"

# Sweep configuration
PARTICLE_COUNTS = [100_000, 200_000, 500_000, 1_000_000]
# Number of fresh random seeds per particle count (unless overridden below)
NUM_SEEDS = 1
OVERRIDE_SEEDS: Dict[int, List[int]] = {}
TIME_END = 2000.0
DUMUX_DT = 0.05  # shared with MG; keep modest for accuracy

# Paths for baseline outputs (DuMuX)
FVM_NPZ_NAME = "dumux_network_tracer_1d.npz"
FVM_RUNTIME_NAME = "dumux_network_tracer_1d_runtime.txt"


def _run_cmd(cmd: List[str], env: Dict[str, str]) -> None:
    res = subprocess.run(cmd, env=env, cwd=str(REPO_ROOT))
    res.check_returncode()


def run_fvm(out_dir: Path, env: Dict[str, str]) -> Dict[str, Any]:
    print("[revision diffusion] Running FVM baseline...")
    _run_cmd(["make", "dumux-fvm"], env=env)
    fvm_npz = DATA_ROOT / FVM_NPZ_NAME
    fvm_runtime = DATA_ROOT / FVM_RUNTIME_NAME
    if not fvm_npz.exists() or not fvm_runtime.exists():
        raise RuntimeError("FVM outputs missing after dumux-fvm")
    shutil.copy(fvm_npz, out_dir / fvm_npz.name)
    shutil.copy(fvm_runtime, out_dir / fvm_runtime.name)

    data = np.load(fvm_npz)
    tracer = np.asarray(data["tracer"], dtype=float)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
    if cell_radii.size == cell_lengths.size:
        area = np.pi * cell_radii * cell_radii
    else:
        area = np.ones_like(cell_lengths)
    total_mass = (tracer[0] * cell_lengths * area * (1000.0 / 0.018)).sum()
    total_length = float(cell_lengths.sum())

    wall = float(fvm_runtime.read_text().strip())
    steps = int(np.ceil(TIME_END / DUMUX_DT))

    return {
        "npz": fvm_npz.name,
        "runtime_file": fvm_runtime.name,
        "time_end": TIME_END,
        "dt": DUMUX_DT,
        "steps": steps,
        "total_mass0": total_mass,
        "total_length": total_length,
        "cells": int(cell_lengths.size),
        "wall_time_s": wall,
        "per_step_s": wall / steps if steps > 0 else None,
    }


def run_mg(count: int, seed: int, out_dir: Path, base_env: Dict[str, str]) -> Dict[str, Any]:
    env = base_env.copy()
    mg_out = out_dir / f"mg_p{count}_s{seed}.npz"
    env.update(
        {
            "DUMUX_MG_PARTICLES": str(count),
            "DUMUX_MG_SEED": str(seed),
            "DUMUX_METRIC_OUTPUT": str(mg_out),
        }
    )
    status = "ok"
    try:
        _run_cmd(["make", "dumux-metric"], env=env)
    except subprocess.CalledProcessError:
        status = "fail"

    if not mg_out.exists():
        return {
            "particles": count,
            "seed": seed,
            "status": status,
            "npz": str(mg_out),
        }

    data = np.load(mg_out)
    cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
    density = np.asarray(data["density"], dtype=float)
    steps = int(np.asarray(data.get("steps", [density.shape[0]])).item())
    wall = float(np.asarray(data.get("total_wall_time", [np.nan])).item())
    total_mass = (density[-1] * cell_lengths).sum()

    return {
        "particles": count,
        "seed": seed,
        "status": status,
        "npz": mg_out.name,
        "cells": int(cell_lengths.size),
        "steps": steps,
        "time_end": float(np.asarray(data["times"])[-1]),
        "dt": float(np.asarray(data.get("dt", [np.nan])).item()),
        "wall_time_s": wall,
        "per_step_s": wall / steps if steps > 0 else None,
        "total_mass_final": total_mass,
    }


def compute_uniform_error(run_npz: Path, total_mass0: float, total_length: float) -> float:
    data = np.load(run_npz)
    if "density" in data:
        density = np.asarray(data["density"], dtype=float)
        cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
        cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
        if cell_radii.size == cell_lengths.size:
            area = np.pi * cell_radii * cell_radii
        else:
            area = np.ones_like(cell_lengths)
        mass = density[-1] * cell_lengths
        mass_sum = mass.sum()
        if mass_sum > 0:
            mass *= total_mass0 / mass_sum  # rescale to match FVM total mass
        total_volume = float((area * cell_lengths).sum())
        target_mass = (total_mass0 / total_volume) * (area * cell_lengths)
    else:
        tracer = np.asarray(data["tracer"], dtype=float)
        cell_lengths = np.asarray(data["cell_lengths"], dtype=float)
        cell_radii = np.asarray(data.get("cell_radii", []), dtype=float)
        if cell_radii.size == cell_lengths.size:
            area = np.pi * cell_radii * cell_radii
        else:
            area = np.ones_like(cell_lengths)
        mass = tracer[-1] * cell_lengths * area * (1000.0 / 0.018)
        total_volume = float((area * cell_lengths).sum())
        target_mass = (total_mass0 / total_volume) * (area * cell_lengths)
    err = np.linalg.norm(mass - target_mass)
    return float(err / max(total_mass0, 1e-20))


def main() -> None:
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_dir = DATA_ROOT / "revision" / "diffusion" / timestamp
    out_dir.mkdir(parents=True, exist_ok=True)

    # Allow overriding discretization/time-step knobs at invocation time.
    target_dx = os.environ.get("DUMUX_TARGET_DX", "5e-7")
    min_cells = os.environ.get("DUMUX_MIN_CELLS_PER_MIN_EDGE", "10")
    record_interval = os.environ.get("DUMUX_MG_RECORD_INTERVAL", "200")

    base_env = os.environ.copy()
    base_env.update(
        {
            "DUMUX_SETTINGS_MODULE": "experiments.dumux_tracer.experiment_settings",
            "DUMUX_CONFIG_MODULE": "experiments.dumux_tracer.metric_graph_config",
            "DUMUX_COMPARE_MODULE": "experiments.dumux_tracer.compare_config",
            "DUMUX_TIME_END": str(TIME_END),
            "DUMUX_DT": str(DUMUX_DT),
            "DUMUX_TARGET_DX": target_dx,
            "DUMUX_MIN_CELLS_PER_MIN_EDGE": min_cells,
            "DUMUX_INITIAL_FROM_DUMUX": "1",
            "DUMUX_MG_RECORD_INTERVAL": record_interval,
        }
    )

    fvm_info = run_fvm(out_dir, base_env)
    total_mass0 = fvm_info["total_mass0"]
    total_length = fvm_info["total_length"]

    rng = np.random.default_rng()
    seeds_per_count: Dict[int, List[int]] = {}
    mg_runs: List[Dict[str, Any]] = []
    for count in PARTICLE_COUNTS:
        seeds = OVERRIDE_SEEDS.get(count)
        if seeds is None:
            seeds = rng.integers(0, 2**31 - 1, size=NUM_SEEDS, endpoint=False, dtype=np.int64).tolist()
        seeds_per_count[count] = seeds
        for seed in seeds:
            info = run_mg(count, seed, out_dir, base_env)
            if info.get("status") == "ok":
                info["rel_l2_uniform"] = compute_uniform_error(out_dir / info["npz"], total_mass0, total_length)
            mg_runs.append(info)

    fvm_info["rel_l2_uniform"] = compute_uniform_error(out_dir / fvm_info["npz"], total_mass0, total_length)

    manifest = {
        "timestamp": timestamp,
        "settings": {
            "time_end": TIME_END,
            "dt": DUMUX_DT,
            "particle_counts": PARTICLE_COUNTS,
            "num_seeds": NUM_SEEDS,
            "seeds_per_count": seeds_per_count,
        },
        "fvm": fvm_info,
        "mg_runs": mg_runs,
    }
    manifest_path = out_dir / "manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2))
    print(f"[revision diffusion] wrote manifest {manifest_path}")


if __name__ == "__main__":
    main()
