#!/usr/bin/env python3
"""
Plot W&B metrics for Qwen3-4B math ver@k retry experiments across estimators.

Default filters match:
  - k=2
  - retry response length=512 (max response length=1280)
  - group size (rollout.n)=20
  - model contains "Qwen3-4B-Instruct"

Example:
  python scripts/analysis/plot_wandb_ver_k_estimators.py \
    --entity ... \
    --project verl_ver_k_retry_math \
    --output outputs/plots/qwen3_4b_math_ver_k2_n20_estimators
"""

from __future__ import annotations

import argparse
import json
import os
import re
import sys
from dataclasses import dataclass
from typing import Any, Iterable

import numpy as np
import pandas as pd

import matplotlib as mpl

mpl.use("Agg")  # headless-safe
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

import wandb


ESTIMATORS_DEFAULT = [
    "grpo_vectorized",
    "grpo_verk_step_reward_step_norm",
    "grpo_verk_step_reward_step_norm_reweight_future_only",
]

ESTIMATOR_LABELS = {
    "grpo_vectorized": "TL-GRPO",
    "grpo_verk_step_reward_step_norm": "AL-GRPO",
    "grpo_verk_step_reward_step_norm_reweight_future_only": "W-AL-GRPO",
}

COLOR_CYCLE = ["#E07A7A", "#7FBF7A", "#6FA8DC"]


@dataclass(frozen=True)
class RunFilter:
    path: str
    expected: Any


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Plot W&B curves for Qwen3-4B math ver@k retry experiments.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("--entity", default=os.environ.get("WANDB_ENTITY"), help="W&B entity/user name")
    parser.add_argument("--project", default="verl_ver_k_retry_math", help="W&B project name")

    parser.add_argument("--k", type=int, default=2, help="Max assistant turns (Ver@K)")
    parser.add_argument("--retry-resp-len", type=int, default=512, help="Per-attempt response length cap")
    parser.add_argument("--max-resp-len", type=int, default=1280, help="Total response length cap")
    parser.add_argument("--group-size", type=int, default=20, help="GRPO group size (rollout.n)")
    parser.add_argument("--model-contains", default="Qwen3-4B-Instruct", help="Substring to match model path")
    parser.add_argument(
        "--train-file-contains",
        default="math_ver_k_retry_k2",
        help="Substring to match training file path",
    )
    parser.add_argument(
        "--val-file-contains",
        default="math_ver_k_retry_k2",
        help="Substring to match validation file path",
    )
    parser.add_argument(
        "--adv-estimators",
        nargs="+",
        default=ESTIMATORS_DEFAULT,
        help="Advantage estimators to include",
    )
    parser.add_argument(
        "--state",
        default="finished",
        help="Run state filter (e.g., finished, running, crashed). Use 'any' to disable.",
    )

    parser.add_argument(
        "--metric",
        default=None,
        help="Exact W&B metric key to plot. If omitted, auto-detects val-core acc mean@20.",
    )
    parser.add_argument(
        "--metric-regex",
        default=r"^val-core/.*/acc/mean@20$",
        help="Regex used to auto-detect a metric key when --metric is not set.",
    )
    parser.add_argument(
        "--turn-metric-regex",
        default=r"^turn_success/p_turn(\d+)_cond$",
        help="Regex used to auto-detect per-turn success metrics.",
    )
    parser.add_argument(
        "--assistant-turns-metric-regex",
        default=r"^(val-aux/assistant_turns/mean|assistant_turns/mean)$",
        help="Regex used to auto-detect assistant_turns/mean metrics.",
    )
    parser.add_argument(
        "--step-weight-metric-regex",
        default=r"^step_weight/future_only/turn(\d+)_cond$",
        help="Regex used to auto-detect step_weight metrics for future_only estimator.",
    )
    parser.add_argument(
        "--step-reward-std-metric-regex",
        default=r"^step_reward/std_turn(\d+)_cond$",
        help="Regex used to auto-detect step_reward std metrics.",
    )
    parser.add_argument(
        "--step-output-length-metric-regex",
        default=r"^step_output_length/mean_turn(\d+)_cond$",
        help="Regex used to auto-detect step_output_length mean metrics.",
    )
    parser.add_argument(
        "--no-turn-success",
        action="store_true",
        help="Disable plotting turn_success/p_turni_cond metrics.",
    )
    parser.add_argument(
        "--x-key",
        default="_step",
        help="X-axis key in W&B history (default: _step).",
    )
    parser.add_argument("--min-step", type=int, default=None, help="Minimum step to include")
    parser.add_argument("--max-step", type=int, default=None, help="Maximum step to include")
    parser.add_argument("--x-keep-every", type=int, default=10, help="Keep only steps that are multiples of this value")
    parser.add_argument("--smooth", type=int, default=0, help="Rolling mean window for smoothing (0 disables)")
    parser.add_argument("--print-val-acc", action="store_true", help="Print validation accuracy values for debugging")
    parser.add_argument(
        "--earliest-per-estimator",
        action="store_true",
        help="Keep only the earliest-created run for each estimator",
    )

    parser.add_argument("--figsize", default="5.5,4.0", help="Figure size in inches, e.g. 5.5,4.0")
    parser.add_argument("--dpi", type=int, default=300, help="Output DPI for saved figures")
    parser.add_argument("--title", default=None, help="Optional plot title")
    parser.add_argument("--xlabel", default=None, help="X-axis label")
    parser.add_argument("--ylabel", default=None, help="Y-axis label")
    parser.add_argument("--legend-loc", default="best", help="Legend location")
    parser.add_argument("--x-tick-every", type=int, default=10, help="Major tick spacing on the x-axis")
    parser.add_argument(
        "--attempt-legend-y-success",
        type=float,
        default=None,
        help="Y anchor for attempt legend in attempt success plot (0-1).",
    )
    parser.add_argument(
        "--attempt-legend-y-weight",
        type=float,
        default=None,
        help="Y anchor for attempt legend in attempt weight plot (0-1).",
    )
    parser.add_argument(
        "--attempt-legend-loc-mean-length",
        default=None,
        help="Legend location for attempt legend in mean response length plot.",
    )
    parser.add_argument(
        "--attempt-legend-y-mean-length",
        type=float,
        default=None,
        help="Y anchor for attempt legend in mean response length plot (0-1).",
    )
    parser.add_argument(
        "--estimator-legend-y-std",
        type=float,
        default=None,
        help="Y anchor for estimator legend in attempt reward std plot (0-1).",
    )
    parser.add_argument(
        "--attempt-legend-loc-std",
        default=None,
        help="Legend location for attempt legend in attempt reward std plot.",
    )
    parser.add_argument(
        "--attempt-legend-y-std",
        type=float,
        default=None,
        help="Y anchor for attempt legend in attempt reward std plot (0-1).",
    )

    parser.add_argument(
        "--formats",
        default="png,pdf",
        help="Comma-separated output formats (png,pdf,svg,eps).",
    )
    parser.add_argument(
        "--output",
        default=None,
        help="Output path without extension (default is auto-derived from filters)",
    )
    parser.add_argument("--no-csv", action="store_true", help="Do not save aggregated CSV data")
    parser.add_argument("--save-runs", action="store_true", help="Save matched run metadata as JSON")

    return parser.parse_args()


