from __future__ import annotations

import hashlib
import json
from pathlib import Path
from typing import Any

try:
    import tomllib  # py>=3.11
except Exception:  # pragma: no cover
    import tomli as tomllib  # type: ignore


def _stable_hash(cfg: dict) -> str:
    s = json.dumps(cfg, sort_keys=True, separators=(",", ":"), ensure_ascii=True)
    return hashlib.sha1(s.encode("utf-8")).hexdigest()[:10]


def load_runs(toml_path: str, run_id: str | None = None) -> list[dict[str, Any]]:
    """
    Supports either:
      - single seed via `seed = 42` (default behavior)
      - multiple seeds via `seeds = [10, 11, 12]` inside a [[runs]] block
    For multiple seeds, we expand into multiple cfg dicts, keeping the SAME run_id
    (group id) and varying only cfg["seed"].
    """
    p = Path(toml_path)
    data = tomllib.loads(p.read_text(encoding="utf-8"))

    defaults = data.get("defaults", {})
    if defaults is None:
        defaults = {}
    if not isinstance(defaults, dict):
        raise ValueError(f"'defaults' must be a table [defaults] in {toml_path}")

    runs = data.get("runs", None)
    if runs is None:
        raise ValueError(f"No [[runs]] found in {toml_path}")
    if not isinstance(runs, list):
        raise ValueError(f"'runs' must be an array of tables [[runs]] in {toml_path}")

    merged: list[dict[str, Any]] = []

    for r in runs:
        rr = dict(defaults)
        rr.update(dict(r))  # run-specific overrides defaults

        # optional multi-seed expansion
        seeds = rr.pop("seeds", None)

        # ensure run_id exists and is STABLE across seeds:
        # - if provided: keep it
        # - else: hash config WITHOUT seed so it's stable for the group
        rid = rr.get("run_id")
        if rid is None:
            cfg_for_hash = dict(rr)
            cfg_for_hash.pop("seed", None)
            rid = _stable_hash(cfg_for_hash)
            rr["run_id"] = rid

        if seeds is None:
            merged.append(rr)
        else:
            # expand to many cfgs, same run_id, different seed
            for s in seeds:
                cfg_s = dict(rr)
                cfg_s["seed"] = int(s)
                merged.append(cfg_s)

    if run_id is None:
        return merged

    # filter by group run_id
    return [r for r in merged if str(r.get("run_id")) == str(run_id)]


def new_run_dir(*, exp: str, method: str, cfg: dict, out_root: str = "runs") -> Path:
    """
    Alternative B layout:
      runs/<exp>/<method>/<run_id>/seed_<seed>/
    """
    cfg = dict(cfg)

    rid = cfg.get("run_id")
    if rid is None:
        cfg_for_hash = dict(cfg)
        cfg_for_hash.pop("seed", None)
        rid = _stable_hash(cfg_for_hash)
        cfg["run_id"] = rid

    seed = cfg.get("seed", None)
    if seed is None:
        raise ValueError("cfg must contain 'seed' (int) for Alternative B directory layout")
    seed = int(seed)

    base_dir = Path(out_root) / exp / method / str(rid)
    base_dir.mkdir(parents=True, exist_ok=True)

    run_dir = base_dir / f"seed_{seed}"
    run_dir.mkdir(parents=True, exist_ok=False)

    (run_dir / "config.json").write_text(
        json.dumps(cfg, indent=2, sort_keys=True, ensure_ascii=False),
        encoding="utf-8",
    )
    (run_dir / "checkpoints").mkdir(parents=True, exist_ok=True)
    (run_dir / "figures").mkdir(parents=True, exist_ok=True)
    return run_dir


def log_jsonl(path: Path, row: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")


def save_ckpt(run_dir: Path, *, epoch: int, tag: str, payload: dict) -> Path:
    import torch

    ckpt_dir = run_dir / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    if tag == "latest":
        name = "latest.pt"
    else:
        name = f"{tag}_{epoch:07d}.pt"

    path = ckpt_dir / name
    torch.save(payload, path)
    return path