#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# 使用tab20颜色映射，提供20种不同的颜色
tab20 = plt.cm.tab20


# -------------------------
# IO
# -------------------------
def load_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def set_global_style(font_size: int):
    plt.rcParams.update({
        "font.size": font_size,
        "axes.titlesize": font_size + 2,
        "axes.labelsize": font_size,
        "xtick.labelsize": font_size - 1,
        "ytick.labelsize": font_size - 1,
        "legend.fontsize": font_size - 1,
    })


def default_label_from_path(p):
    base = os.path.basename(p)
    return os.path.splitext(base)[0]

# -------------------------
# Manual mapping: removed layer ids (0-based) for each series
# -------------------------
def get_survivors_from_removed(L_dense: int, removed_0based: list[int], expect_remain: int | None = None):
    removed = sorted(set(int(x) for x in removed_0based))
    for x in removed:
        if x < 0 or x >= L_dense:
            raise ValueError(f"removed layer idx out of range: {x} (L_dense={L_dense})")

    survivors = [i for i in range(L_dense) if i not in set(removed)]
    survivors.sort()

    if expect_remain is not None and len(survivors) != int(expect_remain):
        raise ValueError(
            f"[Mapping ERROR] expect_remain={expect_remain}, but survivors={len(survivors)}. "
            f"removed={removed}, L_dense={L_dense}"
        )
    return survivors


def align_pruned_mean_to_dense(pruned_mean: np.ndarray, survivors_0based: list[int]):
    if len(pruned_mean) != len(survivors_0based):
        raise ValueError("Length mismatch between pruned_mean and survivors list.")
    x = np.array([i + 1 for i in survivors_0based], dtype=np.int32)  # 1-based
    y = pruned_mean.astype(np.float32)
    return x, y


def build_full_nan(L_dense: int, survivors_0based: list[int], pruned_mean: np.ndarray):
    full = np.full((L_dense,), np.nan, dtype=np.float32)
    for lid, v in zip(survivors_0based, pruned_mean):
        full[lid] = float(v)
    return full


def build_dense_full(L_dense: int, dense_mean: np.ndarray):
    if len(dense_mean) != L_dense:
        raise ValueError("dense_mean length mismatch.")
    return np.asarray(dense_mean, dtype=np.float32).copy()


# -------------------------
# Parse one combined jsonl (dense+pruned in same row)
# -------------------------
def parse_mean_curves_and_deltas_from_combined_jsonl(rows):
    """
    returns:
      dense_mean: [L_dense]
      pruned_mean: [L_pruned]
      dense_delta_mean: [L_dense-1]   (delta for to_layer 2..L_dense)
      pruned_delta_mean: [L_pruned-1] (delta for pruned to_layer 2..L_pruned)
      L_dense, L_pruned
    """
    L_dense_list = [int(r["dense_layers"]) for r in rows if "dense_layers" in r]
    L_prun_list  = [int(r["pruned_layers"]) for r in rows if "pruned_layers" in r]

    if len(L_dense_list) == 0:
        L_dense = min(len(r["m_curve_dense"]) for r in rows if "m_curve_dense" in r)
    else:
        L_dense = int(min(L_dense_list))

    if len(L_prun_list) == 0:
        L_pruned = min(len(r["m_curve_pruned"]) for r in rows if "m_curve_pruned" in r)
    else:
        L_pruned = int(min(L_prun_list))

    dense_curves = np.array([np.array(r["m_curve_dense"][:L_dense], dtype=np.float32) for r in rows])
    pruned_curves = np.array([np.array(r["m_curve_pruned"][:L_pruned], dtype=np.float32) for r in rows])
    dense_mean = dense_curves.mean(axis=0)
    pruned_mean = pruned_curves.mean(axis=0)

    has_dense_delta = any("delta_m_curve_dense" in r for r in rows)
    has_pruned_delta = any("delta_m_curve_pruned" in r for r in rows)

    if has_dense_delta:
        dense_deltas = np.array(
            [np.array(r["delta_m_curve_dense"][:(L_dense - 1)], dtype=np.float32) for r in rows],
            dtype=np.float32
        )
        dense_delta_mean = dense_deltas.mean(axis=0)
    else:
        dense_delta_mean = (dense_mean[1:] - dense_mean[:-1]).astype(np.float32)

    if has_pruned_delta:
        pruned_deltas = np.array(
            [np.array(r["delta_m_curve_pruned"][:(L_pruned - 1)], dtype=np.float32) for r in rows],
            dtype=np.float32
        )
        pruned_delta_mean = pruned_deltas.mean(axis=0)
    else:
        pruned_delta_mean = (pruned_mean[1:] - pruned_mean[:-1]).astype(np.float32)

    return dense_mean, pruned_mean, dense_delta_mean, pruned_delta_mean, L_dense, L_pruned


