# util/wb_init.py
from __future__ import annotations
import os
from typing import Any, Dict, Optional, List, Tuple, Sequence, Mapping

# --- Optional deps: wandb / matplotlib ---
try:
    import wandb  # type: ignore

    _WB_AVAILABLE = True
except Exception:
    wandb = None  # type: ignore
    _WB_AVAILABLE = False

try:
    import numpy as _np
except Exception:
    _np = None

try:
    import matplotlib.pyplot as _plt  # type: ignore

    _MPL_OK = True
except Exception:
    _plt = None
    _MPL_OK = False


# =========================
# Internal helpers
# =========================


class _WBWrap:
    """Wrapper so code won't crash if W&B not available or disabled."""

    def __init__(self, run=None, enabled: bool = False):
        self.run = run
        self.enabled = bool(enabled)

    @property
    def ready(self) -> bool:
        return self.enabled and (self.run is not None)


def _expand(p: Optional[str]) -> Optional[str]:
    if not p:
        return p
    return os.path.expanduser(os.path.expandvars(p))


def _coerce_float(v: Any, default=_np.nan if _np is not None else float("nan")) -> float:
    try:
        return float(v)
    except Exception:
        return default


# =========================
# Public API
# =========================


def wb_setup(
    args: Any = None,
    project: Optional[str] = None,
    entity: Optional[str] = None,
    mode: Optional[str] = None,  # "online" | "offline" | "disabled"
    group: Optional[str] = None,
    job_type: Optional[str] = None,
    tags: Optional[List[str]] = None,
    notes: Optional[str] = None,
    config: Optional[Dict[str, Any]] = None,
    resume: Optional[str] = "allow",
    dir_: Optional[str] = None,
) -> _WBWrap:
    """
    Robust W&B init. Will not crash if wandb missing. Returns a wrapper with .ready flag.
    Args precedence: explicit args > attributes on `args` > sensible defaults.
    """
    if not _WB_AVAILABLE:
        return _WBWrap(run=None, enabled=False)

    # ----- prefer args -----
    prj = project or getattr(args, "wandb_project", None) or "Active Curriculum Design"
    ent = entity or getattr(args, "wandb_entity", None)
    md = mode or getattr(args, "wandb_mode", None)  # "online"/"offline"/"disabled"
    grp = group or getattr(args, "wandb_group", None)
    job = job_type or getattr(args, "wandb_job_type", None)
    tgs = tags or getattr(args, "wandb_tags", None)
    nts = notes or getattr(args, "wandb_notes", None)
    cfg = {}
    if isinstance(config, dict):
        cfg.update(config)
    a_cfg = getattr(args, "wandb_config", None)
    if isinstance(a_cfg, dict):
        cfg.update(a_cfg)

    # run dir
    run_dir = None
    log_dir = _expand(getattr(args, "log_dir", None))
    xpid = getattr(args, "xpid", None)
    if log_dir and xpid:
        run_dir = os.path.join(log_dir, xpid)
    if dir_:
        run_dir = _expand(dir_)
    if not run_dir:
        run_dir = os.getcwd()
    os.makedirs(run_dir, exist_ok=True)

    # run name (prefer xpid)
    name = getattr(args, "run_name", None) or xpid

    try:
        run = wandb.init(
            project=prj,
            entity=ent,
            mode=md,  # None -> auto; or "online"/"offline"/"disabled"
            group=grp,
            job_type=job,
            tags=tgs,
            notes=nts,
            config=cfg if cfg else None,
            resume=resume,
            dir=run_dir,
            name=name,
        )
        return _WBWrap(run=run, enabled=True)
    except Exception:
        # Degrade to disabled if init fails
        return _WBWrap(run=None, enabled=False)


