# # util/wb_init.py
# from __future__ import annotations
# import os
# from typing import Any, Dict, Optional, List, Tuple, Sequence, Mapping

# # --- Optional deps: wandb / matplotlib ---
# try:
#     import wandb  # type: ignore

#     _WB_AVAILABLE = True
# except Exception:
#     wandb = None  # type: ignore
#     _WB_AVAILABLE = False

# try:
#     import numpy as _np
# except Exception:
#     _np = None

# try:
#     import matplotlib.pyplot as _plt  # type: ignore

#     _MPL_OK = True
# except Exception:
#     _plt = None
#     _MPL_OK = False


# # =========================
# # Internal helpers
# # =========================


# class _WBWrap:
#     """Wrapper so code won't crash if W&B not available or disabled."""

#     def __init__(self, run=None, enabled: bool = False):
#         self.run = run
#         self.enabled = bool(enabled)

#     @property
#     def ready(self) -> bool:
#         return self.enabled and (self.run is not None)


# def _expand(p: Optional[str]) -> Optional[str]:
#     if not p:
#         return p
#     return os.path.expanduser(os.path.expandvars(p))


# def _coerce_float(v: Any, default=_np.nan if _np is not None else float("nan")) -> float:
#     try:
#         return float(v)
#     except Exception:
#         return default


# # =========================
# # Public API
# # =========================


# def wb_setup(
#     args: Any = None,
#     project: Optional[str] = None,
#     entity: Optional[str] = None,
#     mode: Optional[str] = None,  # "online" | "offline" | "disabled"
#     group: Optional[str] = None,
#     job_type: Optional[str] = None,
#     tags: Optional[List[str]] = None,
#     notes: Optional[str] = None,
#     config: Optional[Dict[str, Any]] = None,
#     resume: Optional[str] = "allow",
#     dir_: Optional[str] = None,
# ) -> _WBWrap:
#     """
#     Robust W&B init. Will not crash if wandb missing. Returns a wrapper with .ready flag.
#     Args precedence: explicit args > attributes on `args` > sensible defaults.
#     """
#     if not _WB_AVAILABLE:
#         return _WBWrap(run=None, enabled=False)

#     # ----- prefer args -----
#     prj = project or getattr(args, "wandb_project", None) or "Active Curriculum Design"
#     ent = entity or getattr(args, "wandb_entity", None)
#     md = mode or getattr(args, "wandb_mode", None)  # "online"/"offline"/"disabled"
#     grp = group or getattr(args, "wandb_group", None)
#     job = job_type or getattr(args, "wandb_job_type", None)
#     tgs = tags or getattr(args, "wandb_tags", None)
#     nts = notes or getattr(args, "wandb_notes", None)
#     cfg = {}
#     if isinstance(config, dict):
#         cfg.update(config)
#     a_cfg = getattr(args, "wandb_config", None)
#     if isinstance(a_cfg, dict):
#         cfg.update(a_cfg)

#     # run dir
#     run_dir = None
#     log_dir = _expand(getattr(args, "log_dir", None))
#     xpid = getattr(args, "xpid", None)
#     if log_dir and xpid:
#         run_dir = os.path.join(log_dir, xpid)
#     if dir_:
#         run_dir = _expand(dir_)
#     if not run_dir:
#         run_dir = os.getcwd()
#     os.makedirs(run_dir, exist_ok=True)

#     # run name (prefer xpid)
#     name = getattr(args, "run_name", None) or xpid

#     try:
#         run = wandb.init(
#             project=prj,
#             entity=ent,
#             mode=md,  # None -> auto; or "online"/"offline"/"disabled"
#             group=grp,
#             job_type=job,
#             tags=tgs,
#             notes=nts,
#             config=cfg if cfg else None,
#             resume=resume,
#             dir=run_dir,
#             name=name,
#         )
#         return _WBWrap(run=run, enabled=True)
#     except Exception:
#         # Degrade to disabled if init fails
#         return _WBWrap(run=None, enabled=False)


# def wb_define_default_metrics(wb: _WBWrap):
#     """
#     Declare common metrics & step axis for nice default charts.
#     We use 'ppo/updates' as the global x-axis.
#     """
#     if not (wb and wb.ready):
#         return
#     try:
#         # Step axis
#         wandb.define_metric("ppo/updates", summary="max")

#         # Generic total env steps (optional)
#         wandb.define_metric("steps_total", step_metric="ppo/updates")

#         # Training losses / returns
#         wandb.define_metric("loss/value", step_metric="ppo/updates")
#         wandb.define_metric("loss/policy", step_metric="ppo/updates")
#         wandb.define_metric("loss/entropy", step_metric="ppo/updates")
#         wandb.define_metric("agent/mean_return_window", step_metric="ppo/updates")
#         wandb.define_metric("agent/max_return_window", step_metric="ppo/updates")

