"""
Config for the real microBlooM network with flow-derived drift (linear potential).
"""

from __future__ import annotations

from pathlib import Path

import numpy as np
from sorcerun.git_utils import (
    is_dirty,
    get_repo,
    get_commit_hash,
    get_time_str,
    get_tree_hash,
)
from config import _compute_stability_limit

repo = get_repo()
REPO_ROOT = Path(repo.working_dir)
GRAPH_PATH = (
    REPO_ROOT / "experiments" / "vascular_hex" / "microbloom_real_flow_graph.npz"
)

if not GRAPH_PATH.exists():
    raise FileNotFoundError(
        f"Missing flow-derived graph at {GRAPH_PATH}. Run converter with --use_flow_as_drift."
    )

T = 0.05
num_particles = int(1e6)
num_bins = 200
sigma = 1.0
make_gif = False
error_norm = 2
run_fvm = False

data = np.load(GRAPH_PATH)
num_edges = int(data["num_edges"])
edge_lengths = data["edge_lengths"].astype(float).tolist()
drift_coeffs = data["drift_coeffs"].astype(float).tolist()
edge_vertices = data["edge_vertices"].astype(int).tolist()
vertex_edge_offsets = data["vertex_edge_offsets"].astype(int).tolist()
vertex_edge_indices = data["vertex_edge_indices"].astype(int).tolist()
vertex_edge_orientations = data["vertex_edge_orientations"].astype(int).tolist()
vertex_edge_cumweights = data["vertex_edge_cumweights"].astype(float).tolist()
jump_weights = (np.ones(num_edges, dtype=float) / num_edges).tolist()

dx = float(edge_lengths[0]) / num_bins if num_bins > 0 else 1.0
d_max = _compute_stability_limit(drift_coeffs, D=sigma**2 / 2, dx=dx)
dt = min(1e-5, max(d_max / 2, 1e-7))
steps = int(T / dt)

config = {
    "num_particles": num_particles,
    "num_bins": num_bins,
    "steps": steps,
    "dt": dt,
    "sigma": sigma,
    "num_edges": num_edges,
    "edge_lengths": edge_lengths,
    "drift_coeffs": drift_coeffs,
    "potential_type": "linear",
    "jump_weights": jump_weights,
    "edge_vertices": edge_vertices,
    "vertex_edge_offsets": vertex_edge_offsets,
    "vertex_edge_indices": vertex_edge_indices,
    "vertex_edge_orientations": vertex_edge_orientations,
    "vertex_edge_cumweights": vertex_edge_cumweights,
    "make_gif": make_gif,
    "error_norm": error_norm,
    "run_fvm": run_fvm,
    "backend": "cuda",
    "commit_hash": get_commit_hash(repo),
    "main_tree_hash": get_tree_hash(repo, "main"),
    "time_str": get_time_str(),
    "dirty": is_dirty(get_repo()),
}
