#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt

# 仅用于“合并/对比”模式，让重叠曲线也能分辨
import matplotlib.patheffects as pe


# -------------------------
# IO
# -------------------------
def load_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def set_global_style(font_size: int):
    plt.rcParams.update({
        "font.size": font_size,
        "axes.titlesize": font_size + 1,
        "axes.labelsize": font_size,
        "xtick.labelsize": font_size - 1,
        "ytick.labelsize": font_size - 1,
        "legend.fontsize": font_size - 1,
    })


def safe_get(d, k, default=None):
    return d[k] if k in d else default


# -------------------------
# Detect which-field for a jsonl
# -------------------------
def detect_variant(rows):
    for r in rows:
        if "m_curve_pruned" in r:
            return "pruned"
    for r in rows:
        if "m_curve_dense" in r:
            return "dense"
    raise ValueError("Cannot detect variant: jsonl has neither m_curve_pruned nor m_curve_dense.")


def get_keys(which: str):
    key_curve = "m_curve_dense" if which == "dense" else "m_curve_pruned"
    key_last  = "m_last_dense"  if which == "dense" else "m_last_pruned"
    key_star  = "l_star_dense"  if which == "dense" else "l_star_pruned"
    key_corr  = "correct_dense" if which == "dense" else "correct_pruned"
    key_L     = "dense_layers"  if which == "dense" else "pruned_layers"

    key_delta = "delta_m_curve_dense" if which == "dense" else "delta_m_curve_pruned"
    key_late_sum = "late_gain_sum_dense" if which == "dense" else "late_gain_sum_pruned"
    key_late_mean = "late_gain_mean_dense" if which == "dense" else "late_gain_mean_pruned"
    key_persist = "persist_all_after_cross_dense" if which == "dense" else "persist_all_after_cross_pruned"
    key_pos_ratio = "pos_ratio_after_cross_dense" if which == "dense" else "pos_ratio_after_cross_pruned"
    key_last_pos = "last_pos_layer_dense" if which == "dense" else "last_pos_layer_pruned"
    key_min_after = "min_margin_after_cross_dense" if which == "dense" else "min_margin_after_cross_pruned"
    key_amp = "amp_topk_dense" if which == "dense" else "amp_topk_pruned"

    return {
        "curve": key_curve,
        "last": key_last,
        "star": key_star,
        "corr": key_corr,
        "L": key_L,
        "delta": key_delta,
        "late_sum": key_late_sum,
        "late_mean": key_late_mean,
        "persist": key_persist,
        "pos_ratio": key_pos_ratio,
        "last_pos": key_last_pos,
        "min_after": key_min_after,
        "amp": key_amp,
    }


# -------------------------
# Parse one jsonl into aligned arrays
# -------------------------
def parse_metrics(rows, which: str, L_use: int):
    keys = get_keys(which)

    curves = np.array([np.array(r[keys["curve"]][:L_use], dtype=np.float32) for r in rows])  # [N, L]
    m_last = np.array([float(r[keys["last"]]) for r in rows], dtype=np.float32)
    l_star = np.array([int(r[keys["star"]]) for r in rows], dtype=np.int32)
    corr   = np.array([int(r[keys["corr"]]) for r in rows], dtype=np.int32)

    deltas_list = []
    for r in rows:
        d = safe_get(r, keys["delta"], None)
        if d is None:
            mc = np.array(r[keys["curve"]][:L_use], dtype=np.float32)
            d = (mc[1:] - mc[:-1]).tolist()
        deltas_list.append(np.array(d[:(L_use - 1)], dtype=np.float32))
    deltas = np.stack(deltas_list, axis=0)  # [N, L-1]

    late_sum  = np.array([float(safe_get(r, keys["late_sum"], np.nan)) for r in rows], dtype=np.float32)
    late_mean = np.array([float(safe_get(r, keys["late_mean"], np.nan)) for r in rows], dtype=np.float32)
    persist   = np.array([int(bool(safe_get(r, keys["persist"], False))) for r in rows], dtype=np.int32)
    pos_ratio = np.array([safe_get(r, keys["pos_ratio"], np.nan) for r in rows], dtype=np.float32)
    last_pos  = np.array([int(safe_get(r, keys["last_pos"], 0)) for r in rows], dtype=np.int32)
    min_after = np.array([safe_get(r, keys["min_after"], np.nan) for r in rows], dtype=np.float32)

    amp_layers = []
    for r in rows:
        amps = safe_get(r, keys["amp"], [])
        if isinstance(amps, list):
            for it in amps:
                if isinstance(it, dict) and "to_layer" in it:
                    amp_layers.append(int(it["to_layer"]))

    return {
        "curves": curves,
        "deltas": deltas,
        "m_last": m_last,
        "l_star": l_star,
        "corr": corr,
        "late_sum": late_sum,
        "late_mean": late_mean,
        "persist": persist,
        "pos_ratio": pos_ratio,
        "last_pos": last_pos,
        "min_after": min_after,
        "amp_layers": amp_layers,
        "N": len(rows),
    }