def slugify(value: str) -> str:
    value = value.strip().lower()
    value = re.sub(r"[^a-z0-9]+", "_", value)
    return value.strip("_") or "plot"


def get_nested(cfg: dict[str, Any], path: str) -> Any:
    cur: Any = cfg
    for part in path.split("."):
        if isinstance(cur, dict) and part in cur:
            cur = cur[part]
        else:
            return None
    return cur


def normalize_number(value: Any) -> Any:
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return value
    if isinstance(value, str):
        s = value.strip()
        if s == "":
            return value
        try:
            if any(ch in s for ch in [".", "e", "E"]):
                return float(s)
            return int(s)
        except ValueError:
            return value
    return value


def values_equal(val: Any, expected: Any) -> bool:
    if val is None:
        return False
    val_n = normalize_number(val)
    exp_n = normalize_number(expected)
    if isinstance(exp_n, (int, float)) and isinstance(val_n, (int, float)):
        return val_n == exp_n
    return str(val) == str(expected)


def value_contains(val: Any, needle: str | None) -> bool:
    if not needle:
        return True
    if val is None:
        return False
    if isinstance(val, (list, tuple)):
        return any(needle in str(v) for v in val)
    return needle in str(val)


def get_adv_estimator(cfg: dict[str, Any]) -> str | None:
    return get_nested(cfg, "algorithm.adv_estimator") or cfg.get("algorithm.adv_estimator") or cfg.get("adv_estimator")


def filter_runs(
    runs: Iterable[wandb.apis.public.Run],
    filters: list[RunFilter],
    model_contains: str | None,
    train_file_contains: str | None,
    val_file_contains: str | None,
    estimators: list[str],
) -> list[wandb.apis.public.Run]:
    matched = []
    for run in runs:
        cfg = run.config or {}

        ok = True
        for filt in filters:
            val = get_nested(cfg, filt.path)
            if val is None:
                val = cfg.get(filt.path)
            if not values_equal(val, filt.expected):
                ok = False
                break
        if not ok:
            continue

        if model_contains:
            model_path = get_nested(cfg, "actor_rollout_ref.model.path") or cfg.get("actor_rollout_ref.model.path")
            if not value_contains(model_path, model_contains):
                continue

        if train_file_contains:
            train_files = get_nested(cfg, "data.train_files") or cfg.get("data.train_files")
            if not value_contains(train_files, train_file_contains):
                continue

        if val_file_contains:
            val_files = get_nested(cfg, "data.val_files") or cfg.get("data.val_files")
            if not value_contains(val_files, val_file_contains):
                continue

        adv = get_adv_estimator(cfg)
        if adv not in estimators:
            continue

        matched.append(run)

    return matched


