#!/usr/bin/env python3
"""
Create Figure 3: composite figure for probe-and-switch + external tasks.

Panels:
(a) External tasks: ranking heatmap (5 algorithms × 3 tasks)
(b) Single-crossing check: conditional advantage curve Δ(p)
(c) Threshold transfer: improvement lollipop (Δregret) for τ=0.12 across tasks

Output:
  evidence/paper_figures/figure3_composite.(pdf|png)
"""

from __future__ import annotations

import argparse
import csv
import os
from dataclasses import dataclass

import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from _project import BASE_DIR, repo_relpath
from plot_style import apply_style, get_algo_color, save_figure, WIDTHS, add_grid, COLORS
from berwes.utils.display_names import get_display_name


def _load_runs(path: str) -> tuple[list[str], dict[str, list[float]]]:
    """
    Load runs.csv from an external task. Returns (metric_candidates, algo->values).

    Different tasks store the post-eval metric under different column names; we return
    all discovered numeric columns so the caller can pick.
    """
    with open(path, newline="") as f:
        reader = csv.DictReader(f)
        rows = list(reader)
    if not rows:
        return [], {}

    # Prefer these metrics in order (lower is better).
    preferred_cols = [
        "post_true",
        "post_mean",
        "post_median",
        "post_cvar20",
        "final_loss",
    ]

    # Determine which preferred metric exists.
    cols = set(rows[0].keys())
    metric_cols = [c for c in preferred_cols if c in cols]

    values_by_algo: dict[str, list[float]] = {}
    metric_used = metric_cols[0] if metric_cols else ""
    for r in rows:
        algo = str(r.get("algorithm", "")).strip()
        if not algo:
            continue
        val_s = r.get(metric_used, "") if metric_used else ""
        try:
            v = float(val_s)
        except Exception:
            continue
        if not np.isfinite(v):
            continue
        values_by_algo.setdefault(algo, []).append(float(v))

    return metric_cols, values_by_algo


def _panel_a_ranking_heatmap(ax: plt.Axes, evidence_dir: str) -> None:
    # Extended dataset coverage: 6 datasets across Control, HPO, and RL domains
    # Ordered by RB rank (best first), then by domain
    tasks = [
        ("LQR\n(Control)", os.path.join(evidence_dir, "application_lqr_heavytail_control_fixed_budget_resample/runs.csv")),
        ("Breast Cancer\n(HPO)", os.path.join(evidence_dir, "application_hpo_breast_cancer_budget10/runs.csv")),
        ("Digits\n(HPO)", os.path.join(evidence_dir, "application_hpo_digits0_budget10/runs.csv")),
        ("CartPole-HT\n(RL)", os.path.join(evidence_dir, "test_rl_std10_df2/runs.csv")),
        ("CartPole\n(RL)", os.path.join(evidence_dir, "application_rl_cartpole_budget3/runs.csv")),
        ("Pendulum\n(RL)", os.path.join(evidence_dir, "application_rl_pendulum_full/runs.csv")),
    ]

    # Algorithms ordered by average rank (best first):
    # Residual Bootstrapping (1.33), CMA-ES (1.67), UH-CMA-ES (3.0), Resample(5) (4.0), Resample(10) (5.0)
    algos = [
        "BERW-HeteroRobust",
        "CMA-ES-sep",
        "UH-CMA-ES(maxevals=30)",
        "CMA-ES-Resample(k=5)",
        "CMA-ES-Resample(k=10)",
    ]
    algo_labels = ["Residual\nBootstrapping", "CMA-ES", "UH-CMA-ES", "Resample(5)", "Resample(10)"]

    # Compute rank matrix (rows = tasks, cols = algos).
    rank_matrix = np.full((len(tasks), len(algos)), np.nan)
    for t_idx, (task_name, runs_path) in enumerate(tasks):
        if not os.path.isfile(runs_path):
            continue
        _, by_algo = _load_runs(runs_path)
        medians = []
        for algo in algos:
            vals = by_algo.get(algo, [])
            medians.append(float(np.median(vals)) if vals else float("inf"))
        # Rank 1 = best (lowest median loss).
        sorted_idx = np.argsort(medians)
        for rank, idx in enumerate(sorted_idx):
            rank_matrix[t_idx, idx] = rank + 1

    # Diverging colormap: green (rank 1) -> yellow (rank 3) -> red (rank 5).
    cmap = plt.cm.RdYlGn_r
    norm = mcolors.Normalize(vmin=1, vmax=5)

    im = ax.imshow(rank_matrix, cmap=cmap, norm=norm, aspect="auto")

    # Add grid lines.
    for i in range(len(tasks) + 1):
        ax.axhline(i - 0.5, color="white", linewidth=1.5)
    for j in range(len(algos) + 1):
        ax.axvline(j - 0.5, color="white", linewidth=1.5)

    # Add rank numbers.
    for i in range(len(tasks)):
        for j in range(len(algos)):
            rank = rank_matrix[i, j]
            if np.isnan(rank):
                ax.text(j, i, "?", ha="center", va="center", fontsize=11,
                        fontweight="bold", color="#666666")
                continue
            rank = int(rank)
            color = "white" if rank in (1, 5) else "black"
            ax.text(j, i, str(rank), ha="center", va="center",
                    fontsize=11, fontweight="bold", color=color)

    # Axis labels — algorithm names on top.
    ax.xaxis.tick_top()
    ax.set_xticks(range(len(algos)))
    ax.set_xticklabels(algo_labels, fontsize=7, rotation=0, ha="center")
    ax.set_yticks(range(len(tasks)))
    ax.set_yticklabels([t[0] for t in tasks], fontsize=7)

    # Remove tick marks.
    ax.tick_params(length=0)