def compute_min_common_L(jsonl_specs):
    Lmins = []
    for spec in jsonl_specs:
        rows = spec["rows"]
        which = spec["which"]
        keys = get_keys(which)
        Ls = [int(r[keys["L"]]) for r in rows if keys["L"] in r]
        if len(Ls) == 0:
            raise ValueError(f"No layer-length field found for {spec['path']} (which={which}).")
        Lmins.append(int(min(Ls)))
    L_use = int(min(Lmins))
    if L_use <= 1:
        raise ValueError(f"Invalid common L_use={L_use}. Need at least 2.")
    return L_use


def default_label_from_path(p):
    base = os.path.basename(p)
    return os.path.splitext(base)[0]


# -------------------------
# Presentation helpers
# -------------------------
def format_pr_tag(prune_ratio):
    if prune_ratio is None:
        return None
    pr = float(prune_ratio)
    pr_pct = pr if pr > 1.0 else pr * 100.0
    return f"PR-{pr_pct:.1f}%"


def make_title(base, pr_tag, extra=None):
    parts = [base]
    if extra:
        parts.append(extra)
    if pr_tag:
        parts.append(pr_tag)
    return " | ".join(parts)


def savefig(out_dir, name, dpi):
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, name), dpi=dpi)
    plt.close()


# -------------------------
# Compare-mode helpers (ONLY affect overlay plots)
# -------------------------
def gaussian_smooth_1d(y, sigma_bins=2.0):
    if sigma_bins is None or sigma_bins <= 0:
        return y
    radius = int(max(3, np.ceil(4 * float(sigma_bins))))
    x = np.arange(-radius, radius + 1)
    k = np.exp(-(x ** 2) / (2 * float(sigma_bins) ** 2))
    k = k / k.sum()
    return np.convolve(y, k, mode="same")


def smooth_density_curve(x, bins=120, sigma_bins=2.0, xlim=None):
    x = np.asarray(x, dtype=np.float64)
    x = x[np.isfinite(x)]
    if x.size == 0:
        return None, None

    if xlim is None:
        lo, hi = np.percentile(x, [0.5, 99.5])
        if lo == hi:
            lo, hi = x.min() - 1e-6, x.max() + 1e-6
    else:
        lo, hi = xlim

    hist, edges = np.histogram(x, bins=int(bins), range=(lo, hi), density=True)
    hist_s = gaussian_smooth_1d(hist, sigma_bins=sigma_bins)
    centers = 0.5 * (edges[:-1] + edges[1:])
    return centers, hist_s


# def smooth_pmf_curve_int(x_int, xmin, xmax, sigma_bins=1.2):
#     x = np.asarray(x_int, dtype=np.int64)
#     x = x[np.isfinite(x)]
#     if x.size == 0:
#         return None, None
#     xs = np.arange(int(xmin), int(xmax) + 1)
#     cnt = np.zeros_like(xs, dtype=np.float64)
#     x = x[(x >= xmin) & (x <= xmax)]
#     if x.size == 0:
#         return None, None
#     idx = x - xmin
#     np.add.at(cnt, idx, 1.0)
#     pmf = cnt / max(1.0, cnt.sum())
#     pmf_s = gaussian_smooth_1d(pmf, sigma_bins=sigma_bins)
#     return xs, pmf_s

def smooth_pmf_curve_int(
    x_int,
    xmin: int,
    xmax: int,
    sigma_bins: float = 1.2,
    valid_max: int | None = None,
    renorm: bool = True,
):
    """
    对离散整数变量（如 crossing layer l_star）估计平滑 PMF。
    关键特性：
    - xs 固定为 [xmin, xmax] 的完整支持集（例如 1..32），不会因为数据只到 6 就把横轴截断
    - valid_max 可用于“结构约束”：例如 pruned_layers=16，则 xs>16 的概率强制为 0
    - 可选 renorm：置零后是否重新归一化
    """
    x = np.asarray(x_int, dtype=np.int64)
    x = x[np.isfinite(x)]
    if x.size == 0:
        return None, None

    xmin = int(xmin)
    xmax = int(xmax)
    if xmin > xmax:
        return None, None

    # 固定支持集：无论数据最大只有 6，这里都给出 1..32
    xs = np.arange(xmin, xmax + 1, dtype=np.int64)
    cnt = np.zeros(xs.shape[0], dtype=np.float64)

    # 只统计落在 [xmin, xmax] 的样本
    x = x[(x >= xmin) & (x <= xmax)]
    if x.size == 0:
        # 没有任何落在支持集内的值：返回全零
        return xs, np.zeros_like(xs, dtype=np.float64)

    idx = x - xmin
    np.add.at(cnt, idx, 1.0)

    pmf = cnt / max(1.0, cnt.sum())

    # 平滑
    pmf_s = gaussian_smooth_1d(pmf, sigma_bins=sigma_bins)

    # 结构性约束：例如剪枝后最多 16 层，则 xs>16 处强制为 0
    if valid_max is not None:
        valid_max = int(valid_max)
        mask = xs <= valid_max          # 注意：mask 的长度与 xs/pmf_s 一致
        pmf_s = pmf_s.copy()
        pmf_s[~mask] = 0.0
        if renorm:
            s = pmf_s.sum()
            if s > 0:
                pmf_s /= s

    return xs, pmf_s


