# peptide/run_utils.py
from __future__ import annotations

from pathlib import Path
import json
import time
import hashlib
from typing import Any, Dict, List, Optional, Union
import copy
import torch


# -------------------- TOML config loading --------------------

def load_toml(path: str | Path) -> Dict[str, Any]:
    """
    Loads TOML into a Python dict.

    - Python 3.11+: uses tomllib
    - Older: tries tomli (pip install tomli)
    """
    path = Path(path)
    try:
        import tomllib  # py>=3.11
        return tomllib.loads(path.read_text(encoding="utf-8"))
    except ModuleNotFoundError:
        try:
            import tomli  # type: ignore
            return tomli.loads(path.read_text(encoding="utf-8"))
        except ModuleNotFoundError as e:
            raise RuntimeError(
                "Python < 3.11 detected and 'tomli' is not installed.\n"
                "Either use Python 3.11+ or run: pip install tomli"
            ) from e


def merge_dicts(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
    out = dict(base)
    for k, v in override.items():
        out[k] = v
    return out


def load_runs(config_path: Union[str, Path], *, run_id: Optional[str] = None) -> List[Dict[str, Any]]:
    """
    Lê TOML com:
      [defaults]
      [[runs]]
      id="run_27"
      seed=10      (ou seeds=[10,11,...])

    Retorna lista de cfgs já "materializadas" (1 cfg por seed).
    Também normaliza: cfg["run_id"] = cfg["id"].
    """
    cfg_all = load_toml(config_path)
    defaults = cfg_all.get("defaults", {}) or {}
    runs = cfg_all.get("runs", []) or []

    if not isinstance(runs, list):
        raise ValueError("TOML inválido: 'runs' deve ser uma lista (use [[runs]]).")

    out: List[Dict[str, Any]] = []

    for r in runs:
        if not isinstance(r, dict):
            continue

        rid = r.get("id", None)
        if rid is None:
            raise ValueError("Cada [[runs]] precisa ter 'id'.")

        if run_id is not None and str(rid) != str(run_id):
            continue

        # merge defaults + overrides do run (shallow é suficiente pro seu caso)
        merged = copy.deepcopy(defaults)
        merged.update(r)

        # normaliza id -> run_id
        merged["run_id"] = str(rid)

        # expande seeds se existir
        if "seeds" in merged and merged["seeds"] is not None:
            seeds = merged["seeds"]
            if not isinstance(seeds, list) or len(seeds) == 0:
                raise ValueError(f"'seeds' deve ser uma lista não-vazia em run id={rid}")
            for s in seeds:
                cfg2 = copy.deepcopy(merged)
                cfg2.pop("seeds", None)
                cfg2["seed"] = int(s)
                out.append(cfg2)
        else:
            # exige seed (vem do defaults ou do run)
            if "seed" not in merged or merged["seed"] is None:
                raise ValueError(
                    f"Seed ausente para run id={rid}. "
                    f"Defina 'seed' em [defaults] ou em [[runs]], ou use 'seeds=[...]'."
                )
            merged["seed"] = int(merged["seed"])
            out.append(merged)

    if run_id is not None and len(out) == 0:
        raise ValueError(f"Nenhum run com id='{run_id}' encontrado em {config_path}")

    return out


# -------------------- run dir naming + IO --------------------

def _fmt_float(x: float) -> str:
    """
    Compact float formatting for folder names.
    0.94 -> 0p94
    1e-2 -> 0p01-ish (via .4g)
    """
    if x == 0:
        return "0"
    s = f"{x:.4g}"
    return s.replace(".", "p").replace("-", "m")


def _fmt(v: Any) -> str:
    if isinstance(v, float):
        return _fmt_float(v)
    return str(v)


def _short_hash(cfg: Dict[str, Any], n: int = 6) -> str:
    blob = json.dumps(cfg, sort_keys=True, ensure_ascii=False).encode("utf-8")
    return hashlib.sha1(blob).hexdigest()[:n]


def keys_for_method(method: str, cfg: Dict[str, Any]) -> List[str]:
    """
    Which keys to include in the folder name for readability.
    Also includes run_id so you can identify runs quickly.
    """
    common = ["run_id", "seed", "batch_size", "cut_off", "eps", "device"]
    arch = ["emb_dim", "hidden", "pos_dim", "window"]
    lr = ["lr_pf", "lr_logz"]

    if method == "tb":
        keys = common + arch + lr
    elif method == "dtb":
        keys = common + arch + lr + ["div_lr_pf", "div_lr_logz", "threshold"]
    elif method == "sa":
        keys = common + arch + lr + ["sa_lr_pf", "sa_lr_logz"]
    elif method == "teacher_student":
        keys = common + arch + lr
    else:
        keys = common + arch + lr

    return [k for k in keys if k in cfg]


def make_run_name(method: str, cfg: Dict[str, Any], keys: Iterable[str]) -> str:
    ts = time.strftime("%Y-%m-%d_%H%M%S")
    parts: List[str] = [ts, method]
    for k in keys:
        if k in cfg:
            parts.append(f"{k}{_fmt(cfg[k])}")
    parts.append(_short_hash(cfg))
    return "__".join(parts)

def new_run_dir(*, exp: str, method: str, cfg: dict, out_root: str) -> Path:
    """
    runs/<exp>/<method>/<run_id>/seed_<seed>/
    """
    run_id = cfg.get("run_id") or cfg.get("id") or "run_unknown"
    seed = cfg.get("seed", None)
    if seed is None:
        raise ValueError("cfg must contain 'seed' to build seed subdir")

    run_dir = Path(out_root) / exp / method / str(run_id) / f"seed_{int(seed)}"
    (run_dir / "checkpoints").mkdir(parents=True, exist_ok=True)
    (run_dir / "samples").mkdir(parents=True, exist_ok=True)

    # opcional: salvar config.json sempre que cria o diretório (source of truth)
    cfg_path = run_dir / "config.json"
    if not cfg_path.exists():
        with open(cfg_path, "w", encoding="utf-8") as f:
            json.dump(cfg, f, indent=2, ensure_ascii=False)

    return run_dir


def log_jsonl(path: Path, row: Dict[str, Any]) -> None:
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")


def save_ckpt(run_dir: Path, *, epoch: int, tag: str, payload: Dict[str, Any]) -> Path:
    ckpt_dir = run_dir / "checkpoints"
    if tag == "epoch":
        path = ckpt_dir / f"epoch_{epoch:06d}.pt"
    else:
        path = ckpt_dir / f"{tag}.pt"
    torch.save(payload, path)
    return path