@dataclass(frozen=True)
class _BinStats:
    p_mid: float
    mean: float
    stderr: float


def _quantile_bins(p: np.ndarray, n_bins: int) -> list[tuple[float, float]]:
    qs = np.linspace(0.0, 1.0, int(n_bins) + 1)
    edges = np.quantile(p, qs)
    out: list[tuple[float, float]] = []
    for a, b in zip(edges[:-1], edges[1:]):
        if not out:
            out.append((float(a), float(b)))
            continue
        prev_a, prev_b = out[-1]
        if float(a) <= float(prev_b) + 1e-18:
            a = prev_b
        if float(b) <= float(a) + 1e-18:
            continue
        out.append((float(a), float(b)))
    return out


def _panel_b_single_crossing(ax: plt.Axes, evidence_dir: str) -> None:
    decision_points = os.path.join(
        evidence_dir, "bbob_noisy_probe_decision_accuracy_noisefree_i1-15_B500/decision_points.csv"
    )
    if not os.path.isfile(decision_points):
        ax.text(0.5, 0.5, "Missing decision_points.csv", ha="center", va="center", transform=ax.transAxes)
        ax.set_title("(b) Single-crossing property", fontsize=9, fontweight="bold")
        ax.axis("off")
        return

    rows: list[dict[str, str]] = []
    with open(decision_points, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)
    if not rows:
        raise SystemExit("Empty decision_points.csv")

    def _getf(row: dict[str, str], key: str) -> float:
        v = row.get(key, "")
        try:
            return float(v)
        except Exception:
            return float("nan")

    p = np.asarray([_getf(r, "misranking_rd") for r in rows], dtype=float)
    best_cma = np.asarray([_getf(r, "best_f_cma") for r in rows], dtype=float)
    best_berw = np.asarray([_getf(r, "best_f_berw") for r in rows], dtype=float)

    eps = 1e-12
    delta = np.log10(np.maximum(best_cma, eps)) - np.log10(np.maximum(best_berw, eps))

    ok = np.isfinite(p) & np.isfinite(delta)
    p = p[ok]
    delta = delta[ok]
    if p.size <= 0:
        raise SystemExit("No finite rows in decision_points.csv after filtering.")

    bins = _quantile_bins(p, n_bins=12)
    stats: list[_BinStats] = []
    for lo, hi in bins:
        # left-closed, right-open except last bin
        if hi == bins[-1][1]:
            mask = (p >= lo) & (p <= hi)
        else:
            mask = (p >= lo) & (p < hi)
        vals = delta[mask]
        vals = vals[np.isfinite(vals)]
        if vals.size <= 0:
            continue
        mean = float(np.mean(vals))
        sd = float(np.std(vals, ddof=1)) if vals.size >= 2 else 0.0
        stderr = sd / float(np.sqrt(max(1, vals.size)))
        stats.append(_BinStats(p_mid=float(0.5 * (lo + hi)), mean=mean, stderr=stderr))

    xs = np.asarray([s.p_mid for s in stats], dtype=float)
    ys = np.asarray([s.mean for s in stats], dtype=float)
    es = 1.96 * np.asarray([s.stderr for s in stats], dtype=float)

    color = get_algo_color("BERW-Hetero")
    ax.errorbar(xs, ys, yerr=es, fmt="o-", lw=1.2, ms=3.5, capsize=2, color=color, alpha=0.95)
    ax.axhline(0.0, color="#64748B", lw=0.9, alpha=0.8)

    tau = 0.12
    ax.axvline(tau, color="#EF4444", lw=1.1, alpha=0.9, linestyle="--")
    ax.text(tau, float(np.nanmax(ys)), f"  τ={tau:.2f}", color="#B91C1C", va="top", fontsize=7)

    ax.set_xlabel("Probe statistic $P$ (misranking RD)", fontsize=8)
    ax.set_ylabel(r"$\Delta(p)$", fontsize=8)
    ax.set_title("(b) Single-crossing property", fontsize=9, fontweight="bold")
    add_grid(ax, alpha=0.2)

    ax.text(
        0.02,
        0.02,
        f"Δ(p)=E[log10 f(CMA) − log10 f(RB) | P=p]\n(positive ⇒ RB better)",
        transform=ax.transAxes,
        fontsize=5.8,
        va="bottom",
        ha="left",
        alpha=0.75,
    )