def choose_metric_key(
    runs: list[wandb.apis.public.Run],
    metric: str | None,
    metric_regex: str,
) -> str:
    if metric:
        return metric

    regex = re.compile(metric_regex)
    summary_keys_per_run = []
    for run in runs:
        keys = set(run.summary.keys()) if run.summary else set()
        summary_keys_per_run.append({k for k in keys if regex.search(k)})

    if not summary_keys_per_run:
        raise RuntimeError("No runs available to infer metric key.")

    common = set.intersection(*summary_keys_per_run) if summary_keys_per_run else set()
    candidates = common if common else set.union(*summary_keys_per_run)
    if not candidates:
        # Fallback: look for val-core acc/reward mean@N and pick the largest N available.
        fallback_pattern = re.compile(r"^val-core/.*/(acc|reward)/mean@(\d+)$")
        fallback_keys = []
        for run in runs:
            keys = set(run.summary.keys()) if run.summary else set()
            for key in keys:
                match = fallback_pattern.match(key)
                if match:
                    var = match.group(1)
                    n_val = int(match.group(2))
                    fallback_keys.append((key, var, n_val))

        if not fallback_keys:
            raise RuntimeError(
                f"Could not find metric keys matching regex '{metric_regex}' in W&B summaries. "
                "Pass --metric explicitly."
            )

        # Prefer acc over reward when available.
        vars_present = {var for _, var, _ in fallback_keys}
        if "acc" in vars_present:
            fallback_keys = [item for item in fallback_keys if item[1] == "acc"]

        # Prefer largest mean@N, then prefer math/test sources.
        def score_key(key: str) -> tuple[int, int, int]:
            return (
                0 if "/math/" in key else 1,
                0 if "/test/" in key else 1,
                len(key),
            )

        fallback_keys.sort(key=lambda item: (-item[2], score_key(item[0])))
        return fallback_keys[0][0]

    # Prefer keys containing 'math' if multiple.
    def score_key(key: str) -> tuple[int, int, int]:
        return (
            0 if "/math/" in key else 1,
            0 if "/test/" in key else 1,
            len(key),
        )

    return sorted(candidates, key=score_key)[0]


def fetch_metric_series(
    run: wandb.apis.public.Run,
    metric_key: str,
    x_key: str,
    min_step: int | None,
    max_step: int | None,
    keep_every: int | None,
) -> pd.DataFrame:
    keys = {"_step", metric_key}
    if x_key != "_step":
        keys.add(x_key)

    rows = list(run.scan_history(keys=list(keys)))
    if not rows:
        return pd.DataFrame(columns=["x", "y"])

    df = pd.DataFrame(rows)
    if metric_key not in df.columns:
        return pd.DataFrame(columns=["x", "y"])

    if x_key not in df.columns or df[x_key].isna().all():
        x_key = "_step"

    cols = ["_step", metric_key]
    if x_key != "_step":
        cols.insert(1, x_key)
    df = df[cols].rename(columns={metric_key: "y"})
    if x_key != "_step":
        df = df.rename(columns={x_key: "x"})
    else:
        df["x"] = df["_step"]

    df = df.dropna(subset=["x", "y"])
    if "_step" in df.columns:
        df = df.sort_values("_step")
    df = df.groupby("x", as_index=False).last().sort_values("x")
    # Keep only integer steps and (optionally) multiples of keep_every
    if np.issubdtype(df["x"].dtype, np.number):
        x_rounded = df["x"].round().astype(int)
        mask = np.isclose(df["x"], x_rounded, atol=1e-6)
        df = df[mask].copy()
        df["x"] = x_rounded[mask]
        if keep_every and keep_every > 0:
            df = df[df["x"] % keep_every == 0]

    if min_step is not None:
        df = df[df["x"] >= min_step]
    if max_step is not None:
        df = df[df["x"] <= max_step]

    return df


def smooth_series(df: pd.DataFrame, window: int) -> pd.DataFrame:
    if window <= 1:
        return df
    df = df.sort_values("x").copy()
    df["y"] = df["y"].rolling(window=window, min_periods=1, center=True).mean()
    return df


def metric_to_label(metric_key: str) -> str:
    if "/acc/" in metric_key:
        return "Validation Accuracy"
    if metric_key.startswith("val-core/"):
        parts = metric_key.split("/")
        if len(parts) >= 4:
            var = parts[2]
            metric = "/".join(parts[3:])
            var_label = var.replace("_", " ").title()
            return f"Validation {var_label} ({metric})"
    if metric_key.startswith("critic/"):
        return metric_key.replace("critic/", "Critic ").replace("/", " ")
    return metric_key.replace("_", " ")


def extract_turn_index(metric_key: str, pattern: re.Pattern[str]) -> int | None:
    match = pattern.match(metric_key)
    if not match:
        return None
    try:
        return int(match.group(1))
    except (IndexError, ValueError):
        return None


def _bump_legend_font(legend: mpl.legend.Legend, delta: float = 2.0) -> None:
    for text in legend.get_texts():
        text.set_fontsize(text.get_fontsize() + delta)