def parse_pr_from_label(label: str):
    if not label:
        return None
    s = label.strip()
    if "PR-" in s or "pr-" in s or "PR_" in s or "pr_" in s:
        import re
        m = re.search(r"(?i)pr[-_ ]*([0-9]+(?:\.[0-9]+)?)\s*%?", s)
        if m:
            return float(m.group(1))
    import re
    m = re.search(r"([0-9]+(?:\.[0-9]+)?)", s)
    if not m:
        return None
    v = float(m.group(1))
    if v <= 1.0:
        return v * 100.0
    return v


def select_overlay_indices(datasets, max_k):
    n = len(datasets)
    if max_k is None or max_k <= 0 or n <= max_k:
        return list(range(n))

    other = list(range(1, n))
    pr_vals = []
    ok = True
    for i in other:
        pr = parse_pr_from_label(datasets[i].get("label", ""))
        if pr is None:
            ok = False
            break
        pr_vals.append((i, pr))

    if not ok:
        return [0] + other[:max_k - 1]

    pr_vals.sort(key=lambda t: t[1])
    k_other = max_k - 1
    if k_other >= len(pr_vals):
        return [0] + [i for i, _ in pr_vals]

    idxs = np.linspace(0, len(pr_vals) - 1, k_other)
    idxs = np.round(idxs).astype(int)
    idxs = np.unique(idxs).tolist()
    while len(idxs) < k_other:
        for cand in range(len(pr_vals)):
            if cand not in idxs:
                idxs.append(cand)
                if len(idxs) == k_other:
                    break
    idxs = sorted(idxs[:k_other])
    chosen = [pr_vals[t][0] for t in idxs]
    return [0] + chosen


# --- 关键：合并模式固定“好看”的配色 + 用线型/marker/描边区分重叠 ---
def build_fixed_overlay_styles(datasets, main_color):
    """
    overlay 目标：
      1) 配色固定且好看，不出现灰不拉几颜色；
      2) 主曲线仍用 args.line_color（与单图一致）；
      3) 重叠可辨：linestyle/marker + 白色描边。
    """
    n = len(datasets)

    others = list(range(1, n))
    pr_list = []
    all_ok = True
    for i in others:
        pr = parse_pr_from_label(datasets[i].get("label", ""))
        if pr is None:
            all_ok = False
            break
        pr_list.append((i, pr))

    if all_ok:
        pr_list.sort(key=lambda t: t[1])
        ordered = [0] + [i for i, _ in pr_list]
    else:
        ordered = [0] + sorted(others, key=lambda i: str(datasets[i].get("label", "")))

    colors = {}
    colors[0] = main_color
    for k, idx in enumerate([i for i in ordered if i != 0], start=1):
        colors[idx] = f"C{(k % 10)}"  # C1..C9 循环

    linestyles = ["-", "--", "-.", ":"]
    markers = [None, "o", "s", "^", "D", "v", "P", "X"]
    styles = {}
    for rank, idx in enumerate(ordered):
        ls = linestyles[rank % len(linestyles)]
        mk = markers[rank % len(markers)]
        styles[idx] = dict(color=colors[idx], linestyle=ls, marker=mk)

    return styles

def plot_line_visible(x, y, label, style, is_main=False, linewidth=2.4):
    """
    overlay 专用画线：描边 + marker + zorder，保证重叠也能看出来。
    同时做防御性处理：确保 x/y 一维且长度一致，避免 matplotlib 报错。
    """
    # --- 防御：强制 1D ---
    x = np.asarray(x).ravel()
    y = np.asarray(y).ravel()

    # --- 防御：长度不一致时裁剪到最短（仅影响 overlay 的平滑曲线，不影响单图） ---
    if x.size != y.size:
        m = min(x.size, y.size)
        if m <= 1:
            return  # 数据太少就不画
        x = x[:m]
        y = y[:m]

    z = 10 if is_main else 6

    # 描边（白色外描边）能让重叠线条更清楚
    effects = [pe.Stroke(linewidth=linewidth + 2.0, foreground="white"), pe.Normal()]

    marker = style.get("marker", None)
    use_marker = marker is not None

    plt.plot(
        x, y,
        label=label,
        linewidth=linewidth,
        color=style["color"],
        linestyle=style["linestyle"],
        marker=marker if use_marker else None,
        markersize=4.0 if use_marker else 0.0,
        markevery=3 if use_marker else None,
        zorder=z,
        path_effects=effects,
    )