def wb_define_default_metrics(wb: _WBWrap):
    """
    Declare common metrics & step axis for nice default charts.
    We use 'ppo/updates' as the global x-axis.
    """
    if not (wb and wb.ready):
        return
    try:
        # Step axis
        wandb.define_metric("ppo/updates", summary="max")

        # Generic total env steps (optional)
        wandb.define_metric("steps_total", step_metric="ppo/updates")

        # Training losses / returns
        wandb.define_metric("loss/value", step_metric="ppo/updates")
        wandb.define_metric("loss/policy", step_metric="ppo/updates")
        wandb.define_metric("loss/entropy", step_metric="ppo/updates")
        wandb.define_metric("agent/mean_return_window", step_metric="ppo/updates")
        wandb.define_metric("agent/max_return_window", step_metric="ppo/updates")

        # Replay / macro
        wandb.define_metric("replay/global_pool_size", step_metric="ppo/updates")
        wandb.define_metric("replay/preferred_topk", step_metric="ppo/updates")
        wandb.define_metric("macro/new_levels_added", step_metric="ppo/updates")
        wandb.define_metric("macro/injected_replay_count", step_metric="ppo/updates")

        # Neighbor probing / score (for macro moves)
        wandb.define_metric("neighbor/num_neighbors", step_metric="ppo/updates")
        wandb.define_metric("neighbor/best_score", step_metric="ppo/updates")
        wandb.define_metric("neighbor/best_mu_bar", step_metric="ppo/updates")
        wandb.define_metric("neighbor/best_std_mu", step_metric="ppo/updates")
        wandb.define_metric("neighbor/target", step_metric="ppo/updates")

        # ===== Evaluation =====
        # Summary scalars (averaged over eval episodes)
        wandb.define_metric("eval/returns_mean", step_metric="ppo/updates")
        wandb.define_metric("eval/returns_std", step_metric="ppo/updates")
        wandb.define_metric("eval/success_rate", step_metric="ppo/updates")
        wandb.define_metric("eval/len_mean", step_metric="ppo/updates")
        wandb.define_metric("eval/len_std", step_metric="ppo/updates")
        # Optional breakdowns (e.g., by partition)
        wandb.define_metric("eval/breakdown/*", step_metric="ppo/updates")
        # Optional diagnostics
        wandb.define_metric("eval/num_episodes", step_metric="ppo/updates")
    except Exception:
        pass


def wb_log(payload: Dict[str, Any], step: Optional[int] = None):
    """Safe log to W&B."""
    if not _WB_AVAILABLE or wandb.run is None:
        return
    try:
        if step is None:
            wandb.log(payload)
        else:
            wandb.log(payload, step=step)
    except Exception:
        pass


def wb_log_train_step(
    num_updates: int,
    total_steps: Optional[int] = None,
    value_loss: Optional[float] = None,
    policy_loss: Optional[float] = None,
    entropy: Optional[float] = None,
    mean_return_window: Optional[float] = None,
    max_return_window: Optional[float] = None,
    global_pool_size: Optional[int] = None,
    preferred_topk_size: Optional[int] = None,
):
    """Convenience logging for each PPO update."""
    payload: Dict[str, Any] = {
        "ppo/updates": num_updates,
    }
    if total_steps is not None:
        payload["steps_total"] = int(total_steps)
    if value_loss is not None:
        payload["loss/value"] = _coerce_float(value_loss)
    if policy_loss is not None:
        payload["loss/policy"] = _coerce_float(policy_loss)
    if entropy is not None:
        payload["loss/entropy"] = _coerce_float(entropy)
    if mean_return_window is not None:
        payload["agent/mean_return_window"] = _coerce_float(mean_return_window)
    if max_return_window is not None:
        payload["agent/max_return_window"] = _coerce_float(max_return_window)
    if global_pool_size is not None:
        payload["replay/global_pool_size"] = int(global_pool_size)
    if preferred_topk_size is not None:
        payload["replay/preferred_topk"] = int(preferred_topk_size)
    wb_log(payload, step=num_updates)


