"""
Sorcerun entrypoint for DuMuX + metric-graph runs.
"""

import os
import subprocess
import time
from pathlib import Path

import numpy as np
from gitbud.gitbud import inject_repo_into_sys_path
from sorcerun.git_utils import get_repo

inject_repo_into_sys_path()

from experiments.dumux_tracer import run_metric_graph
from experiments.dumux_tracer import extract_dumux_tracer

REPO = get_repo()
REPO_ROOT = Path(REPO.working_dir)


def _maybe_run_dumux(cfg: dict) -> None:
    if not cfg.get("run_dumux", False):
        return
    example = Path(cfg["dumux_example"])
    params = Path(cfg["dumux_params"])
    print(f"[sorcerun] Running DuMuX example {example} with {params}")
    subprocess.run([str(example), str(params)], check=True, cwd=str(example.parent))
    os.environ["DUMUX_VTK_DIR"] = cfg["dumux_vtk_dir"]
    os.environ["DUMUX_TRACER_OUTPUT"] = cfg["dumux_tracer_output"]
    extract_dumux_tracer.main()


def adapter(config: dict, _run) -> None:
    """Sacred adapter to run DuMuX (optional) then the metric graph."""
    _maybe_run_dumux(config)

    # Env overrides for metric run
    env_map = {
        "DUMUX_MG_PARTICLES": config["num_particles"],
        "DUMUX_MG_NUM_BINS": config["num_bins"],
        "DUMUX_MG_DT": config["dt"],
        "DUMUX_MG_STEPS": config["steps"],
        "DUMUX_MG_RECORD_INTERVAL": config["record_interval"],
        "DUMUX_MG_VELOCITY_SCALE": config["velocity_scale"],
        "DUMUX_MG_REFINE": config["refine_segments"],
        "DUMUX_MG_DECAY": config["decay_rate"],
        "DUMUX_MG_SEED": config["rng_seed"],
        "DUMUX_MG_OUTPUT": config["output_path"],
        "DUMUX_TRACER_OUTPUT": config["dumux_tracer_output"],
        "DUMUX_MG_DGF": config["dgf_path"],
    }
    for key, val in env_map.items():
        os.environ[str(key)] = str(val)

    start = time.perf_counter()
    run_metric_graph.run()
    wall = time.perf_counter() - start

    # Log outputs
    _run.log_scalar("metric_wall_time", wall)
    _run.info["output_path"] = config["output_path"]
    npz = np.load(config["output_path"])
    _run.log_scalar("metric_total_wall_time", float(npz.get("total_wall_time", np.nan)))
    _run.log_scalar("metric_num_particles", int(npz.get("num_particles", -1)))
    _run.log_scalar("metric_num_bins", int(npz.get("num_bins", -1)))
    _run.log_scalar("metric_steps", int(npz.get("steps", -1)))


if __name__ == "__main__":
    from sorcerun.sacred_utils import run_sacred_experiment
    from experiments.dumux_tracer.sorcerun_config import config

    run = run_sacred_experiment(adapter, config)
    print("Run id:", run._id)
