"""
Settings for the conservative advection-diffusion baseline (constant drift per edge).
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
import sys
import os

from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
sys.path.append(str(REPO_ROOT / "main" / "langevin-gpu" / "python"))


def _env_float(name: str, default: float) -> float:
    raw = os.environ.get(name)
    return float(raw) if raw is not None and raw != "" else default


def _env_int(name: str, default: int) -> int:
    raw = os.environ.get(name)
    return int(raw) if raw is not None and raw != "" else default


@dataclass(frozen=True)
class ExperimentSettings:
    # Geometry / discretization
    target_dx: float = _env_float("DUMUX_TARGET_DX", 1e-6)
    min_cells_per_min_edge: int = _env_int("DUMUX_MIN_CELLS_PER_MIN_EDGE", 10)
    # Target max velocity magnitude on any edge (m/s); large enough to see advection while keeping dt=0.25 stable for dx=1e-6
    flux_scale: float = _env_float("DUMUX_FLUX_SCALE", 1e-5)
    synthetic_velocity: float | None = None
    diffusion_coefficient: float = _env_float("DUMUX_DIFFUSION_COEFF", 5.0e-10)

    # Time stepping (shared between DuMuX and MG)
    time_end: float = _env_float("DUMUX_TIME_END", 1500.0)
    dumux_dt: float = _env_float("DUMUX_DT", 0.25)

    # Metric-graph parameters
    mg_particles: int = _env_int("DUMUX_MG_PARTICLES", 1_000_000)
    # Record every DuMuX time step by default for aligned mass plots
    mg_record_interval: int = _env_int("DUMUX_MG_RECORD_INTERVAL", 1)
    mg_num_bins: int = 1

    # Paths
    data_dir: Path = REPO_ROOT / "data"
    figs_dir: Path = REPO_ROOT / "figs"

    @property
    def dune_root(self) -> Path:
        return REPO_ROOT / "external"

    @property
    def dumux_input_npz(self) -> Path:
        return self.data_dir / "dumux_advection_network_input.npz"

    @property
    def dumux_tracer_output(self) -> Path:
        return self.data_dir / "dumux_advection_tracer_1d.npz"

    @property
    def dumux_runtime_file(self) -> Path:
        return self.dumux_tracer_output.with_name(f"{self.dumux_tracer_output.stem}_runtime.txt")

    @property
    def dumux_example_dir(self) -> Path:
        return self.dune_root / "dumux" / "build-cmake" / "examples" / "network_tracer_1d"

    @property
    def dumux_params_path(self) -> Path:
        return self.dumux_example_dir / "params.input"

    @property
    def dumux_dgf_build(self) -> Path:
        return self.dumux_example_dir / "network_filtered.dgf"

    @property
    def dumux_example(self) -> Path:
        return self.dumux_example_dir / "example_network_tracer_1d"

    @property
    def metric_output(self) -> Path:
        return self.data_dir / "dumux_advection_metric_graph_sim.npz"

    @property
    def mg_dt(self) -> float:
        from langevin_simulator import STEPS_PER_KERNEL  # type: ignore

        return self.dumux_dt / STEPS_PER_KERNEL

    @property
    def mg_steps(self) -> int:
        from math import ceil
        from langevin_simulator import STEPS_PER_KERNEL  # type: ignore

        return int(ceil(self.time_end / (STEPS_PER_KERNEL * self.mg_dt)))

    @property
    def plot_dir(self) -> Path:
        override = os.environ.get("DUMUX_PLOT_DIR")
        return Path(override) if override else self.figs_dir / "advection"


settings = ExperimentSettings()