def wb_log_macro_move(
    macro_step: int,
    from_partition: Any,
    to_partition: Any,
    scores: Dict[Any, float],
    stats: Dict[Any, Dict[str, float]],
    injected_replay_count: int,
    new_levels_added: int,
    global_pool_size: int,
    preferred_topk_size: int,
    step: Optional[int] = None,
):
    """
    Log one partition move:
      - scalars: best score/mu/std, neighbor count, from/to, injected_replay_count, etc.
      - table: rows of {macro_step, from, to, neighbor, score, mu_bar, std_mu}.
    """
    if not _WB_AVAILABLE or wandb.run is None:
        return

    # Table
    table = None
    try:
        table = wandb.Table(
            columns=["macro_step", "from", "to", "neighbor", "score", "mu_bar", "std_mu"]
        )
        for k, v in scores.items():
            key_str = str(k)
            st = stats.get(k, {}) if isinstance(stats, dict) else {}
            mu = _coerce_float(st.get("mu_bar", float("nan")))
            sd = _coerce_float(st.get("std_mu", float("nan")))
            table.add_data(
                int(macro_step),
                str(from_partition),
                str(to_partition),
                key_str,
                float(v),
                mu,
                sd,
            )
    except Exception:
        table = None

    payload = {
        "ppo/updates": step if step is not None else macro_step,
        "macro/macro_step": macro_step,
        "macro/from": str(from_partition),
        "macro/to": str(to_partition),
        "macro/injected_replay_count": int(injected_replay_count),
        "macro/new_levels_added": int(new_levels_added),
        "replay/global_pool_size": int(global_pool_size),
        "replay/preferred_topk": int(preferred_topk_size),
        "neighbor/num_neighbors": len(scores),
    }

    # Best neighbor stats
    try:
        if scores:
            best_key = max(scores, key=lambda k: scores[k])
            st = stats.get(best_key, {}) if isinstance(stats, dict) else {}
            payload.update(
                {
                    "neighbor/best_score": float(scores[best_key]),
                    "neighbor/best_mu_bar": _coerce_float(st.get("mu_bar", float("nan"))),
                    "neighbor/best_std_mu": _coerce_float(st.get("std_mu", float("nan"))),
                    "neighbor/target": str(best_key),
                }
            )
    except Exception:
        pass

    if table is not None:
        payload["tables/neighbor_scores"] = table

    wb_log(payload, step=step)


# =========================
# Evaluation logging
# =========================


def wb_log_eval_summary(
    step_updates: int,
    returns: Optional[Sequence[float]] = None,
    lengths: Optional[Sequence[float]] = None,
    success_rate: Optional[float] = None,
    extra_scalars: Optional[Dict[str, float]] = None,
    breakdown_by_partition: Optional[Mapping[Any, Dict[str, float]]] = None,
    episodes_table: Optional[List[Dict[str, Any]]] = None,
):
    """
    Log evaluation summary at a given PPO update step.
    - returns / lengths: sequences; we log mean/std and histogram (if matplotlib available)
    - success_rate: scalar in [0,1]
    - extra_scalars: any additional eval scalars -> logged under 'eval/*'
    - breakdown_by_partition: dict[partition_key] -> {'mean':..., 'std':..., 'n':...}
      -> each key will be logged as 'eval/breakdown/{key}_mean' etc.
    - episodes_table: list of dicts per episode to be shown as a W&B table (columns are auto-inferred)
    """
    payload: Dict[str, Any] = {
        "ppo/updates": int(step_updates),
    }

    # Basic stats
    if returns is not None and len(returns) > 0:
        try:
            mean_r = (
                float(_np.mean(returns)) if _np is not None else sum(returns) / len(returns)
            )
            std_r = float(_np.std(returns)) if _np is not None else 0.0
        except Exception:
            mean_r, std_r = _coerce_float(returns[-1]), float("nan")
        payload["eval/returns_mean"] = mean_r
        payload["eval/returns_std"] = std_r

    if lengths is not None and len(lengths) > 0:
        try:
            mean_l = (
                float(_np.mean(lengths)) if _np is not None else sum(lengths) / len(lengths)
            )
            std_l = float(_np.std(lengths)) if _np is not None else 0.0
        except Exception:
            mean_l, std_l = _coerce_float(lengths[-1]), float("nan")
        payload["eval/len_mean"] = mean_l
        payload["eval/len_std"] = std_l

    if success_rate is not None:
        payload["eval/success_rate"] = _coerce_float(success_rate)

    if episodes_table is not None and _WB_AVAILABLE and wandb.run is not None:
        try:
            # infer columns
            cols: List[str] = []
            for row in episodes_table:
                for k in row.keys():
                    if k not in cols:
                        cols.append(k)
            table = wandb.Table(columns=cols)
            for row in episodes_table:
                table.add_data(*[row.get(c, None) for c in cols])
            payload["tables/eval_episodes"] = table
            payload["eval/num_episodes"] = len(episodes_table)
        except Exception:
            pass

    # Partition breakdown
    if breakdown_by_partition:
        for k, d in breakdown_by_partition.items():
            key = str(k)
            for subk, val in d.items():
                payload[f"eval/breakdown/{key}_{subk}"] = _coerce_float(val)

    # Any extra eval scalars
    if extra_scalars:
        for k, v in extra_scalars.items():
            payload[f"eval/{k}"] = _coerce_float(v)

    wb_log(payload, step=step_updates)

    # Optional histograms as images (matplotlib)
    if _MPL_OK and _WB_AVAILABLE and wandb.run is not None:
        try:
            if returns is not None and len(returns) > 0:
                _plt.figure()
                _plt.hist(list(returns), bins=30)
                _plt.title("Eval Returns Histogram")
                _plt.xlabel("return")
                _plt.ylabel("count")
                wandb.log({"plots/eval_returns_hist": wandb.Image(_plt)}, step=step_updates)
                _plt.close()

            if lengths is not None and len(lengths) > 0:
                _plt.figure()
                _plt.hist(list(lengths), bins=30)
                _plt.title("Eval Episode Lengths Histogram")
                _plt.xlabel("length")
                _plt.ylabel("count")
                wandb.log({"plots/eval_lengths_hist": wandb.Image(_plt)}, step=step_updates)
                _plt.close()
        except Exception:
            pass