#         # Replay / macro
#         wandb.define_metric("replay/global_pool_size", step_metric="ppo/updates")
#         wandb.define_metric("replay/preferred_topk", step_metric="ppo/updates")
#         wandb.define_metric("macro/new_levels_added", step_metric="ppo/updates")
#         wandb.define_metric("macro/injected_replay_count", step_metric="ppo/updates")

#         # Neighbor probing / score (for macro moves)
#         wandb.define_metric("neighbor/num_neighbors", step_metric="ppo/updates")
#         wandb.define_metric("neighbor/best_score", step_metric="ppo/updates")
#         wandb.define_metric("neighbor/best_mu_bar", step_metric="ppo/updates")
#         wandb.define_metric("neighbor/best_std_mu", step_metric="ppo/updates")
#         wandb.define_metric("neighbor/target", step_metric="ppo/updates")

#         # ===== Evaluation =====
#         # Summary scalars (averaged over eval episodes)
#         wandb.define_metric("eval/returns_mean", step_metric="ppo/updates")
#         wandb.define_metric("eval/returns_std", step_metric="ppo/updates")
#         wandb.define_metric("eval/success_rate", step_metric="ppo/updates")
#         wandb.define_metric("eval/len_mean", step_metric="ppo/updates")
#         wandb.define_metric("eval/len_std", step_metric="ppo/updates")
#         # Optional breakdowns (e.g., by partition)
#         wandb.define_metric("eval/breakdown/*", step_metric="ppo/updates")
#         # Optional diagnostics
#         wandb.define_metric("eval/num_episodes", step_metric="ppo/updates")
#     except Exception:
#         pass


# def wb_log(payload: Dict[str, Any], step: Optional[int] = None):
#     """Safe log to W&B."""
#     if not _WB_AVAILABLE or wandb.run is None:
#         return
#     try:
#         if step is None:
#             wandb.log(payload)
#         else:
#             wandb.log(payload, step=step)
#     except Exception:
#         pass


# def wb_log_train_step(
#     num_updates: int,
#     total_steps: Optional[int] = None,
#     value_loss: Optional[float] = None,
#     policy_loss: Optional[float] = None,
#     entropy: Optional[float] = None,
#     mean_return_window: Optional[float] = None,
#     max_return_window: Optional[float] = None,
#     global_pool_size: Optional[int] = None,
#     preferred_topk_size: Optional[int] = None,
# ):
#     """Convenience logging for each PPO update."""
#     payload: Dict[str, Any] = {
#         "ppo/updates": num_updates,
#     }
#     if total_steps is not None:
#         payload["steps_total"] = int(total_steps)
#     if value_loss is not None:
#         payload["loss/value"] = _coerce_float(value_loss)
#     if policy_loss is not None:
#         payload["loss/policy"] = _coerce_float(policy_loss)
#     if entropy is not None:
#         payload["loss/entropy"] = _coerce_float(entropy)
#     if mean_return_window is not None:
#         payload["agent/mean_return_window"] = _coerce_float(mean_return_window)
#     if max_return_window is not None:
#         payload["agent/max_return_window"] = _coerce_float(max_return_window)
#     if global_pool_size is not None:
#         payload["replay/global_pool_size"] = int(global_pool_size)
#     if preferred_topk_size is not None:
#         payload["replay/preferred_topk"] = int(preferred_topk_size)
#     wb_log(payload, step=num_updates)


# def wb_log_macro_move(
#     macro_step: int,
#     from_partition: Any,
#     to_partition: Any,
#     scores: Dict[Any, float],
#     stats: Dict[Any, Dict[str, float]],
#     injected_replay_count: int,
#     new_levels_added: int,
#     global_pool_size: int,
#     preferred_topk_size: int,
#     step: Optional[int] = None,
# ):
#     """
#     Log one partition move:
#       - scalars: best score/mu/std, neighbor count, from/to, injected_replay_count, etc.
#       - table: rows of {macro_step, from, to, neighbor, score, mu_bar, std_mu}.
#     """
#     if not _WB_AVAILABLE or wandb.run is None:
#         return

#     # Table
#     table = None
#     try:
#         table = wandb.Table(
#             columns=["macro_step", "from", "to", "neighbor", "score", "mu_bar", "std_mu"]
#         )
#         for k, v in scores.items():
#             key_str = str(k)
#             st = stats.get(k, {}) if isinstance(stats, dict) else {}
#             mu = _coerce_float(st.get("mu_bar", float("nan")))
#             sd = _coerce_float(st.get("std_mu", float("nan")))
#             table.add_data(
#                 int(macro_step),
#                 str(from_partition),
#                 str(to_partition),
#                 key_str,
#                 float(v),
#                 mu,
#                 sd,
#             )
#     except Exception:
#         table = None