def _panel_c_transfer(ax: plt.Axes, evidence_dir: str) -> None:
    from matplotlib.lines import Line2D

    transfer_csv = os.path.join(evidence_dir, "probeswitch_transfer_overhead_summary/transfer_summary_compact.csv")
    if not os.path.isfile(transfer_csv):
        ax.text(0.5, 0.5, "Missing transfer_summary_compact.csv", ha="center", va="center", transform=ax.transAxes)
        ax.set_title("(c) Zero-tuning transfer", fontsize=9, fontweight="bold")
        ax.axis("off")
        return

    rows: list[dict[str, str]] = []
    with open(transfer_csv, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)

    # Targets to exclude (boundary cases less relevant to main argument).
    EXCLUDE_TARGETS = {"bbob_B200_d10", "bbob_B200_d20"}

    # Compute deltas for both thresholds: τ=0.12 and τ=0.22.
    # delta = always_cma_regret - transfer_regret (positive = transfer improves).
    lollipop_data: list[dict] = []
    for r in rows:
        target = str(r.get("target", ""))
        if target in EXCLUDE_TARGETS:
            continue
        try:
            always_reg = float(r.get("always_cma_regret_mean", "nan"))
            transfer_reg_012 = float(r.get("bbob_B500_regret_mean", "nan"))
            transfer_reg_022 = float(r.get("safe_regret_mean", "nan"))
        except (ValueError, TypeError):
            continue
        if not np.isfinite(always_reg):
            continue
        delta_012 = always_reg - transfer_reg_012 if np.isfinite(transfer_reg_012) else float("nan")
        delta_022 = always_reg - transfer_reg_022 if np.isfinite(transfer_reg_022) else float("nan")
        label = str(r.get("target_label", r.get("target", "?")))
        lollipop_data.append({
            "target": target,
            "label": label,
            "delta_012": delta_012,
            "delta_022": delta_022,
        })

    if not lollipop_data:
        ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
        ax.set_title("(c) Zero-tuning transfer", fontsize=9, fontweight="bold")
        return

    # Sort by delta_012 descending (best improvement at top).
    all_items = sorted(lollipop_data, key=lambda d: d["delta_012"], reverse=True)

    # Draw dual lollipops: τ=0.12 (blue, upper) and τ=0.22 (gray, lower).
    y_offset = 0.15
    color_012 = COLORS["blue"]
    color_022 = "#888888"

    labels: list[str] = []
    for i, item in enumerate(all_items):
        y_base = i
        labels.append(item["label"])

        # τ=0.12 (upper line).
        d_012 = item["delta_012"]
        if np.isfinite(d_012):
            ax.hlines(y_base - y_offset, 0, d_012, color=color_012, linewidth=1.5)
            ax.plot(d_012, y_base - y_offset, "o", color=color_012, markersize=4)

        # τ=0.22 (lower line).
        d_022 = item["delta_022"]
        if np.isfinite(d_022):
            ax.hlines(y_base + y_offset, 0, d_022, color=color_022, linewidth=1.5)
            ax.plot(d_022, y_base + y_offset, "s", color=color_022, markersize=3.5)

    # Zero reference line.
    ax.axvline(0, color="black", linewidth=0.8, linestyle="-")

    ax.set_yticks(range(len(labels)))
    ax.set_yticklabels(labels, fontsize=6)
    ax.set_xlabel(r"$\Delta\,\mathrm{regret}$ (always-CMA $-$ transfer)", fontsize=7)
    ax.set_title("(c) Zero-tuning transfer", fontsize=9, fontweight="bold")

    # Legend for dual thresholds.
    legend_elements = [
        Line2D([0], [0], marker='o', color=color_012, lw=1.5, label=r'$\tau=0.12$', ms=4),
        Line2D([0], [0], marker='s', color=color_022, lw=1.5, label=r'$\tau=0.22$', ms=3.5),
    ]
    ax.legend(handles=legend_elements, loc='lower right', fontsize=6, frameon=False)

    ax.invert_yaxis()
    add_grid(ax, axis="x", alpha=0.2)