def wb_log_eval_curve(
    step_updates: int,
    curves: Mapping[str, Tuple[Sequence[float], Sequence[float]]],
    title_prefix: str = "eval_curve",
):
    """
    Log line curves as Tables so W&B can render interactive plots.
    curves: dict[name] = (xs, ys)
    """
    if not _WB_AVAILABLE or wandb.run is None:
        return
    for name, (xs, ys) in curves.items():
        try:
            table = wandb.Table(columns=["x", "y"])
            for x, y in zip(xs, ys):
                table.add_data(_coerce_float(x), _coerce_float(y))
            wandb.log({f"tables/{title_prefix}/{name}": table}, step=step_updates)
        except Exception:
            continue


def wb_log_eval_heatmap(
    step_updates: int,
    grid: Sequence[Sequence[float]],
    title: str = "Eval Heatmap",
    x_label: str = "X",
    y_label: str = "Y",
):
    """
    Log a heatmap image (e.g., returns over a 2D partition grid).
    Falls back silently if matplotlib not available.
    """
    if not (_MPL_OK and _WB_AVAILABLE and wandb.run is not None):
        return
    try:
        arr = _np.array(grid) if _np is not None else grid
        _plt.figure()
        _plt.imshow(arr, aspect="auto")
        _plt.colorbar()
        _plt.title(title)
        _plt.xlabel(x_label)
        _plt.ylabel(y_label)
        wandb.log({f"plots/{title}": wandb.Image(_plt)}, step=step_updates)
        _plt.close()
    except Exception:
        pass


def wb_log_eval_hist(
    step_updates: int,
    name: str,
    values: Sequence[float],
    bins: int = 30,
    title: Optional[str] = None,
):
    """
    Log one histogram as an image (matplotlib).
    """
    if not (_MPL_OK and _WB_AVAILABLE and wandb.run is not None):
        return
    try:
        _plt.figure()
        _plt.hist(list(values), bins=bins)
        _plt.title(title or f"Histogram: {name}")
        _plt.xlabel(name)
        _plt.ylabel("count")
        wandb.log({f"plots/hist/{name}": wandb.Image(_plt)}, step=step_updates)
        _plt.close()
    except Exception:
        pass


def wb_finish():
    """Finish the W&B run if available."""
    if not _WB_AVAILABLE or wandb.run is None:
        return
    try:
        wandb.finish()
    except Exception:
        pass
