"""
Configuration for running the Langevin metric-graph sampler with nonzero drift.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
import os

from gitbud.gitbud import inject_repo_into_sys_path

REPO_ROOT = Path(inject_repo_into_sys_path())
from experiments.dumux_tracer.advection_settings import settings


def _env_bool(name: str, default: bool) -> bool:
    raw = os.environ.get(name)
    if raw is None or raw == "":
        return default
    return raw.lower() not in {"0", "false", "no"}


def _env_int(name: str, default: int) -> int:
    raw = os.environ.get(name)
    try:
        return int(raw) if raw not in (None, "") else default
    except ValueError:
        return default


def _env_path(name: str, default: Path) -> Path:
    raw = os.environ.get(name)
    if raw is None or raw == "":
        return default
    return Path(raw)


@dataclass(frozen=True)
class MetricGraphConfig:
    dgf_path: Path = settings.dumux_dgf_build
    output_path: Path = _env_path("DUMUX_METRIC_OUTPUT", settings.metric_output)
    tracer_npz_path: Path = _env_path("DUMUX_TRACER_OUTPUT", settings.dumux_tracer_output)
    viscosity: float = 1e-3
    diffusion_coefficient: float = settings.diffusion_coefficient
    velocity_scale: float = 1.0  # use DuMuX-derived drift
    num_particles: int = settings.mg_particles
    num_bins: int = settings.mg_num_bins
    dt: float = settings.mg_dt
    steps: int | None = settings.mg_steps
    time_horizon: float = settings.time_end
    record_interval: int = settings.mg_record_interval
    rng_seed: int = _env_int("DUMUX_MG_SEED", 0)
    refine_segments: int = 1
    weight_power: float = 2.0
    weight_by_length: bool = False
    decay_rate: float = 0.0
    initial_from_dumux: bool = _env_bool("DUMUX_INITIAL_FROM_DUMUX", True)
    initial_snapshot_idx: int = 0
    constant_velocity: float | None = None
    validate_hist: bool = False
    validate_hist_every: int = 0

    @property
    def sigma(self) -> float:
        return (2.0 * self.diffusion_coefficient) ** 0.5


config = MetricGraphConfig()