# -------------------------
# Δm alignment helpers
# -------------------------
def build_dense_delta_full_nan(L_dense: int, dense_delta_mean: np.ndarray):
    full = np.full((L_dense,), np.nan, dtype=np.float32)
    if len(dense_delta_mean) != (L_dense - 1):
        raise ValueError("dense_delta_mean length mismatch.")
    full[1:] = dense_delta_mean  # delta to_layer=2..L
    return full


def build_pruned_delta_full_nan_on_dense(L_dense: int, pruned_delta_mean: np.ndarray, survivors_0based: list[int]):
    """
    pruned_delta_mean[k-1] aligns to dense_to_layer = survivors[k] (0-based index)
    """
    full = np.full((L_dense,), np.nan, dtype=np.float32)

    if len(pruned_delta_mean) != (len(survivors_0based) - 1):
        raise ValueError(
            f"pruned_delta_mean length mismatch: got {len(pruned_delta_mean)}, "
            f"expected {len(survivors_0based) - 1}"
        )

    for k in range(1, len(survivors_0based)):
        dense_to_0 = survivors_0based[k]
        full[dense_to_0] = float(pruned_delta_mean[k - 1])

    return full


def find_positive_runs_on_delta(delta_full_nan: np.ndarray, valid_from_layer=2, min_len=2):
    """
    Find runs where delta>0 on to_layer axis (layer 2..L).
    returns runs in 1-based layer ids: (start_to, end_to) inclusive
    """
    L = len(delta_full_nan)
    runs = []
    in_run = False
    s = None

    for layer in range(valid_from_layer, L + 1):
        v = delta_full_nan[layer - 1]
        ok = np.isfinite(v) and (v > 0)

        if ok and not in_run:
            in_run = True
            s = layer
        elif (not ok) and in_run:
            e = layer - 1
            if (e - s + 1) >= min_len:
                runs.append((s, e))
            in_run = False
            s = None

    if in_run:
        e = L
        if (e - s + 1) >= min_len:
            runs.append((s, e))

    return runs


# -------------------------
# Plots
# -------------------------
def plot_aligned_step(out_dir, name, dense_mean, pruned_series, L_dense, dpi):
    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    x_dense = np.arange(1, L_dense + 1)

    ax.plot(x_dense, dense_mean, linewidth=2.8, label="dense")

    color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', [])
    if not color_cycle:
        color_cycle = ["C0","C1","C2","C3","C4","C5","C6","C7","C8","C9"]

    for idx, s in enumerate(pruned_series):
        col = color_cycle[(idx + 1) % len(color_cycle)]
        full = s["full_nan"].copy()

        last = np.nan
        for i in range(L_dense):
            if np.isfinite(full[i]):
                last = full[i]
            else:
                full[i] = last

        if not np.isfinite(full[0]):
            first_idx = np.where(np.isfinite(full))[0]
            if first_idx.size > 0:
                full[:first_idx[0]] = full[first_idx[0]]

        ax.step(x_dense, full, where="post", linewidth=2.2, color=col, label=s["label"])
        ax.scatter(s["x"], s["y"], s=34, color=col, edgecolor="white", linewidth=0.6, zorder=5)

    ax.axhline(0.0, color="black", linestyle="--", linewidth=1.0)
    ax.set_xlabel("Dense Layer ID (aligned)")
    ax.set_ylabel("Mean Margin")
    ax.set_title("Mean Margin Curve (Aligned) | Step Plot (forward-fill) + block markers")
    ax.legend()
    plt.tight_layout()
    plt.savefig(f"{out_dir}/{name}", dpi=dpi)
    plt.close()


def plot_aligned_scatter_segment(out_dir, name, dense_mean, pruned_series, L_dense, dpi):
    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    x_dense = np.arange(1, L_dense + 1)

    ax.plot(x_dense, dense_mean, linewidth=3.5, label="dense")

    tab20_colors = [tab20(i) for i in range(20)]

    for idx, s in enumerate(pruned_series):
        col = tab20_colors[(idx+1) % 20]
        x = np.asarray(s["x"], dtype=np.int32)
        y = np.asarray(s["y"], dtype=np.float32)

        ax.scatter(x, y, s=38, color=col, edgecolor="white", linewidth=0.6,
                   label=s["label"], zorder=5)

        for i in range(len(x) - 1):
            if (x[i + 1] - x[i]) == 1:
                ax.plot([x[i], x[i + 1]], [y[i], y[i + 1]],
                        linewidth=3, color=col, linestyle='-', alpha=1, zorder=4)
            else:
                ax.plot([x[i], x[i + 1]], [y[i], y[i + 1]],
                        linewidth=2.5, color=col, linestyle='--', alpha=0.7, zorder=3)

    ax.axhline(0.0, color="black", linestyle="--", linewidth=1.0)
    ax.set_xlabel("Layer ID", fontsize=22)
    ax.set_ylabel("Decision Margin", fontsize=22)
    ax.set_title("Llama3-8B | Hellaswag", fontsize=22)
    ax.tick_params(axis='x', labelsize=20)
    ax.tick_params(axis='y', labelsize=20)
    ax.legend(fontsize=14)
    plt.tight_layout()
    plt.savefig(f"{out_dir}/{name}", dpi=dpi)
    plt.close()