#     payload = {
#         "ppo/updates": step if step is not None else macro_step,
#         "macro/macro_step": macro_step,
#         "macro/from": str(from_partition),
#         "macro/to": str(to_partition),
#         "macro/injected_replay_count": int(injected_replay_count),
#         "macro/new_levels_added": int(new_levels_added),
#         "replay/global_pool_size": int(global_pool_size),
#         "replay/preferred_topk": int(preferred_topk_size),
#         "neighbor/num_neighbors": len(scores),
#     }

#     # Best neighbor stats
#     try:
#         if scores:
#             best_key = max(scores, key=lambda k: scores[k])
#             st = stats.get(best_key, {}) if isinstance(stats, dict) else {}
#             payload.update(
#                 {
#                     "neighbor/best_score": float(scores[best_key]),
#                     "neighbor/best_mu_bar": _coerce_float(st.get("mu_bar", float("nan"))),
#                     "neighbor/best_std_mu": _coerce_float(st.get("std_mu", float("nan"))),
#                     "neighbor/target": str(best_key),
#                 }
#             )
#     except Exception:
#         pass

#     if table is not None:
#         payload["tables/neighbor_scores"] = table

#     wb_log(payload, step=step)


# # =========================
# # Evaluation logging
# # =========================


# def wb_log_eval_summary(
#     step_updates: int,
#     returns: Optional[Sequence[float]] = None,
#     lengths: Optional[Sequence[float]] = None,
#     success_rate: Optional[float] = None,
#     extra_scalars: Optional[Dict[str, float]] = None,
#     breakdown_by_partition: Optional[Mapping[Any, Dict[str, float]]] = None,
#     episodes_table: Optional[List[Dict[str, Any]]] = None,
# ):
#     """
#     Log evaluation summary at a given PPO update step.
#     - returns / lengths: sequences; we log mean/std and histogram (if matplotlib available)
#     - success_rate: scalar in [0,1]
#     - extra_scalars: any additional eval scalars -> logged under 'eval/*'
#     - breakdown_by_partition: dict[partition_key] -> {'mean':..., 'std':..., 'n':...}
#       -> each key will be logged as 'eval/breakdown/{key}_mean' etc.
#     - episodes_table: list of dicts per episode to be shown as a W&B table (columns are auto-inferred)
#     """
#     payload: Dict[str, Any] = {
#         "ppo/updates": int(step_updates),
#     }

#     # Basic stats
#     if returns is not None and len(returns) > 0:
#         try:
#             mean_r = (
#                 float(_np.mean(returns)) if _np is not None else sum(returns) / len(returns)
#             )
#             std_r = float(_np.std(returns)) if _np is not None else 0.0
#         except Exception:
#             mean_r, std_r = _coerce_float(returns[-1]), float("nan")
#         payload["eval/returns_mean"] = mean_r
#         payload["eval/returns_std"] = std_r

#     if lengths is not None and len(lengths) > 0:
#         try:
#             mean_l = (
#                 float(_np.mean(lengths)) if _np is not None else sum(lengths) / len(lengths)
#             )
#             std_l = float(_np.std(lengths)) if _np is not None else 0.0
#         except Exception:
#             mean_l, std_l = _coerce_float(lengths[-1]), float("nan")
#         payload["eval/len_mean"] = mean_l
#         payload["eval/len_std"] = std_l

#     if success_rate is not None:
#         payload["eval/success_rate"] = _coerce_float(success_rate)

#     if episodes_table is not None and _WB_AVAILABLE and wandb.run is not None:
#         try:
#             # infer columns
#             cols: List[str] = []
#             for row in episodes_table:
#                 for k in row.keys():
#                     if k not in cols:
#                         cols.append(k)
#             table = wandb.Table(columns=cols)
#             for row in episodes_table:
#                 table.add_data(*[row.get(c, None) for c in cols])
#             payload["tables/eval_episodes"] = table
#             payload["eval/num_episodes"] = len(episodes_table)
#         except Exception:
#             pass

#     # Partition breakdown
#     if breakdown_by_partition:
#         for k, d in breakdown_by_partition.items():
#             key = str(k)
#             for subk, val in d.items():
#                 payload[f"eval/breakdown/{key}_{subk}"] = _coerce_float(val)

#     # Any extra eval scalars
#     if extra_scalars:
#         for k, v in extra_scalars.items():
#             payload[f"eval/{k}"] = _coerce_float(v)

#     wb_log(payload, step=step_updates)

#     # Optional histograms as images (matplotlib)
#     if _MPL_OK and _WB_AVAILABLE and wandb.run is not None:
#         try:
#             if returns is not None and len(returns) > 0:
#                 _plt.figure()
#                 _plt.hist(list(returns), bins=30)
#                 _plt.title("Eval Returns Histogram")
#                 _plt.xlabel("return")
#                 _plt.ylabel("count")
#                 wandb.log({"plots/eval_returns_hist": wandb.Image(_plt)}, step=step_updates)
#                 _plt.close()