def main() -> None:
    parser = argparse.ArgumentParser(description="Create Figure 3 composite plot")
    parser.add_argument(
        "--evidence-dir",
        default=os.path.join(BASE_DIR, "evidence"),
        help="Path to evidence directory",
    )
    parser.add_argument(
        "--output",
        default=os.path.join(BASE_DIR, "evidence/paper_figures/figure3_composite"),
        help="Output path prefix (without extension)",
    )
    args = parser.parse_args()

    os.chdir(BASE_DIR)
    apply_style()

    fig = plt.figure(figsize=(WIDTHS["double"], 3.2))
    gs = gridspec.GridSpec(1, 3, figure=fig, width_ratios=[1.2, 1.0, 1.2], wspace=0.45)

    ax_a = fig.add_subplot(gs[0])
    ax_b = fig.add_subplot(gs[1])
    ax_c = fig.add_subplot(gs[2])

    _panel_a_ranking_heatmap(ax_a, str(args.evidence_dir))
    _panel_b_single_crossing(ax_b, str(args.evidence_dir))
    _panel_c_transfer(ax_c, str(args.evidence_dir))

    save_figure(fig, str(args.output))
    plt.close(fig)

    print("Wrote:", repo_relpath(str(args.output) + ".pdf"))
    print("Wrote:", repo_relpath(str(args.output) + ".png"))


if __name__ == "__main__":
    main()