def choose_assistant_metric(keys: set[str]) -> str | None:
    # Prefer validation metric if available, otherwise fall back to training metric.
    if "val-aux/assistant_turns/mean" in keys:
        return "val-aux/assistant_turns/mean"
    if "assistant_turns/mean" in keys:
        return "assistant_turns/mean"
    return None


def discover_metric_keys_from_history(
    run: wandb.apis.public.Run, pattern: re.Pattern[str], max_rows: int = 200
) -> set[str]:
    discovered: set[str] = set()
    for idx, row in enumerate(run.scan_history()):
        for key in row.keys():
            if pattern.match(key):
                discovered.add(key)
        if discovered and idx >= max_rows:
            break
    return discovered


def apply_plot_style(figsize: tuple[float, float], dpi: int) -> None:
    mpl.rcParams.update(
        {
            "figure.figsize": figsize,
            "figure.dpi": 100,
            "savefig.dpi": dpi,
            "font.family": "sans-serif",
            "font.sans-serif": [
                "Source Sans 3",
                "Inter",
                "Avenir",
                "Helvetica Neue",
                "Arial",
                "DejaVu Sans",
            ],
            "font.size": 13,
            "font.weight": "bold",
            "axes.labelsize": 15,
            "axes.labelweight": "bold",
            "axes.titlesize": 15,
            "axes.titleweight": "bold",
            "legend.fontsize": 11,
            "legend.title_fontsize": 11,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "xtick.major.width": 1.0,
            "ytick.major.width": 1.0,
            "xtick.major.size": 5,
            "ytick.major.size": 5,
            "axes.linewidth": 0.8,
            "grid.alpha": 0.25,
        }
    )