#             if lengths is not None and len(lengths) > 0:
#                 _plt.figure()
#                 _plt.hist(list(lengths), bins=30)
#                 _plt.title("Eval Episode Lengths Histogram")
#                 _plt.xlabel("length")
#                 _plt.ylabel("count")
#                 wandb.log({"plots/eval_lengths_hist": wandb.Image(_plt)}, step=step_updates)
#                 _plt.close()
#         except Exception:
#             pass


# def wb_log_eval_curve(
#     step_updates: int,
#     curves: Mapping[str, Tuple[Sequence[float], Sequence[float]]],
#     title_prefix: str = "eval_curve",
# ):
#     """
#     Log line curves as Tables so W&B can render interactive plots.
#     curves: dict[name] = (xs, ys)
#     """
#     if not _WB_AVAILABLE or wandb.run is None:
#         return
#     for name, (xs, ys) in curves.items():
#         try:
#             table = wandb.Table(columns=["x", "y"])
#             for x, y in zip(xs, ys):
#                 table.add_data(_coerce_float(x), _coerce_float(y))
#             wandb.log({f"tables/{title_prefix}/{name}": table}, step=step_updates)
#         except Exception:
#             continue


# def wb_log_eval_heatmap(
#     step_updates: int,
#     grid: Sequence[Sequence[float]],
#     title: str = "Eval Heatmap",
#     x_label: str = "X",
#     y_label: str = "Y",
# ):
#     """
#     Log a heatmap image (e.g., returns over a 2D partition grid).
#     Falls back silently if matplotlib not available.
#     """
#     if not (_MPL_OK and _WB_AVAILABLE and wandb.run is not None):
#         return
#     try:
#         arr = _np.array(grid) if _np is not None else grid
#         _plt.figure()
#         _plt.imshow(arr, aspect="auto")
#         _plt.colorbar()
#         _plt.title(title)
#         _plt.xlabel(x_label)
#         _plt.ylabel(y_label)
#         wandb.log({f"plots/{title}": wandb.Image(_plt)}, step=step_updates)
#         _plt.close()
#     except Exception:
#         pass


# def wb_log_eval_hist(
#     step_updates: int,
#     name: str,
#     values: Sequence[float],
#     bins: int = 30,
#     title: Optional[str] = None,
# ):
#     """
#     Log one histogram as an image (matplotlib).
#     """
#     if not (_MPL_OK and _WB_AVAILABLE and wandb.run is not None):
#         return
#     try:
#         _plt.figure()
#         _plt.hist(list(values), bins=bins)
#         _plt.title(title or f"Histogram: {name}")
#         _plt.xlabel(name)
#         _plt.ylabel("count")
#         wandb.log({f"plots/hist/{name}": wandb.Image(_plt)}, step=step_updates)
#         _plt.close()
#     except Exception:
#         pass


# def wb_finish():
#     """Finish the W&B run if available."""
#     if not _WB_AVAILABLE or wandb.run is None:
#         return
#     try:
#         wandb.finish()
#     except Exception:
#         pass


# util/wb_init.py
# util/wb_init.py
# util/wb_init.py
from __future__ import annotations
import os
import re
import hashlib
from typing import Any, Dict, Optional, List, Tuple, Sequence, Mapping

# --- Optional deps: wandb / matplotlib ---
try:
    import wandb  # type: ignore

    _WB_AVAILABLE = True
except Exception:
    wandb = None  # type: ignore
    _WB_AVAILABLE = False

try:
    import numpy as _np
except Exception:
    _np = None

try:
    import matplotlib.pyplot as _plt  # type: ignore

    _MPL_OK = True
except Exception:
    _plt = None
    _MPL_OK = False


# =========================
# Internal helpers
# =========================


class _WBWrap:
    """Wrapper so code won't crash if W&B not available or disabled."""

    def __init__(self, run=None, enabled: bool = False):
        self.run = run
        self.enabled = bool(enabled)

    @property
    def ready(self) -> bool:
        return self.enabled and (self.run is not None)


def _expand(p: Optional[str]) -> Optional[str]:
    if not p:
        return p
    return os.path.expanduser(os.path.expandvars(p))


def _coerce_float(v: Any, default=_np.nan if _np is not None else float("nan")) -> float:
    try:
        return float(v)
    except Exception:
        return default


_slug_re = re.compile(r"[^A-Za-z0-9_\-\.]+")


def _slug(s: str) -> str:
    """Mild slug, keep [A-Za-z0-9_.-], replace others with '_'."""
    return _slug_re.sub("_", s.strip())