def plot_aligned_heatmap(out_dir, name, heat_rows, L_dense, dpi):
    """
    Heatmap includes:
      - dense row (always)
      - pruned rows (compare only, unless include_main_pruned=True)
      - square cells
      - horizontal separators
      - red rectangles for contiguous Δm>0 (start shifted left by 1 layer)
    """
    dense_rows = [r for r in heat_rows if r.get("kind") == "dense"]
    pruned_rows = [r for r in heat_rows if r.get("kind") == "pruned"]
    pruned_rows = sorted(pruned_rows, key=lambda r: r["L_pruned"])
    rows_sorted = pruned_rows + dense_rows

    K = len(rows_sorted)
    if K == 0:
        return

    mat = np.stack([r["full_nan"] for r in rows_sorted], axis=0).astype(np.float32)
    labels = [r["label"] for r in rows_sorted]
    delta_fulls = [r["delta_full_nan"] for r in rows_sorted]

    mat_m = np.ma.masked_invalid(mat)

    cell = 0.35
    fig_w = max(10, cell * L_dense)
    fig_h = max(4.5, cell * K + 1.5)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))

    cmap = plt.get_cmap("viridis").copy()
    cmap.set_bad(color=(1, 1, 1, 0.0))

    x_edges = np.arange(0, L_dense + 1)
    y_edges = np.arange(0, K + 1)

    im = ax.pcolormesh(x_edges, y_edges, mat_m, cmap=cmap, shading="flat")
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', shrink=0.5)
    cbar.set_label("Mean Margin")

    ax.set_aspect("equal")

    ax.set_xticks(np.arange(L_dense) + 0.5)
    show_every = 2
    ax.set_xticklabels([str(i + 1) if (i % show_every == 0) else "" for i in range(L_dense)], rotation=0)

    ax.set_yticks(np.arange(K) + 0.5)
    ax.set_yticklabels(labels)

    for y in range(1, K):
        ax.axhline(y, color="white", linewidth=2.0, alpha=0.9)

    ax.set_xlabel("Layer ID")
    ax.set_title("Mean Margin Heatmap")

    plt.tight_layout()
    plt.savefig(f"{out_dir}/{name}", dpi=dpi)
    plt.close()


