"""
Single source of truth for the conservative diffusion baseline configs.

We keep everything in Python (no env parsing) and import this module from
all scripts so that:
- the DuMuX grid/discretization matches the metric-graph histogramming
- both methods share the same physical time step and end time
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
import sys
import os

from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
# Ensure CUDA simulator Python bindings are importable for dt/step helpers.
sys.path.append(str(REPO_ROOT / "main" / "langevin-gpu" / "python"))


def _env_float(name: str, default: float) -> float:
    raw = os.environ.get(name)
    return float(raw) if raw is not None and raw != "" else default


def _env_int(name: str, default: int) -> int:
    raw = os.environ.get(name)
    return int(raw) if raw is not None and raw != "" else default


@dataclass(frozen=True)
class ExperimentSettings:
    # Geometry / discretization
    target_dx: float = _env_float("DUMUX_TARGET_DX", 5e-7)
    min_cells_per_min_edge: int = _env_int("DUMUX_MIN_CELLS_PER_MIN_EDGE", 10)
    flux_scale: float = _env_float("DUMUX_FLUX_SCALE", 0.0)  # pure diffusion; divergence-free flux disabled
    diffusion_coefficient: float = _env_float("DUMUX_DIFFUSION_COEFF", 1.0e-9)

    # Time stepping (shared between DuMuX and MG)
    time_end: float = _env_float("DUMUX_TIME_END", 4000.0)
    dumux_dt: float = _env_float("DUMUX_DT", 0.5)  # DuMuX TimeLoop.DtInitial (larger dt for faster marching)

    # Metric-graph parameters
    mg_particles: int = _env_int("DUMUX_MG_PARTICLES", 2_000_000)
    mg_record_interval: int = _env_int("DUMUX_MG_RECORD_INTERVAL", 1)
    mg_num_bins: int = 1  # one histogram bin per DuMuX cell (cells are edges after subdivision)

    # Paths
    data_dir: Path = REPO_ROOT / "data"
    figs_dir: Path = REPO_ROOT / "figs"

    @property
    def dune_root(self) -> Path:
        return REPO_ROOT / "external"

    @property
    def dumux_input_npz(self) -> Path:
        return self.data_dir / "dumux_network_input.npz"

    @property
    def dumux_tracer_output(self) -> Path:
        return self.data_dir / "dumux_network_tracer_1d.npz"

    @property
    def dumux_runtime_file(self) -> Path:
        return self.dumux_tracer_output.with_name(f"{self.dumux_tracer_output.stem}_runtime.txt")

    @property
    def dumux_example_dir(self) -> Path:
        return self.dune_root / "dumux" / "build-cmake" / "examples" / "network_tracer_1d"

    @property
    def dumux_params_path(self) -> Path:
        return self.dumux_example_dir / "params.input"

    @property
    def dumux_dgf_build(self) -> Path:
        return self.dumux_example_dir / "network_filtered.dgf"

    @property
    def dumux_example(self) -> Path:
        return self.dumux_example_dir / "example_network_tracer_1d"

    @property
    def metric_output(self) -> Path:
        return self.data_dir / "dumux_metric_graph_sim.npz"

    @property
    def mg_dt(self) -> float:
        # LangevinSimulator advances STEPS_PER_KERNEL times per loop; choose dt so
        # that each loop matches the DuMuX time step.
        from langevin_simulator import STEPS_PER_KERNEL  # type: ignore

        return self.dumux_dt / STEPS_PER_KERNEL

    @property
    def mg_steps(self) -> int:
        from math import ceil

        from langevin_simulator import STEPS_PER_KERNEL  # type: ignore

        return int(ceil(self.time_end / (STEPS_PER_KERNEL * self.mg_dt)))

    @property
    def plot_dir(self) -> Path:
        return self.figs_dir / "diffusion"


settings = ExperimentSettings()