def _hash8(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()[:8]


def _safe_trim_with_hash(src: str, max_len: int) -> str:
    """If src exceeds max_len, return prefix + '-' + hash8, within max_len."""
    if len(src) <= max_len:
        return src
    h = _hash8(src)
    keep = max_len - (1 + len(h))  # space for '-' + hash
    if keep <= 0:
        # extremely small max_len; fall back to hash only (trim to max_len)
        return h[:max_len]
    return f"{src[:keep]}-{h}"


def _safe_name(name: Optional[str], max_len: int = 128) -> Optional[str]:
    if not name:
        return name
    s = _slug(str(name))
    return _safe_trim_with_hash(s, max_len)


def _safe_id(run_id: Optional[str], max_len: int = 64) -> Optional[str]:
    """Conservative max length for W&B run id."""
    if not run_id:
        return run_id
    s = _slug(str(run_id))
    return _safe_trim_with_hash(s, max_len)


def _safe_group(group: Optional[str], max_len: int = 64) -> Optional[str]:
    if not group:
        return group
    s = _slug(str(group))
    return _safe_trim_with_hash(s, max_len)


def _safe_project(p: Optional[str], max_len: int = 128) -> Optional[str]:
    if not p:
        return p
    s = _slug(str(p))
    return _safe_trim_with_hash(s, max_len)


def _safe_entity(e: Optional[str], max_len: int = 128) -> Optional[str]:
    if not e:
        return e
    s = _slug(str(e))
    return _safe_trim_with_hash(s, max_len)


# =========================
# Public API
# =========================


def wb_setup(
    args: Any = None,
    project: Optional[str] = None,
    entity: Optional[str] = None,
    mode: Optional[str] = None,  # "online" | "offline" | "disabled"
    group: Optional[str] = None,
    job_type: Optional[str] = None,
    tags: Optional[List[str]] = None,
    notes: Optional[str] = None,
    config: Optional[Dict[str, Any]] = None,
    resume: Optional[str] = "allow",
    dir_: Optional[str] = None,
) -> _WBWrap:
    """
    Robust W&B init with stable resume by run_id (derived from args.wandb_id or args.xpid),
    and safe length limits for all user-visible strings (avoid 400: `128 limit exceeded for Name`).
    """
    if not _WB_AVAILABLE:
        return _WBWrap(run=None, enabled=False)

    # ----- preferred sources from args -----
    prj_raw = project or getattr(args, "wandb_project", None) or "Active_Curriculum_Design"
    ent_raw = entity or getattr(args, "wandb_entity", None)
    md = mode or getattr(args, "wandb_mode", None)  # "online"/"offline"/"disabled"
    grp_raw = group or getattr(args, "wandb_group", None)
    job = job_type or getattr(args, "wandb_job_type", None)
    tgs = tags or getattr(args, "wandb_tags", None)
    nts = notes or getattr(args, "wandb_notes", None)

    cfg = {}
    if isinstance(config, dict):
        cfg.update(config)
    a_cfg = getattr(args, "wandb_config", None)
    if isinstance(a_cfg, dict):
        cfg.update(a_cfg)

    # --- local run_dir (does NOT affect name/id limits) ---
    run_dir = None
    log_dir = _expand(getattr(args, "log_dir", None))
    xpid = getattr(args, "xpid", None)
    if log_dir and xpid:
        run_dir = os.path.join(log_dir, xpid)
    if dir_:
        run_dir = _expand(dir_)
    if not run_dir:
        run_dir = os.getcwd()
    os.makedirs(run_dir, exist_ok=True)

    # --- visible name (can be long; clamp to 128) ---
    # prefer args.run_name else use "<basename(log_dir)>-<xpid>"
    name_raw = getattr(args, "run_name", None)
    if not name_raw:
        try:
            base = os.path.basename(os.path.abspath(os.path.expanduser(log_dir or "")))
            name_raw = f"{base}-{xpid}" if base and xpid else (xpid or base or "run")
        except Exception:
            name_raw = xpid or "run"
    name_safe = _safe_name(name_raw, max_len=128)

    # --- stable run id (for resume), clamp to 64 ---
    # prefer args.wandb_id else args.run_id else args.xpid
    run_id_input = getattr(args, "wandb_id", None) or getattr(args, "run_id", None) or xpid
    run_id_safe = _safe_id(run_id_input, max_len=64) if run_id_input else None

    # --- project/entity/group also clamped ---
    prj_safe = _safe_project(prj_raw, max_len=128)
    ent_safe = _safe_entity(ent_raw, max_len=128) if ent_raw else None
    grp_safe = _safe_group(grp_raw, max_len=64) if grp_raw else None

    # --- resume policy: force "must" if local state suggests resuming ---
    has_ckpt = os.path.exists(os.path.join(run_dir, "model.tar"))
    has_local_wandb = os.path.exists(os.path.join(run_dir, "wandb", "latest-run"))
    resume_mode = getattr(args, "wandb_resume", None) or resume or "allow"
    if has_ckpt or has_local_wandb:
        resume_mode = "must"

    # --- env to make CLI/SDK consistent ---
    try:
        if run_id_safe:
            os.environ["WANDB_RUN_ID"] = str(run_id_safe)
        os.environ["WANDB_RESUME"] = str(resume_mode)
    except Exception:
        pass

    try:
        run = wandb.init(
            project=prj_safe,
            entity=ent_safe,
            mode=md,  # None -> auto; or "online"/"offline"/"disabled"
            group=grp_safe,
            job_type=job,
            tags=tgs,
            notes=nts,
            config=cfg if cfg else None,
            resume=resume_mode,  # "must"/"allow"/"never"
            id=run_id_safe,
            dir=run_dir,
            name=name_safe,
            settings=wandb.Settings(init_timeout=90),
            allow_val_change=True,
        )
        try:
            print(
                f"[wandb] project={prj_safe}, name={name_safe}, id={run_id_safe}, "
                f"resume={resume_mode}, url={run.url}"
            )
        except Exception:
            pass
        return _WBWrap(run=run, enabled=True)
    except Exception as e:
        try:
            print(f"[wandb] init failed ({e}); running with wandb disabled.")
        except Exception:
            pass
        return _WBWrap(run=None, enabled=False)


def wb_define_default_metrics(wb: _WBWrap):
    """Declare common metrics & step axis for nice default charts."""
    if not (wb and wb.ready):
        return
    try:
        # Step axis
        wandb.define_metric("ppo/updates", summary="max")
        # Generic total env steps (optional)
        wandb.define_metric("steps_total", step_metric="ppo/updates")

        # Training losses / returns
        wandb.define_metric("loss/value", step_metric="ppo/updates")
        wandb.define_metric("loss/policy", step_metric="ppo/updates")
        wandb.define_metric("loss/entropy", step_metric="ppo/updates")
        wandb.define_metric("agent/mean_return_window", step_metric="ppo/updates")
        wandb.define_metric("agent/max_return_window", step_metric="ppo/updates")

        # Replay / macro
        wandb.define_metric("replay/global_pool_size", step_metric="ppo/updates")
        wandb.define_metric("replay/preferred_topk", step_metric="ppo/updates")
        wandb.define_metric("macro/new_levels_added", step_metric="ppo/updates")
        wandb.define_metric("macro/injected_replay_count", step_metric="ppo/updates")

        # Neighbor probing / score (for macro moves)
        wandb.define_metric("neighbor/num_neighbors", step_metric="ppo/updates")
        wandb.define_metric("neighbor/best_score", step_metric="ppo/updates")
        wandb.define_metric("neighbor/best_mu_bar", step_metric="ppo/updates")
        wandb.define_metric("neighbor/best_std_mu", step_metric="ppo/updates")
        wandb.define_metric("neighbor/target", step_metric="ppo/updates")

        # ===== Evaluation =====
        wandb.define_metric("eval/returns_mean", step_metric="ppo/updates")
        wandb.define_metric("eval/returns_std", step_metric="ppo/updates")
        wandb.define_metric("eval/success_rate", step_metric="ppo/updates")
        wandb.define_metric("eval/len_mean", step_metric="ppo/updates")
        wandb.define_metric("eval/len_std", step_metric="ppo/updates")
        wandb.define_metric("eval/num_episodes", step_metric="ppo/updates")
        wandb.define_metric("eval/breakdown/*", step_metric="ppo/updates")
        # Per-env curves
        wandb.define_metric("eval/env/*", step_metric="ppo/updates")
    except Exception:
        pass


def wb_log(payload: Dict[str, Any], step: Optional[int] = None):
    """Safe log to W&B."""
    if not _WB_AVAILABLE or wandb.run is None:
        return
    try:
        if step is None:
            wandb.log(payload)
        else:
            wandb.log(payload, step=step)
    except Exception:
        pass


def wb_log_train_step(
    num_updates: int,
    total_steps: Optional[int] = None,
    value_loss: Optional[float] = None,
    policy_loss: Optional[float] = None,
    entropy: Optional[float] = None,
    mean_return_window: Optional[float] = None,
    max_return_window: Optional[float] = None,
    global_pool_size: Optional[int] = None,
    preferred_topk_size: Optional[int] = None,
):
    """Convenience logging for each PPO update."""
    payload: Dict[str, Any] = {"ppo/updates": num_updates}
    if total_steps is not None:
        payload["steps_total"] = int(total_steps)
    if value_loss is not None:
        payload["loss/value"] = _coerce_float(value_loss)
    if policy_loss is not None:
        payload["loss/policy"] = _coerce_float(policy_loss)
    if entropy is not None:
        payload["loss/entropy"] = _coerce_float(entropy)
    if mean_return_window is not None:
        payload["agent/mean_return_window"] = _coerce_float(mean_return_window)
    if max_return_window is not None:
        payload["agent/max_return_window"] = _coerce_float(max_return_window)
    if global_pool_size is not None:
        payload["replay/global_pool_size"] = int(global_pool_size)
    if preferred_topk_size is not None:
        payload["replay/preferred_topk"] = int(preferred_topk_size)
    wb_log(payload, step=num_updates)


def wb_log_macro_move(
    macro_step: int,
    from_partition: Any,
    to_partition: Any,
    scores: Dict[Any, float],
    stats: Dict[Any, Dict[str, float]],
    injected_replay_count: int,
    new_levels_added: int,
    global_pool_size: int,
    preferred_topk_size: int,
    step: Optional[int] = None,
):
    """Log one partition move (scalars + neighbor table)."""
    if not _WB_AVAILABLE or wandb.run is None:
        return

    # Table of neighbors
    table = None
    try:
        table = wandb.Table(
            columns=["macro_step", "from", "to", "neighbor", "score", "mu_bar", "std_mu"]
        )
        for k, v in scores.items():
            key_str = str(k)
            st = stats.get(k, {}) if isinstance(stats, dict) else {}
            mu = _coerce_float(st.get("mu_bar", float("nan")))
            sd = _coerce_float(st.get("std_mu", float("nan")))
            table.add_data(
                int(macro_step),
                str(from_partition),
                str(to_partition),
                key_str,
                float(v),
                mu,
                sd,
            )
    except Exception:
        table = None

    payload = {
        "ppo/updates": step if step is not None else macro_step,
        "macro/macro_step": macro_step,
        "macro/from": str(from_partition),
        "macro/to": str(to_partition),
        "macro/injected_replay_count": int(injected_replay_count),
        "macro/new_levels_added": int(new_levels_added),
        "replay/global_pool_size": int(global_pool_size),
        "replay/preferred_topk": int(preferred_topk_size),
        "neighbor/num_neighbors": len(scores),
    }

    # Best neighbor stats
    try:
        if scores:
            best_key = max(scores, key=lambda k: scores[k])
            st = stats.get(best_key, {}) if isinstance(stats, dict) else {}
            payload.update(
                {
                    "neighbor/best_score": float(scores[best_key]),
                    "neighbor/best_mu_bar": _coerce_float(st.get("mu_bar", float("nan"))),
                    "neighbor/best_std_mu": _coerce_float(st.get("std_mu", float("nan"))),
                    "neighbor/target": str(best_key),
                }
            )
    except Exception:
        pass

    if table is not None:
        payload["tables/neighbor_scores"] = table

    wb_log(payload, step=step)


# =========================
# Evaluation logging
# =========================


def wb_log_eval_summary(
    step_updates: int,
    returns: Optional[Sequence[float]] = None,
    lengths: Optional[Sequence[float]] = None,
    success_rate: Optional[float] = None,
    extra_scalars: Optional[Dict[str, float]] = None,
    breakdown_by_partition: Optional[Mapping[Any, Dict[str, float]]] = None,
    episodes_table: Optional[List[Dict[str, Any]]] = None,
):
    """Log evaluation summary at a given PPO update step (global)."""
    payload: Dict[str, Any] = {"ppo/updates": int(step_updates)}

    if returns is not None and len(returns) > 0:
        try:
            mean_r = (
                float(_np.mean(returns)) if _np is not None else sum(returns) / len(returns)
            )
            std_r = float(_np.std(returns)) if _np is not None else 0.0
        except Exception:
            mean_r, std_r = _coerce_float(returns[-1]), float("nan")
        payload["eval/returns_mean"] = mean_r
        payload["eval/returns_std"] = std_r

    if lengths is not None and len(lengths) > 0:
        try:
            mean_l = (
                float(_np.mean(lengths)) if _np is not None else sum(lengths) / len(lengths)
            )
            std_l = float(_np.std(lengths)) if _np is not None else 0.0
        except Exception:
            mean_l, std_l = _coerce_float(lengths[-1]), float("nan")
        payload["eval/len_mean"] = mean_l
        payload["eval/len_std"] = std_l

    if success_rate is not None:
        payload["eval/success_rate"] = _coerce_float(success_rate)

    if episodes_table is not None and _WB_AVAILABLE and wandb.run is not None:
        try:
            cols: List[str] = []
            for row in episodes_table:
                for k in row.keys():
                    if k not in cols:
                        cols.append(k)
            table = wandb.Table(columns=cols)
            for row in episodes_table:
                table.add_data(*[row.get(c, None) for c in cols])
            payload["tables/eval_episodes"] = table
            payload["eval/num_episodes"] = len(episodes_table)
        except Exception:
            pass

    if breakdown_by_partition:
        for k, d in breakdown_by_partition.items():
            key = str(k)
            for subk, val in d.items():
                payload[f"eval/breakdown/{key}_{subk}"] = _coerce_float(val)

    if extra_scalars:
        for k, v in extra_scalars.items():
            payload[f"eval/{k}"] = _coerce_float(v)

    wb_log(payload, step=step_updates)

    # Optional histograms
    if _MPL_OK and _WB_AVAILABLE and wandb.run is not None:
        try:
            if returns is not None and len(returns) > 0:
                _plt.figure()
                _plt.hist(list(returns), bins=30)
                _plt.title("Eval Returns Histogram")
                _plt.xlabel("return")
                _plt.ylabel("count")
                wandb.log({"plots/eval_returns_hist": wandb.Image(_plt)}, step=step_updates)
                _plt.close()

            if lengths is not None and len(lengths) > 0:
                _plt.figure()
                _plt.hist(list(lengths), bins=30)
                _plt.title("Eval Episode Lengths Histogram")
                _plt.xlabel("length")
                _plt.ylabel("count")
                wandb.log({"plots/eval_lengths_hist": wandb.Image(_plt)}, step=step_updates)
                _plt.close()
        except Exception:
            pass


def wb_log_eval_summary_per_env(
    step_updates: int,
    env_name: str,
    returns: Optional[Sequence[float]] = None,
    lengths: Optional[Sequence[float]] = None,
    success_rate: Optional[float] = None,
    extra_scalars: Optional[Dict[str, float]] = None,
):
    """Log evaluation summary for a single environment (separate curves per env)."""
    if not (_WB_AVAILABLE and wandb.run is not None):
        return
    slug = _slug(env_name)
    payload: Dict[str, Any] = {"ppo/updates": int(step_updates)}

    if returns is not None and len(returns) > 0:
        try:
            mean_r = (
                float(_np.mean(returns)) if _np is not None else sum(returns) / len(returns)
            )
            std_r = float(_np.std(returns)) if _np is not None else 0.0
        except Exception:
            mean_r, std_r = _coerce_float(returns[-1]), float("nan")
        payload[f"eval/env/{slug}/returns_mean"] = mean_r
        payload[f"eval/env/{slug}/returns_std"] = std_r

    if lengths is not None and len(lengths) > 0:
        try:
            mean_l = (
                float(_np.mean(lengths)) if _np is not None else sum(lengths) / len(lengths)
            )
            std_l = float(_np.std(lengths)) if _np is not None else 0.0
        except Exception:
            mean_l, std_l = _coerce_float(lengths[-1]), float("nan")
        payload[f"eval/env/{slug}/len_mean"] = mean_l
        payload[f"eval/env/{slug}/len_std"] = std_l

    if success_rate is not None:
        payload[f"eval/env/{slug}/success_rate"] = _coerce_float(success_rate)

    if extra_scalars:
        for k, v in extra_scalars.items():
            payload[f"eval/env/{slug}/{k}"] = _coerce_float(v)

    wb_log(payload, step=step_updates)


def wb_log_eval_curve(
    step_updates: int,
    curves: Mapping[str, Tuple[Sequence[float], Sequence[float]]],
    title_prefix: str = "eval_curve",
):
    """Log line curves as Tables so W&B can render interactive plots."""
    if not _WB_AVAILABLE or wandb.run is None:
        return
    for name, (xs, ys) in curves.items():
        try:
            table = wandb.Table(columns=["x", "y"])
            for x, y in zip(xs, ys):
                table.add_data(_coerce_float(x), _coerce_float(y))
            wandb.log({f"tables/{title_prefix}/{name}": table}, step=step_updates)
        except Exception:
            continue


def wb_log_eval_heatmap(
    step_updates: int,
    grid: Sequence[Sequence[float]],
    title: str = "Eval Heatmap",
    x_label: str = "X",
    y_label: str = "Y",
):
    """Log a heatmap image (e.g., returns over a 2D partition grid)."""
    if not (_MPL_OK and _WB_AVAILABLE and wandb.run is not None):
        return
    try:
        arr = _np.array(grid) if _np is not None else grid
        _plt.figure()
        _plt.imshow(arr, aspect="auto")
        _plt.colorbar()
        _plt.title(title)
        _plt.xlabel(x_label)
        _plt.ylabel(y_label)
        wandb.log({f"plots/{title}": wandb.Image(_plt)}, step=step_updates)
        _plt.close()
    except Exception:
        pass


def wb_log_eval_hist(
    step_updates: int,
    name: str,
    values: Sequence[float],
    bins: int = 30,
    title: Optional[str] = None,
):
    """Log one histogram as an image (matplotlib)."""
    if not (_MPL_OK and _WB_AVAILABLE and wandb.run is not None):
        return
    try:
        _plt.figure()
        _plt.hist(list(values), bins=bins)
        _plt.title(title or f"Histogram: {name}")
        _plt.xlabel(name)
        _plt.ylabel("count")
        wandb.log({f"plots/hist/{name}": wandb.Image(_plt)}, step=step_updates)
        _plt.close()
    except Exception:
        pass


def wb_finish():
    """Finish the W&B run if available."""
    if not _WB_AVAILABLE or wandb.run is None:
        return
    try:
        wandb.finish()
    except Exception:
        pass