# -------------------------
# Main
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--metrics_jsonl", type=str, required=True)
    ap.add_argument("--out_dir", type=str, required=True)

    ap.add_argument("--which", type=str, default="pruned", choices=["dense", "pruned"])

    ap.add_argument(
        "--compare_other_jsonl",
        type=str,
        action="append",
        default=[],
        help="Additional metrics.jsonl paths to overlay in the same plots. Can be used multiple times."
    )
    ap.add_argument(
        "--compare_labels",
        type=str,
        nargs="*",
        default=None,
        help="Legend labels for compare_other_jsonl (same order). If omitted, use file names."
    )

    ap.add_argument("--font_size", type=int, default=12)
    ap.add_argument("--line_color", type=str, default="C0")
    ap.add_argument("--dpi", type=int, default=200)
    ap.add_argument("--max_bins", type=int, default=60)
    ap.add_argument("--topk_layers", type=int, default=10)

    ap.add_argument(
        "--prune_ratio",
        type=float,
        default=None,
        help="Pruning ratio for figure titles. Accepts 3.1 (means 3.1%%) or 0.031 (means 3.1%%)."
    )

    # compare/overlay 专用参数（单图不变）
    ap.add_argument("--max_overlay_curves", type=int, default=6,
                    help="Max number of curves to show in overlay line figures (compare mode only).")
    ap.add_argument("--save_smooth_overlay", action="store_true",
                    help="If set, also save smooth density overlays for histogram-style compare plots.")
    ap.add_argument("--smooth_bins", type=int, default=120,
                    help="Bins used for smooth density overlays (compare mode only).")
    ap.add_argument("--smooth_sigma", type=float, default=2.0,
                    help="Gaussian smoothing sigma in bin-space (compare mode only).")

    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    set_global_style(args.font_size)

    pr_tag = format_pr_tag(args.prune_ratio)

    # ---- collect jsonl specs ----
    specs = []

    main_rows = load_jsonl(args.metrics_jsonl)
    if len(main_rows) == 0:
        raise ValueError("Primary metrics_jsonl is empty.")
    specs.append({
        "path": args.metrics_jsonl,
        "rows": main_rows,
        "which": args.which,
        "label": f"{args.which}",
        "is_main": True,
    })

    compare_paths = args.compare_other_jsonl or []
    compare_labels = args.compare_labels
    if compare_labels is not None and len(compare_labels) != len(compare_paths):
        raise ValueError("Length mismatch: --compare_labels must match number of --compare_other_jsonl entries.")

    for i, p in enumerate(compare_paths):
        r = load_jsonl(p)
        if len(r) == 0:
            raise ValueError(f"Compare jsonl is empty: {p}")
        w = detect_variant(r)
        lab = compare_labels[i] if compare_labels is not None else default_label_from_path(p)
        specs.append({
            "path": p,
            "rows": r,
            "which": w,
            "label": lab,
            "is_main": False,
        })

    is_compare = (len(specs) > 1)

    # ===== 单图保持原逻辑：用 common L；合并模式：每条线用自己的 L =====
    if not is_compare:
        L_use = compute_min_common_L(specs)   # 单图不动
    else:
        L_use = None  # 合并模式不再用 min L（避免 dense 被截断）

    datasets = []
    for spec in specs:
        if not is_compare:
            # 单图：完全不动
            data = parse_metrics(spec["rows"], spec["which"], L_use=L_use)
            data["L_use"] = L_use
        else:
            # 合并：每个数据集单独取可用层数
            keys = get_keys(spec["which"])
            Ls = [int(r[keys["L"]]) for r in spec["rows"] if keys["L"] in r]
            if len(Ls) == 0:
                # fallback：用 curve 长度
                L_i = int(min(len(r[keys["curve"]]) for r in spec["rows"]))
            else:
                L_i = int(min(Ls))
            data = parse_metrics(spec["rows"], spec["which"], L_use=L_i)
            data["L_use"] = L_i

        data["label"] = spec["label"]
        data["which"] = spec["which"]
        data["is_main"] = spec["is_main"]
        data["path"] = spec["path"]
        datasets.append(data)

    # ===== 单图的 x 轴保持原样；合并模式不预先构造 xL/xD（每条线动态构造）=====
    if not is_compare:
        xL = np.arange(1, L_use + 1)       # Layer ID
        xD = np.arange(2, L_use + 1)       # To Layer (delta index)
        Lmax = L_use
    else:
        Lmax = int(max(d["L_use"] for d in datasets))

    # 颜色与线型策略：
    # - 单图：完全不动（仍然只有 args.line_color）
    # - 合并：固定配色（C1,C2,...）+ 线型/marker/描边
    if not is_compare:
        def get_color(idx, is_main):
            return args.line_color
        overlay_styles = None
    else:
        overlay_styles = build_fixed_overlay_styles(datasets, args.line_color)

        def get_color(idx, is_main):
            return overlay_styles[idx]["color"]

    # 合并：选择 Top-K 曲线（避免十几条线过糊）
    if is_compare:
        overlay_idx = select_overlay_indices(datasets, args.max_overlay_curves)
        overlay_set = set(overlay_idx)
    else:
        overlay_idx = None
        overlay_set = None

    # =========================================================
    # Plot A) Mean Margin Curve (+ IQR for main only)
    # 单图不动；合并：每条线用自己的 x 轴（1..L_i）
    # =========================================================
    plt.figure()
    for j, d in enumerate(datasets):
        curves = d["curves"]
        mean = curves.mean(axis=0)

        if not is_compare:
            col = get_color(j, d["is_main"])
            plt.plot(xL, mean, linewidth=2.0, color=col, label=d["label"])
        else:
            if j not in overlay_set:
                continue
            L_i = int(d["L_use"])
            xL_i = np.arange(1, L_i + 1)
            plot_line_visible(
                xL_i, mean,
                label=d["label"],
                style=overlay_styles[j],
                is_main=d["is_main"],
                linewidth=2.6 if d["is_main"] else 2.2
            )

        if d["is_main"]:
            q25 = np.quantile(curves, 0.25, axis=0)
            q75 = np.quantile(curves, 0.75, axis=0)
            if not is_compare:
                plt.fill_between(xL, q25, q75, alpha=0.14, color=get_color(j, d["is_main"]), zorder=1)
            else:
                # 主线 IQR 同样按主线自身 L
                L_i = int(d["L_use"])
                xL_i = np.arange(1, L_i + 1)
                plt.fill_between(xL_i, q25, q75, alpha=0.12, color=get_color(j, d["is_main"]), zorder=1)

    plt.axhline(0.0, color="black", linestyle="--", linewidth=1.0)
    plt.xlabel("Layer ID")
    plt.ylabel("Margin")
    extra = "Overlay" if is_compare else None
    if is_compare and overlay_idx is not None and len(overlay_idx) < len(datasets):
        extra = f"Overlay (Top-{len(overlay_idx)})"
    plt.title(make_title("Mean Margin Curve", pr_tag, extra))
    plt.legend()
    savefig(args.out_dir,
            "mean_margin_curve_overlay.png" if is_compare else f"mean_margin_curve_{args.which}.png",
            args.dpi)

    # =========================================================
    # Plot B) Mean Margin Gain Curve overlay
    # 单图不动；合并：每条线用自己的 x 轴（2..L_i）
    # =========================================================
    plt.figure()
    for j, d in enumerate(datasets):
        if is_compare and (j not in overlay_set):
            continue
        deltas = d["deltas"]
        mean_d = deltas.mean(axis=0)

        if not is_compare:
            col = get_color(j, d["is_main"])
            plt.plot(xD, mean_d, linewidth=2.0, color=col, label=d["label"])
        else:
            L_i = int(d["L_use"])
            xD_i = np.arange(2, L_i + 1)  # length L_i-1
            plot_line_visible(
                xD_i, mean_d,
                label=d["label"],
                style=overlay_styles[j],
                is_main=d["is_main"],
                linewidth=2.6 if d["is_main"] else 2.2
            )

    plt.axhline(0.0, color="black", linestyle="--", linewidth=1.0)
    plt.xlabel("Layer ID")
    plt.ylabel("Margin Gain")
    if not is_compare:
        extra = f"L={L_use}"
    else:
        extra = f"Lmax={Lmax}"
        if overlay_idx is not None and len(overlay_idx) < len(datasets):
            extra = f"{extra}, Top-{len(overlay_idx)}"
    plt.title(make_title("Mean Margin Gain", pr_tag, extra))
    plt.legend()
    savefig(args.out_dir,
            "mean_margin_gain_overlay.png" if is_compare else f"mean_margin_gain_{args.which}.png",
            args.dpi)

    # =========================================================
    # Plot C) Final Margin distribution
    # 单图不动；合并：颜色固定（不灰）
    # =========================================================
    if not is_compare:
        d = datasets[0]
        m_last = d["m_last"]
        p_pos = float((m_last > 0).mean())
        plt.figure()
        plt.hist(m_last, bins=args.max_bins, color=args.line_color, alpha=0.85)
        plt.axvline(0.0, color="black", linestyle="--", linewidth=1.0)
        plt.xlabel("Final Margin")
        plt.ylabel("Count")
        plt.title(make_title("Final Margin Distribution", pr_tag, f"P>0={p_pos:.3f}"))
        savefig(args.out_dir, f"final_margin_hist_{args.which}.png", args.dpi)
    else:
        plt.figure()
        all_vals = []
        for d in datasets:
            v = np.asarray(d["m_last"], dtype=np.float64)
            v = v[np.isfinite(v)]
            if v.size:
                all_vals.append(v)
        xlim = None
        if len(all_vals):
            merged = np.concatenate(all_vals, axis=0)
            xlim = tuple(np.percentile(merged, [0.5, 99.5]))

        for j, d in enumerate(datasets):
            m_last = np.asarray(d["m_last"], dtype=np.float64)
            m_last = m_last[np.isfinite(m_last)]
            if m_last.size == 0:
                continue
            col = get_color(j, d["is_main"])
            plt.hist(m_last, bins=args.max_bins, density=True, histtype="step",
                     linewidth=2.2, color=col, label=d["label"])
        plt.axvline(0.0, color="black", linestyle="--", linewidth=1.0)
        plt.xlabel("Final Margin")
        plt.ylabel("Density")
        plt.title(make_title("Final Margin Distribution", pr_tag, "Density"))
        plt.legend()
        savefig(args.out_dir, "final_margin_hist_overlay.png", args.dpi)

        if args.save_smooth_overlay:
            plt.figure()
            for j, d in enumerate(datasets):
                m_last = np.asarray(d["m_last"], dtype=np.float64)
                xs, ys = smooth_density_curve(m_last, bins=args.smooth_bins,
                                              sigma_bins=args.smooth_sigma, xlim=xlim)
                if xs is None:
                    continue
                plot_line_visible(xs, ys, d["label"], overlay_styles[j], d["is_main"], linewidth=2.2)
            plt.axvline(0.0, color="black", linestyle="--", linewidth=1.0)
            plt.xlabel("Final Margin")
            plt.ylabel("Density")
            plt.title(make_title("Final Margin Density", pr_tag, "Smooth"))
            plt.legend()
            savefig(args.out_dir, "final_margin_kde_overlay.png", args.dpi)

    # =========================================================
    # Plot D) Crossing Layer (l_star) distribution
    # 单图不动；合并：bins 用 Lmax（避免 L_use=None 崩）
    # =========================================================
    if not is_compare:
        d = datasets[0]
        l_star = d["l_star"]
        plt.figure()
        bins = min(args.max_bins, max(10, L_use + 1))
        plt.hist(l_star, bins=bins, color=args.line_color, alpha=0.85)
        plt.xlabel("Crossing Layer")
        plt.ylabel("Count")
        plt.title(make_title("Crossing Layer Distribution", pr_tag, None))
        savefig(args.out_dir, f"crossing_layer_hist_{args.which}.png", args.dpi)
    else:
        plt.figure()

        # 1) 自适应横轴范围：取所有数据的 l_star 的 min/max
        all_ls = []
        for d in datasets:
            v = np.asarray(d["l_star"], dtype=np.int64)
            v = v[np.isfinite(v)]
            if v.size:
                all_ls.append(v)

        if len(all_ls) == 0:
            # 没数据就直接跳过
            plt.close()
        else:
            all_ls = np.concatenate(all_ls, axis=0)
            xmin = int(np.nanmin(all_ls))
            xmax = int(np.nanmax(all_ls))

            # 保险：crossing layer 通常从 1 开始；如果你想强制从 1 开始就保留这行
            xmin = max(1, xmin)

            xs = np.arange(xmin, xmax + 1, dtype=np.int64)
            Lspan = xs.size

            # 2) 计算每个数据集在 [xmin, xmax] 上的 PMF
            pmfs = []
            labels = []
            colors = []
            for j, d in enumerate(datasets):
                l_star = np.asarray(d["l_star"], dtype=np.int64)
                l_star = l_star[np.isfinite(l_star)]
                l_star = l_star[(l_star >= xmin) & (l_star <= xmax)]

                cnt = np.zeros(Lspan, dtype=np.float64)
                if l_star.size > 0:
                    np.add.at(cnt, l_star - xmin, 1.0)

                pmf = cnt / max(1.0, cnt.sum())
                pmfs.append(pmf)
                labels.append(d["label"])
                colors.append(get_color(j, d["is_main"]))

            K = len(pmfs)
            width = 0.7 / max(1, K)
            offsets = (np.arange(K) - (K - 1) / 2.0) * width

            # 3) 画并排柱：加粗边线（edge + linewidth），视觉会明显更清楚
            for j in range(K):
                plt.bar(
                    xs + offsets[j],
                    pmfs[j],
                    width=width,
                    color=colors[j],
                    alpha=0.7,
                    edgecolor=colors[j],   # 边线颜色同柱体
                    linewidth=2.2,         # 你想更粗可以调到 2.8/3.0
                    label=labels[j],
                )

            plt.xlim(xmin - 0.6, xmax + 0.6)
            plt.xlabel("Crossing Layer")
            plt.ylabel("Probability")
            plt.title(make_title("Crossing Layer PMF", pr_tag))
            plt.legend()
            savefig(args.out_dir, "crossing_layer_pmf_dodged.png", args.dpi)

        # plt.figure()
        # bins = min(args.max_bins, max(10, Lmax + 1))
        # for j, d in enumerate(datasets):
        #     l_star = np.asarray(d["l_star"], dtype=np.int64)
        #     col = get_color(j, d["is_main"])
        #     plt.hist(l_star, bins=bins, density=True, histtype="step",
        #              linewidth=2.2, color=col, label=d["label"])
        # plt.xlabel("Crossing Layer")
        # plt.ylabel("Density")
        # plt.title(make_title("Crossing Layer Distribution", pr_tag, "Density"))
        # plt.legend()
        # savefig(args.out_dir, "crossing_layer_hist_overlay.png", args.dpi)

        # if args.save_smooth_overlay:
        #     plt.figure()
        #     all_ls = np.concatenate([np.asarray(dd["l_star"], dtype=np.int64) for dd in datasets], axis=0)
        #     # xmin = int(np.nanmin(all_ls))
        #     # xmax = int(np.nanmax(all_ls))
        #     # for j, d in enumerate(datasets):
        #     #     l_star = np.asarray(d["l_star"], dtype=np.int64)
        #     #     xs, ys = smooth_pmf_curve_int(l_star, xmin=xmin, xmax=xmax, sigma_bins=1.2)
        #     #     if xs is None:
        #     #         continue
        #     #     plot_line_visible(xs, ys, d["label"], overlay_styles[j], d["is_main"], linewidth=2.2)
        #     # plt.xlabel("Crossing Layer")
        #     # plt.ylabel("Probability")
        #     # plt.title(make_title("Crossing Layer PMF", pr_tag, "Smooth"))
        #     # plt.legend()
        #     # savefig(args.out_dir, "crossing_layer_pmf_overlay.png", args.dpi)
        
        #     # 固定横轴：1..Lmax（例如 32）
        #     xmin = 1
        #     xmax = int(Lmax)

        #     for j, d in enumerate(datasets):
        #         l_star = np.asarray(d["l_star"], dtype=np.int64)

        #         # dense 模型 valid_max = 32
        #         # pruned 模型 valid_max = pruned_layers（例如 16），可选：看你是否希望强制“不可达层”为 0
        #         if "pruned_layers" in d and d["pruned_layers"] is not None:
        #             valid_max = int(d["pruned_layers"])
        #         else:
        #             valid_max = int(d.get("dense_layers", Lmax))

        #         xs, ys = smooth_pmf_curve_int(
        #             l_star,
        #             xmin=xmin,
        #             xmax=xmax,
        #             sigma_bins=1.2,
        #             valid_max=valid_max,   # 如果你不想强制置零，就传 None
        #             renorm=True,
        #         )
        #         if xs is None:
        #             continue
        #         plot_line_visible(xs, ys, d["label"], overlay_styles[j], d["is_main"], linewidth=2.2)
        #     plt.xlabel("Crossing Layer")
        #     plt.ylabel("Probability")
        #     plt.title(make_title("Crossing Layer PMF", pr_tag, "Smooth"))
        #     plt.legend()
        #     savefig(args.out_dir, "crossing_layer_pmf_overlay.png", args.dpi)

    # =========================================================
    # Plot E) Correctness vs Final Margin
    # 单图不动；合并：颜色固定
    # =========================================================
    if not is_compare:
        d = datasets[0]
        plt.figure()
        plt.scatter(d["m_last"], d["corr"], s=10, alpha=0.6, color=args.line_color)
        plt.axvline(0.0, color="black", linestyle="--", linewidth=1.0)
        plt.yticks([0, 1], ["Wrong", "Correct"])
        plt.xlabel("Final Margin")
        plt.ylabel("Correctness")
        plt.title(make_title("Correctness vs Final Margin", pr_tag, None))
        savefig(args.out_dir, f"correct_vs_final_margin_{args.which}.png", args.dpi)
    else:
        plt.figure()
        for j, d in enumerate(datasets):
            col = get_color(j, d["is_main"])
            plt.scatter(d["m_last"], d["corr"], s=10, alpha=0.22, color=col, label=d["label"])
        plt.axvline(0.0, color="black", linestyle="--", linewidth=1.0)
        plt.yticks([0, 1], ["Wrong", "Correct"])
        plt.xlabel("Final Margin")
        plt.ylabel("Correctness")
        plt.title(make_title("Correctness vs Final Margin", pr_tag, "Overlay"))
        plt.legend()
        savefig(args.out_dir, "correct_vs_final_margin_overlay.png", args.dpi)

    # =========================================================
    # Plot F) Late-Stage Margin Gain (Sum)
    # 单图不动；合并：颜色固定；smooth 用可辨识画线
    # =========================================================
    metric_name_sum = "Late-Stage Gain (Sum)"
    if not is_compare:
        d = datasets[0]
        x = d["late_sum"]
        mask = np.isfinite(x)
        if mask.any():
            vals = x[mask]
            plt.figure()
            plt.hist(vals, bins=args.max_bins, color=args.line_color, alpha=0.85)
            plt.axvline(np.median(vals), color="black", linestyle="--", linewidth=1.0)
            plt.xlabel(metric_name_sum)
            plt.ylabel("Count")
            plt.title(make_title(f"{metric_name_sum} Distribution", pr_tag, None))
            savefig(args.out_dir, f"late_stage_gain_sum_hist_{args.which}.png", args.dpi)
    else:
        plt.figure()
        any_ok = False
        all_vals = []
        for d in datasets:
            x = np.asarray(d["late_sum"], dtype=np.float64)
            x = x[np.isfinite(x)]
            if x.size:
                all_vals.append(x)
        xlim = None
        if len(all_vals):
            merged = np.concatenate(all_vals, axis=0)
            xlim = tuple(np.percentile(merged, [0.5, 99.5]))

        for j, d in enumerate(datasets):
            x = np.asarray(d["late_sum"], dtype=np.float64)
            x = x[np.isfinite(x)]
            if x.size == 0:
                continue
            any_ok = True
            col = get_color(j, d["is_main"])
            plt.hist(x, bins=args.max_bins, density=True, histtype="step",
                     linewidth=2.2, color=col, label=d["label"])
        if any_ok:
            plt.xlabel(metric_name_sum)
            plt.ylabel("Density")
            plt.title(make_title(f"{metric_name_sum} Distribution", pr_tag, "Density"))
            plt.legend()
            savefig(args.out_dir, "late_stage_gain_sum_hist_overlay.png", args.dpi)
        else:
            plt.close()

        if any_ok and args.save_smooth_overlay:
            plt.figure()
            for j, d in enumerate(datasets):
                x = np.asarray(d["late_sum"], dtype=np.float64)
                xs, ys = smooth_density_curve(x, bins=args.smooth_bins,
                                              sigma_bins=args.smooth_sigma, xlim=xlim)
                if xs is None:
                    continue
                plot_line_visible(xs, ys, d["label"], overlay_styles[j], d["is_main"], linewidth=2.2)
            plt.xlabel(metric_name_sum)
            plt.ylabel("Density")
            plt.title(make_title(f"{metric_name_sum} Density", pr_tag, "Smooth"))
            plt.legend()
            savefig(args.out_dir, "late_stage_gain_sum_kde_overlay.png", args.dpi)

    # =========================================================
    # Plot F2) Late-Stage Margin Gain (Mean)
    # 单图不动；合并：颜色固定；smooth 用可辨识画线
    # =========================================================
    metric_name_mean = "Late-Stage Gain (Mean)"
    if not is_compare:
        d = datasets[0]
        x = d["late_mean"]
        mask = np.isfinite(x)
        if mask.any():
            vals = x[mask]
            plt.figure()
            plt.hist(vals, bins=args.max_bins, color=args.line_color, alpha=0.85)
            plt.axvline(np.median(vals), color="black", linestyle="--", linewidth=1.0)
            plt.xlabel(metric_name_mean)
            plt.ylabel("Count")
            plt.title(make_title(f"{metric_name_mean} Distribution", pr_tag, None))
            savefig(args.out_dir, f"late_stage_gain_mean_hist_{args.which}.png", args.dpi)
    else:
        plt.figure()
        any_ok = False
        all_vals = []
        for d in datasets:
            x = np.asarray(d["late_mean"], dtype=np.float64)
            x = x[np.isfinite(x)]
            if x.size:
                all_vals.append(x)
        xlim = None
        if len(all_vals):
            merged = np.concatenate(all_vals, axis=0)
            xlim = tuple(np.percentile(merged, [0.5, 99.5]))

        for j, d in enumerate(datasets):
            x = np.asarray(d["late_mean"], dtype=np.float64)
            x = x[np.isfinite(x)]
            if x.size == 0:
                continue
            any_ok = True
            col = get_color(j, d["is_main"])
            plt.hist(x, bins=args.max_bins, density=True, histtype="step",
                     linewidth=2.2, color=col, label=d["label"])
        if any_ok:
            plt.xlabel(metric_name_mean)
            plt.ylabel("Density")
            plt.title(make_title(f"{metric_name_mean} Distribution", pr_tag, "Density"))
            plt.legend()
            savefig(args.out_dir, "late_stage_gain_mean_hist_overlay.png", args.dpi)
        else:
            plt.close()

        if any_ok and args.save_smooth_overlay:
            plt.figure()
            for j, d in enumerate(datasets):
                x = np.asarray(d["late_mean"], dtype=np.float64)
                xs, ys = smooth_density_curve(x, bins=args.smooth_bins,
                                              sigma_bins=args.smooth_sigma, xlim=xlim)
                if xs is None:
                    continue
                plot_line_visible(xs, ys, d["label"], overlay_styles[j], d["is_main"], linewidth=2.2)
            plt.xlabel(metric_name_mean)
            plt.ylabel("Density")
            plt.title(make_title(f"{metric_name_mean} Density", pr_tag, "Smooth"))
            plt.legend()
            savefig(args.out_dir, "late_stage_gain_mean_kde_overlay.png", args.dpi)

    # =========================================================
    # Plot G) Top Amplifier Layers Frequency (single only) ——不动
    # =========================================================
    if not is_compare:
        d = datasets[0]
        amp_layers = d["amp_layers"]
        if len(amp_layers) > 0:
            amp_layers = np.array(amp_layers, dtype=np.int32)
            uniq, cnt = np.unique(amp_layers, return_counts=True)
            order = np.argsort(cnt)[::-1]
            uniq = uniq[order]
            cnt = cnt[order]
            topn = min(args.topk_layers, len(uniq))
            uniq_top = uniq[:topn]
            cnt_top = cnt[:topn]
            plt.figure(figsize=(max(6, topn * 0.6), 4))
            plt.bar([str(u) for u in uniq_top], cnt_top, color=args.line_color, alpha=0.85)
            plt.xlabel("Layer ID")
            plt.ylabel("Count")
            plt.title(make_title("Top Amplifier Layers", pr_tag, None))
            savefig(args.out_dir, f"top_amplifier_layers_{args.which}.png", args.dpi)

    # =========================================================
    # Write quick stats（单图保持原样；合并模式会写 Common L: None，不影响单图）
    # =========================================================
    stats_path = os.path.join(args.out_dir, "quick_stats_all.txt")
    with open(stats_path, "w", encoding="utf-8") as wf:
        wf.write(f"Common L: {L_use}\n")
        if pr_tag:
            wf.write(f"Pruning Ratio: {pr_tag}\n")
        wf.write(f"Num Datasets: {len(datasets)}\n\n")

        for d in datasets:
            corr = d["corr"]
            m_last = d["m_last"]
            l_star = d["l_star"]

            wf.write(f"[{d['label']}] which={d['which']} N={d['N']}\n")
            wf.write(f"  Accuracy: {corr.mean():.4f}\n")
            wf.write(f"  P(Final Margin > 0): {(m_last > 0).mean():.4f}\n")
            wf.write(f"  Mean Crossing Layer: {l_star.mean():.4f}\n")

            if np.isfinite(d["late_sum"]).any():
                wf.write(f"  Mean Late-Stage Gain (Sum): {np.nanmean(d['late_sum']):.4f}\n")
            if np.isfinite(d["late_mean"]).any():
                wf.write(f"  Mean Late-Stage Gain (Mean): {np.nanmean(d['late_mean']):.4f}\n")

            wf.write(f"  Post-Cross Persistence: {d['persist'].mean():.4f}\n")
            if np.isfinite(d["pos_ratio"]).any():
                wf.write(f"  Post-Cross Positive Ratio: {np.nanmean(d['pos_ratio']):.4f}\n")
            wf.write("\n")

    print("[Saved plots to]", args.out_dir)
    print("[Saved stats to]", stats_path)


if __name__ == "__main__":
    main()