# -------------------------
# Main
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--metrics_jsonl", type=str, required=True,
                    help="Combined jsonl. Used to define the dense baseline (dense row).")
    ap.add_argument("--out_dir", type=str, required=True)

    ap.add_argument("--compare_other_jsonl", type=str, action="append", default=[],
                    help="Overlay pruned settings (each combined jsonl).")
    ap.add_argument("--compare_labels", type=str, nargs="*", default=None,
                    help="Labels for compare_other_jsonl.")

    ap.add_argument("--include_main_pruned", action="store_true",
                    help="If set, also include the pruned row from --metrics_jsonl in heatmap/plots. Default OFF.")

    ap.add_argument("--font_size", type=int, default=14)
    ap.add_argument("--dpi", type=int, default=250)
    ap.add_argument("--remove_ids_are_1based", action="store_true")

    ap.add_argument(
        "--compare_removed",
        type=str,
        action="append",
        default=[],
        help='Removed layers for each --compare_other_jsonl. '
            'Use comma-separated 0-based ids, e.g. "16,18" or "17". '
            'Must have same count as compare_other_jsonl.'
    )


    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    set_global_style(args.font_size)

    compare_paths = args.compare_other_jsonl or []
    compare_labels = args.compare_labels
    if compare_labels is not None and len(compare_labels) != len(compare_paths):
        raise ValueError("Length mismatch: --compare_labels must match number of --compare_other_jsonl entries.")

    if args.remove_ids_are_1based:
        global REMAIN_TO_REMOVED_0BASED
        REMAIN_TO_REMOVED_0BASED = {k: [x - 1 for x in v] for k, v in REMAIN_TO_REMOVED_0BASED.items()}

    # ---- load main (defines dense baseline) ----
    rows_main = load_jsonl(args.metrics_jsonl)
    dense_mean_main, pruned_mean_main, dense_delta_main, pruned_delta_main, L_dense_main, L_pruned_main = \
        parse_mean_curves_and_deltas_from_combined_jsonl(rows_main)

    L_dense = int(L_dense_main)
    if L_dense != 32:
        print(f"[WARN] dense_layers from file = {L_dense}. Mapping assumes 32, may be invalid.")

    dense_mean = dense_mean_main
    dense_delta_full = build_dense_delta_full_nan(L_dense, dense_delta_main)

    # ---- pruned specs: ONLY compare by default (this fixes your "extra row") ----
    specs = []
    if args.include_main_pruned:
        specs.append({
            "label": default_label_from_path(args.metrics_jsonl),
            "pruned_mean": pruned_mean_main,
            "pruned_delta": pruned_delta_main,
            "L_pruned": int(L_pruned_main),
        })

    def _parse_removed_str(s: str) -> list[int]:
        s = s.strip()
        if s == "":
            return []
        return [int(x) for x in s.split(",") if x.strip() != ""]

    for i, p in enumerate(compare_paths):
        rows = load_jsonl(p)
        d_mean, p_mean, d_delta, p_delta, Ld, Lp = parse_mean_curves_and_deltas_from_combined_jsonl(rows)
        lab = compare_labels[i] if compare_labels is not None else default_label_from_path(p)

        removed_i = _parse_removed_str(args.compare_removed[i])

        specs.append({
            "label": lab,
            "pruned_mean": p_mean,
            "pruned_delta": p_delta,
            "L_pruned": int(Lp),
            "removed": removed_i,   # <<< 关键新增
        })

    pruned_series = []
    heat_rows = []

    # dense row always present and label is exactly "dense"
    heat_rows.append({
        "kind": "dense",
        "label": "dense",
        "L_pruned": L_dense,
        "full_nan": build_dense_full(L_dense, dense_mean),
        "delta_full_nan": dense_delta_full,
    })

    for sp in specs:
        Lp = int(sp["L_pruned"])
        pruned_mean_i = np.asarray(sp["pruned_mean"], dtype=np.float32)
        pruned_delta_i = np.asarray(sp["pruned_delta"], dtype=np.float32)

        # 由 removed 列表推 survivors；同时用 Lp(=pruned_layers) 做一致性校验
        survivors = get_survivors_from_removed(
            L_dense=L_dense,
            removed_0based=sp["removed"],
            expect_remain=Lp
        )
        if survivors is None:
            raise ValueError(
                f"No mapping found for pruned_layers(remain)={Lp}. "
                f"Supported remain: {sorted(REMAIN_TO_REMOVED_0BASED.keys())}."
            )

        x, y = align_pruned_mean_to_dense(pruned_mean_i, survivors_0based=survivors)
        full_nan = build_full_nan(L_dense, survivors, pruned_mean_i)
        delta_full_nan = build_pruned_delta_full_nan_on_dense(L_dense, pruned_delta_i, survivors_0based=survivors)

        pruned_series.append({
            "label": sp["label"],
            "L_pruned": Lp,
            "x": x,
            "y": y,
            "full_nan": full_nan,
        })

        heat_rows.append({
            "kind": "pruned",
            "label": sp["label"],
            "L_pruned": Lp,
            "full_nan": full_nan,
            "delta_full_nan": delta_full_nan,
        })

        print(f"[OK] {sp['label']}: pruned_layers={Lp}, survivors(min,max)={min(survivors)}..{max(survivors)}")

    # ---- plots ----
    plot_aligned_scatter_segment(
        out_dir=args.out_dir,
        name="mean_margin_aligned_scatter_segment.png",
        dense_mean=dense_mean,
        pruned_series=pruned_series,
        L_dense=L_dense,
        dpi=args.dpi
    )

    # plot_aligned_step(
    #     out_dir=args.out_dir,
    #     name="mean_margin_aligned_step.png",
    #     dense_mean=dense_mean,
    #     pruned_series=pruned_series,
    #     L_dense=L_dense,
    #     dpi=args.dpi
    # )

    # plot_aligned_heatmap(
    #     out_dir=args.out_dir,
    #     name="mean_margin_aligned_heatmap.png",
    #     heat_rows=heat_rows,
    #     L_dense=L_dense,
    #     dpi=args.dpi
    # )

    print("[Saved aligned mean margin plots to]", args.out_dir)


if __name__ == "__main__":
    main()