def main() -> None:
    args = parse_args()

    if not args.entity:
        print("ERROR: W&B entity not set. Use --entity or set WANDB_ENTITY.", file=sys.stderr)
        sys.exit(1)

    api = wandb.Api()

    api_filters: dict[str, Any] = {}
    if args.state and args.state != "any":
        api_filters["state"] = args.state
    if args.group_size is not None:
        api_filters["config.actor_rollout_ref.rollout.n"] = args.group_size
    if args.max_resp_len is not None:
        api_filters["config.data.max_response_length"] = args.max_resp_len
    if args.retry_resp_len is not None:
        api_filters["config.actor_rollout_ref.rollout.multi_turn.max_assistant_response_length"] = args.retry_resp_len
    if args.k is not None:
        api_filters["config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns"] = args.k
    if args.adv_estimators:
        api_filters["config.algorithm.adv_estimator"] = {"$in": args.adv_estimators}

    runs = list(api.runs(f"{args.entity}/{args.project}", filters=api_filters))

    filters = [
        RunFilter("actor_rollout_ref.rollout.n", args.group_size),
        RunFilter("data.max_response_length", args.max_resp_len),
        RunFilter("actor_rollout_ref.rollout.multi_turn.max_assistant_response_length", args.retry_resp_len),
        RunFilter("actor_rollout_ref.rollout.multi_turn.max_assistant_turns", args.k),
    ]

    matched = filter_runs(
        runs,
        filters=filters,
        model_contains=args.model_contains,
        train_file_contains=args.train_file_contains,
        val_file_contains=args.val_file_contains,
        estimators=args.adv_estimators,
    )

    if not matched:
        print("No runs matched the filters.", file=sys.stderr)
        print(f"Project: {args.entity}/{args.project}", file=sys.stderr)
        print(f"API filters: {api_filters}", file=sys.stderr)
        sys.exit(1)

    metric_key = choose_metric_key(matched, args.metric, args.metric_regex)
    turn_metrics: list[str] = []
    turn_pattern = re.compile(args.turn_metric_regex)
    assistant_metric: str | None = None
    assistant_pattern = re.compile(args.assistant_turns_metric_regex)
    step_weight_metrics: list[str] = []
    step_weight_pattern = re.compile(args.step_weight_metric_regex)
    step_reward_std_metrics: list[str] = []
    step_reward_std_pattern = re.compile(args.step_reward_std_metric_regex)
    step_output_len_metrics: list[str] = []
    step_output_len_pattern = re.compile(args.step_output_length_metric_regex)
    for run in matched:
        keys = set(run.summary.keys()) if run.summary else set()
        if not args.no_turn_success:
            for key in keys:
                if turn_pattern.match(key):
                    turn_metrics.append(key)
        for key in keys:
            if step_weight_pattern.match(key):
                step_weight_metrics.append(key)
            if step_reward_std_pattern.match(key):
                step_reward_std_metrics.append(key)
            if step_output_len_pattern.match(key):
                step_output_len_metrics.append(key)
        if assistant_metric is None:
            assistant_metric = choose_assistant_metric(keys)
        if assistant_metric is None:
            for key in keys:
                if assistant_pattern.match(key):
                    assistant_metric = key
                    break
    if not args.no_turn_success:
        turn_metrics = sorted(set(turn_metrics), key=lambda k: extract_turn_index(k, turn_pattern) or 0)
    step_weight_metrics = sorted(
        set(step_weight_metrics), key=lambda k: extract_turn_index(k, step_weight_pattern) or 0
    )
    step_reward_std_metrics = sorted(
        set(step_reward_std_metrics), key=lambda k: extract_turn_index(k, step_reward_std_pattern) or 0
    )
    step_output_len_metrics = sorted(
        set(step_output_len_metrics), key=lambda k: extract_turn_index(k, step_output_len_pattern) or 0
    )

    # Fallback: discover step_weight metrics from history if not in summaries
    if not step_weight_metrics:
        for run in matched:
            adv = get_adv_estimator(run.config or {})
            if adv != "grpo_verk_step_reward_step_norm_reweight_future_only":
                continue
            discovered = discover_metric_keys_from_history(run, step_weight_pattern)
            if discovered:
                step_weight_metrics = sorted(
                    discovered, key=lambda k: extract_turn_index(k, step_weight_pattern) or 0
                )
                break

    # Fallback: discover step_reward/std and step_output_length/mean metrics
    if not step_reward_std_metrics:
        for run in matched:
            discovered = discover_metric_keys_from_history(run, step_reward_std_pattern)
            if discovered:
                step_reward_std_metrics = sorted(
                    discovered, key=lambda k: extract_turn_index(k, step_reward_std_pattern) or 0
                )
                break
    if not step_output_len_metrics:
        for run in matched:
            discovered = discover_metric_keys_from_history(run, step_output_len_pattern)
            if discovered:
                step_output_len_metrics = sorted(
                    discovered, key=lambda k: extract_turn_index(k, step_output_len_pattern) or 0
                )
                break

    runs_by_estimator: dict[str, list[wandb.apis.public.Run]] = {k: [] for k in args.adv_estimators}
    for run in matched:
        adv = get_adv_estimator(run.config or {})
        if adv in runs_by_estimator:
            runs_by_estimator[adv].append(run)

    if args.earliest_per_estimator:
        def created_key(run: wandb.apis.public.Run) -> int:
            ts = pd.to_datetime(run.created_at, utc=True, errors="coerce")
            if pd.isna(ts):
                return sys.maxsize
            return int(ts.value)

        for estimator, runs_list in list(runs_by_estimator.items()):
            if not runs_list:
                continue
            runs_by_estimator[estimator] = sorted(runs_list, key=created_key)[:1]

    series_rows = []
    run_meta = []
    all_metric_keys = [metric_key] + turn_metrics + step_weight_metrics + step_reward_std_metrics + step_output_len_metrics
    if assistant_metric:
        all_metric_keys.append(assistant_metric)
    for estimator, runs_list in runs_by_estimator.items():
        for run in runs_list:
            for key in all_metric_keys:
                df = fetch_metric_series(
                    run,
                    key,
                    args.x_key,
                    args.min_step,
                    args.max_step,
                    args.x_keep_every,
                )
                if df.empty:
                    continue
                df["estimator"] = estimator
                df["metric"] = key
                series_rows.append(df)
            run_meta.append(
                {
                    "id": run.id,
                    "name": run.name,
                    "url": run.url,
                    "estimator": estimator,
                }
            )

    if not series_rows:
        print(f"No history data found for metric '{metric_key}'.", file=sys.stderr)
        sys.exit(1)

    data = pd.concat(series_rows, ignore_index=True)

    grouped = (
        data.groupby(["metric", "estimator", "x"])
        .agg(y_mean=("y", "mean"), y_std=("y", "std"), n=("y", "count"))
        .reset_index()
    )
    grouped["y_sem"] = grouped["y_std"] / np.sqrt(grouped["n"].clip(lower=1))

    if args.smooth and args.smooth > 1:
        grouped = grouped.groupby(["metric", "estimator"], group_keys=False).apply(
            lambda df: smooth_series(df.rename(columns={"y_mean": "y"}), args.smooth).rename(columns={"y": "y_mean"})
        )

    if not args.output:
        dataset_hint = args.train_file_contains or args.val_file_contains or args.project
        output_name = (
            f"{args.project}_k{args.k}_n{args.group_size}_resp{args.max_resp_len}_{slugify(dataset_hint)}_estimators"
        )
        args.output = os.path.join("outputs", "plots", output_name)

    figsize = tuple(float(x) for x in args.figsize.split(","))
    include_turn = (not args.no_turn_success) and len(turn_metrics) > 0
    apply_plot_style(figsize=figsize, dpi=args.dpi)

    fig, ax_main = plt.subplots()

    # Main (validation accuracy)
    main_data = grouped[grouped["metric"] == metric_key]
    if args.print_val_acc:
        print(f"Validation metric: {metric_key}")
        for estimator in args.adv_estimators:
            est_data = main_data[main_data["estimator"] == estimator].sort_values("x")
            if est_data.empty:
                continue
            label = ESTIMATOR_LABELS.get(estimator, estimator)
            print(label)
            for row in est_data[["x", "y_mean", "n"]].itertuples(index=False):
                print(f"  step={int(row.x)} mean={row.y_mean:.6f} n={int(row.n)}")
    for idx, estimator in enumerate(args.adv_estimators):
        est_data = main_data[main_data["estimator"] == estimator]
        if est_data.empty:
            continue
        label = ESTIMATOR_LABELS.get(estimator, estimator)
        color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
        ax_main.plot(est_data["x"], est_data["y_mean"], label=label, color=color, linewidth=2.5)

    ax_main.grid(True, axis="y")
    ax_main.set_xlabel(args.xlabel or "Training step")
    ax_main.set_ylabel(args.ylabel or metric_to_label(metric_key))
    if args.x_tick_every and args.x_tick_every > 0:
        ax_main.xaxis.set_major_locator(mticker.MultipleLocator(args.x_tick_every))
    if args.title:
        ax_main.set_title(args.title)
    legend = ax_main.legend(loc=args.legend_loc, frameon=False)
    if legend is not None:
        for text in legend.get_texts():
            text.set_fontweight("bold")
        _bump_legend_font(legend)
    ax_main.spines["top"].set_visible(False)
    ax_main.spines["right"].set_visible(False)
    for tick in ax_main.get_xticklabels() + ax_main.get_yticklabels():
        tick.set_fontweight("bold")

    fig.tight_layout()

    # Turn success plot (separate figure)
    fig_turn = None
    fig_assistant = None
    fig_step_weight = None
    fig_step_reward_std = None
    fig_step_output_len = None
    if include_turn:
        fig_turn, ax_turn = plt.subplots()
        linestyles = ["-", "--", ":", "-.", (0, (1, 1))]
        turn_data = grouped[grouped["metric"].isin(turn_metrics)]
        for idx, estimator in enumerate(args.adv_estimators):
            est_turn = turn_data[turn_data["estimator"] == estimator]
            if est_turn.empty:
                continue
            color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
            for metric in turn_metrics:
                turn_idx = extract_turn_index(metric, turn_pattern)
                if turn_idx is None:
                    continue
                metric_df = est_turn[est_turn["metric"] == metric]
                if metric_df.empty:
                    continue
                style = linestyles[turn_idx % len(linestyles)]
                ax_turn.plot(
                    metric_df["x"],
                    metric_df["y_mean"],
                    color=color,
                    linestyle=style,
                    linewidth=2.5,
                )

        ax_turn.grid(True, axis="y")
        ax_turn.set_xlabel(args.xlabel or "Training step")
        ax_turn.set_ylabel("Attempt Success")
        if args.x_tick_every and args.x_tick_every > 0:
            ax_turn.xaxis.set_major_locator(mticker.MultipleLocator(args.x_tick_every))
        ax_turn.spines["top"].set_visible(False)
        ax_turn.spines["right"].set_visible(False)
        for tick in ax_turn.get_xticklabels() + ax_turn.get_yticklabels():
            tick.set_fontweight("bold")

        # Estimator legend (color only)
        est_handles = []
        est_labels = []
        for idx, estimator in enumerate(args.adv_estimators):
            color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
            label = ESTIMATOR_LABELS.get(estimator, estimator)
            est_handles.append(mpl.lines.Line2D([0], [0], color=color, linewidth=2.5))
            est_labels.append(label)
        legend_est = ax_turn.legend(est_handles, est_labels, loc="upper left", frameon=False)
        if legend_est is not None:
            for text in legend_est.get_texts():
                text.set_fontweight("bold")
            _bump_legend_font(legend_est)
            ax_turn.add_artist(legend_est)

        # Turn legend (linestyle only)
        turn_handles = []
        turn_labels = []
        for metric in turn_metrics:
            turn_idx = extract_turn_index(metric, turn_pattern)
            if turn_idx is None:
                continue
            style = linestyles[turn_idx % len(linestyles)]
            handle = mpl.lines.Line2D([0], [0], color="black", linestyle=style, linewidth=2.5)
            turn_handles.append(handle)
            turn_labels.append(f"attempt {turn_idx + 1}")
        if turn_handles:
            legend_turn = ax_turn.legend(
                turn_handles,
                turn_labels,
                loc="center right",
                bbox_to_anchor=(0.98, args.attempt_legend_y_success if args.attempt_legend_y_success is not None else 0.61),
                frameon=False,
            )
            if legend_turn is not None:
                for text in legend_turn.get_texts():
                    text.set_fontweight("bold")
                ax_turn.add_artist(legend_turn)

        fig_turn.tight_layout()

    # Assistant turns plot (separate figure)
    if assistant_metric:
        fig_assistant, ax_assist = plt.subplots()
        assistant_data = grouped[grouped["metric"] == assistant_metric]
        for idx, estimator in enumerate(args.adv_estimators):
            est_assist = assistant_data[assistant_data["estimator"] == estimator]
            if est_assist.empty:
                continue
            color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
            label = ESTIMATOR_LABELS.get(estimator, estimator)
            ax_assist.plot(
                est_assist["x"],
                est_assist["y_mean"],
                color=color,
                linewidth=2.5,
                label=label,
            )

        ax_assist.grid(True, axis="y")
        ax_assist.set_xlabel(args.xlabel or "Training step")
        ax_assist.set_ylabel("Avg # of attempts")
        if args.x_tick_every and args.x_tick_every > 0:
            ax_assist.xaxis.set_major_locator(mticker.MultipleLocator(args.x_tick_every))
        ax_assist.spines["top"].set_visible(False)
        ax_assist.spines["right"].set_visible(False)
        legend_assist = ax_assist.legend(loc="best", frameon=False)
        if legend_assist is not None:
            for text in legend_assist.get_texts():
                text.set_fontweight("bold")
            _bump_legend_font(legend_assist)
        for tick in ax_assist.get_xticklabels() + ax_assist.get_yticklabels():
            tick.set_fontweight("bold")
        fig_assistant.tight_layout()

    # Step weight plot (separate figure) for future_only estimator if available
    if step_weight_metrics:
        fig_step_weight, ax_weight = plt.subplots()
        weight_data = grouped[grouped["metric"].isin(step_weight_metrics)]
        if not weight_data.empty:
            linestyles = ["-", "--", ":", "-.", (0, (1, 1))]
            for idx, estimator in enumerate(args.adv_estimators):
                if estimator != "grpo_verk_step_reward_step_norm_reweight_future_only":
                    continue
                est_weight = weight_data[weight_data["estimator"] == estimator]
                if est_weight.empty:
                    continue
                color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
                for metric in step_weight_metrics:
                    turn_idx = extract_turn_index(metric, step_weight_pattern)
                    if turn_idx is None:
                        continue
                    metric_df = est_weight[est_weight["metric"] == metric]
                    if metric_df.empty:
                        continue
                    style = linestyles[turn_idx % len(linestyles)]
                    ax_weight.plot(
                        metric_df["x"],
                        metric_df["y_mean"],
                        color=color,
                        linestyle=style,
                        linewidth=2.5,
                    )

            ax_weight.grid(True, axis="y")
            ax_weight.set_xlabel(args.xlabel or "Training step")
            ax_weight.set_ylabel("Attempt Weights")
            if args.x_tick_every and args.x_tick_every > 0:
                ax_weight.xaxis.set_major_locator(mticker.MultipleLocator(args.x_tick_every))
            ax_weight.spines["top"].set_visible(False)
            ax_weight.spines["right"].set_visible(False)
            for tick in ax_weight.get_xticklabels() + ax_weight.get_yticklabels():
                tick.set_fontweight("bold")

            # Turn legend (linestyle only)
            weight_handles = []
            weight_labels = []
            for metric in step_weight_metrics:
                turn_idx = extract_turn_index(metric, step_weight_pattern)
                if turn_idx is None:
                    continue
                style = linestyles[turn_idx % len(linestyles)]
                handle = mpl.lines.Line2D([0], [0], color="black", linestyle=style, linewidth=2.5)
                weight_handles.append(handle)
                weight_labels.append(f"attempt {turn_idx + 1}")
            if weight_handles:
                legend_weight = ax_weight.legend(
                    weight_handles,
                    weight_labels,
                    loc="center right",
                    bbox_to_anchor=(0.98, args.attempt_legend_y_weight if args.attempt_legend_y_weight is not None else 0.61),
                    frameon=False,
                )
                if legend_weight is not None:
                    for text in legend_weight.get_texts():
                        text.set_fontweight("bold")

            fig_step_weight.tight_layout()

    def plot_attempt_metric(
        metric_keys: list[str],
        pattern: re.Pattern[str],
        ylabel: str,
        output_ax: plt.Axes,
        estimator_legend_y: float | None = None,
        attempt_legend_loc: str = "center right",
        attempt_legend_y: float = 0.61,
    ) -> None:
        linestyles = ["-", "--", ":", "-.", (0, (1, 1))]
        metric_data = grouped[grouped["metric"].isin(metric_keys)]
        for idx, estimator in enumerate(args.adv_estimators):
            est_metric = metric_data[metric_data["estimator"] == estimator]
            if est_metric.empty:
                continue
            color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
            for metric in metric_keys:
                turn_idx = extract_turn_index(metric, pattern)
                if turn_idx is None:
                    continue
                metric_df = est_metric[est_metric["metric"] == metric]
                if metric_df.empty:
                    continue
                style = linestyles[turn_idx % len(linestyles)]
                output_ax.plot(
                    metric_df["x"],
                    metric_df["y_mean"],
                    color=color,
                    linestyle=style,
                    linewidth=2.5,
                )

        output_ax.grid(True, axis="y")
        output_ax.set_xlabel(args.xlabel or "Training step")
        output_ax.set_ylabel(ylabel)
        if args.x_tick_every and args.x_tick_every > 0:
            output_ax.xaxis.set_major_locator(mticker.MultipleLocator(args.x_tick_every))
        output_ax.spines["top"].set_visible(False)
        output_ax.spines["right"].set_visible(False)
        for tick in output_ax.get_xticklabels() + output_ax.get_yticklabels():
            tick.set_fontweight("bold")

        # Estimator legend (color only)
        est_handles = []
        est_labels = []
        for idx, estimator in enumerate(args.adv_estimators):
            color = COLOR_CYCLE[idx % len(COLOR_CYCLE)]
            label = ESTIMATOR_LABELS.get(estimator, estimator)
            est_handles.append(mpl.lines.Line2D([0], [0], color=color, linewidth=2.5))
            est_labels.append(label)
        if estimator_legend_y is None:
            legend_est = output_ax.legend(est_handles, est_labels, loc="upper left", frameon=False)
        else:
            legend_est = output_ax.legend(
                est_handles,
                est_labels,
                loc="upper left",
                bbox_to_anchor=(0.02, estimator_legend_y),
                frameon=False,
            )
        if legend_est is not None:
            for text in legend_est.get_texts():
                text.set_fontweight("bold")
            _bump_legend_font(legend_est)
            output_ax.add_artist(legend_est)

        # Attempt legend (linestyle only)
        attempt_handles = []
        attempt_labels = []
        for metric in metric_keys:
            turn_idx = extract_turn_index(metric, pattern)
            if turn_idx is None:
                continue
            style = linestyles[turn_idx % len(linestyles)]
            handle = mpl.lines.Line2D([0], [0], color="black", linestyle=style, linewidth=2.5)
            attempt_handles.append(handle)
            attempt_labels.append(f"attempt {turn_idx + 1}")
        if attempt_handles:
            legend_attempt = output_ax.legend(
                attempt_handles,
                attempt_labels,
                loc=attempt_legend_loc,
                bbox_to_anchor=(0.98, attempt_legend_y),
                frameon=False,
            )
            if legend_attempt is not None:
                for text in legend_attempt.get_texts():
                    text.set_fontweight("bold")
                output_ax.add_artist(legend_attempt)

    if step_reward_std_metrics:
        fig_step_reward_std, ax_step_reward = plt.subplots()
        plot_attempt_metric(
            step_reward_std_metrics,
            step_reward_std_pattern,
            "Attempt Reward Std",
            ax_step_reward,
            estimator_legend_y=args.estimator_legend_y_std if args.estimator_legend_y_std is not None else 0.88,
            attempt_legend_loc=args.attempt_legend_loc_std or "lower right",
            attempt_legend_y=args.attempt_legend_y_std if args.attempt_legend_y_std is not None else 0.12,
        )
        fig_step_reward_std.tight_layout()

    if step_output_len_metrics:
        fig_step_output_len, ax_step_len = plt.subplots()
        plot_attempt_metric(
            step_output_len_metrics,
            step_output_len_pattern,
            "Mean Response Length",
            ax_step_len,
            estimator_legend_y=0.78,
            attempt_legend_loc=args.attempt_legend_loc_mean_length or "center right",
            attempt_legend_y=args.attempt_legend_y_mean_length if args.attempt_legend_y_mean_length is not None else 0.61,
        )
        fig_step_output_len.tight_layout()

    output_base = args.output
    os.makedirs(os.path.dirname(output_base), exist_ok=True)
    for fmt in [f.strip() for f in args.formats.split(",") if f.strip()]:
        fig.savefig(f"{output_base}.{fmt}", bbox_inches="tight")
        if fig_turn is not None:
            fig_turn.savefig(f"{output_base}_turns.{fmt}", bbox_inches="tight")
        if fig_assistant is not None:
            fig_assistant.savefig(f"{output_base}_assistant_turns.{fmt}", bbox_inches="tight")
        if fig_step_weight is not None:
            fig_step_weight.savefig(f"{output_base}_step_weight.{fmt}", bbox_inches="tight")
        if fig_step_reward_std is not None:
            fig_step_reward_std.savefig(f"{output_base}_step_reward_std.{fmt}", bbox_inches="tight")
        if fig_step_output_len is not None:
            fig_step_output_len.savefig(f"{output_base}_mean_response_length.{fmt}", bbox_inches="tight")

    if not args.no_csv:
        grouped.to_csv(f"{output_base}.csv", index=False)

    if args.save_runs:
        with open(f"{output_base}_runs.json", "w", encoding="utf-8") as f:
            json.dump(run_meta, f, indent=2)

    print(f"Saved: {output_base}.{args.formats}")
    if fig_turn is not None:
        print(f"Saved: {output_base}_turns.{args.formats}")
    if fig_assistant is not None:
        print(f"Saved: {output_base}_assistant_turns.{args.formats}")
    if fig_step_weight is not None:
        print(f"Saved: {output_base}_step_weight.{args.formats}")
    if fig_step_reward_std is not None:
        print(f"Saved: {output_base}_step_reward_std.{args.formats}")
    if fig_step_output_len is not None:
        print(f"Saved: {output_base}_mean_response_length.{args.formats}")
    print(f"Metric: {metric_key}")
    for est, runs_list in runs_by_estimator.items():
        print(f"{est}: {len(runs_list)} runs")


if __name__ == "__main__":
    main()
