#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
# 使用tab20颜色映射，提供20种不同的颜色
tab20 = plt.cm.tab20


# -------------------------
# IO
# -------------------------
def load_jsonl(path):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def set_global_style(font_size: int):
    plt.rcParams.update({
        "font.size": font_size,
        "axes.titlesize": font_size + 2,
        "axes.labelsize": font_size,
        "xtick.labelsize": font_size - 1,
        "ytick.labelsize": font_size - 1,
        "legend.fontsize": font_size - 1,
    })


def default_label_from_path(p):
    base = os.path.basename(p)
    return os.path.splitext(base)[0]

# -------------------------------------------------------------------------
# -------------------------llama3-8b arc_challenge-------------------------
REMAIN_TO_REMOVED_0BASED = {
    31: [24],
    30: [24,23],
    29: [24,23,22],
    28: [24,23,22,21],
    27: [24,23,22,21,19],
    26: [24,23,22,21,19,20],
    25: [24,23,22,21,19,20,18],
    24: [24,23,25,26,27,28,22,21],
    23: [24,23,25,26,27,28,22,21,19],
    22: [24,23,25,26,27,28,22,21,19,20],
    21: [24,23,25,26,27,28,22,21,19,20,18],
    20: [24,23,25,26,27,28,22,21,19,20,18,17],
    19: [24,23,25,26,27,28,22,21,19,20,18,17,10],
    18: [24,23,25,26,27,28,22,21,19,20,18,17,10,2],
    17: [24,23,25,26,27,28,22,21,19,20,18,17,10,2,11],
    16: [24,23,25,26,27,28,22,21,19,20,18,17,10,2,11,9],
}
# -------------------------llama3-8b  mmlu-------------------------
# REMAIN_TO_REMOVED_0BASED = {
#     31: [23],
#     30: [23, 22],
#     29: [23,22,21],
#     28: [23,22,21,19],
#     27: [23,22,21,19,20],
#     26: [23,22,21,19,20,18],
#     25: [23,22,21,19,20,18,27],
#     24: [23,24,25,26,27,22,28,21],
#     23: [23,24,25,26,27,22,28,21,19],
#     22: [23,24,25,26,27,22,28,21,19,20],
#     21: [23,24,25,26,27,22,28,21,19,20,18],
#     20: [23,24,25,26,27,22,28,21,19,20,18,17],
#     19: [23,24,25,26,27,22,28,21,19,20,18,17,10],
#     18: [23,24,25,26,27,22,28,21,19,20,18,17,10,9],
#     17: [23,24,25,26,27,22,28,21,19,20,18,17,10,9,12],
#     16: [23,24,25,26,27,22,28,21,19,20,18,17,10,9,12,13],
# }
# -------------------------------------------------------------------------
# -------------------------llama2-7b arc_challenge-------------------------
# REMAIN_TO_REMOVED_0BASED = {
#     31: [25],
#     30: [25,24],
#     29: [25,24,23],
#     28: [25,24,23,21],
#     27: [25,24,23,21,20],
#     26: [25,24,23,21,20,26],
#     25: [25,24,23,21,20,26,19],
#     24: [25,24,23,21,20,26,19,22],
#     23: [25,24,23,26,21,20,27,28,29],
#     22: [25,24,23,26,21,20,27,28,29,19],
#     21: [25,24,23,26,21,20,27,28,29,19,22],
#     20: [25,24,23,26,21,20,27,28,29,19,22,14],
#     19: [25,24,23,26,21,20,27,28,29,19,22,14,12],
#     18: [25,24,23,26,21,20,27,28,29,19,22,14,12,10],
#     17: [25,24,23,26,21,20,27,28,29,19,22,14,12,10,16],
#     16: [25,24,23,26,21,20,27,28,29,19,22,14,12,10,16,15],
# }
# -------------------------------------------------------------------------
# -------------------------qwen3-4b arc_challenge-------------------------
# REMAIN_TO_REMOVED_0BASED = {
#     35: [32],
#     34: [32,31],
#     33: [32,31,30],
#     32: [32,31,30,2],
#     31: [32,31,30,2,29],
#     30: [32,31,30,2,29,26],
#     29: [32,31,30,2,29,26,1],
#     28: [32,31,30,2,29,26,1,28],
#     27: [32,31,30,2,29,26,1,28,27],
#     26: [32,31,30,2,29,26,1,28,27,25],
#     25: [32,31,30,2,29,26,1,28,27,25,24],
#     24: [32,31,30,2,29,26,1,28,27,25,24,20],
#     23: [32,31,30,2,29,26,1,28,27,25,24,20,19],
#     22: [32,31,30,2,29,26,1,28,27,25,24,20,19,7],
#     21: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18],
#     20: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8],
#     19: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17],
#     18: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17,21],
#     17: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17,21,22],
#     16: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17,21,22,23],
# }

def get_survivors_dense_ids(L_dense: int, remain_layers: int):
    print(REMAIN_TO_REMOVED_0BASED)
    if remain_layers not in REMAIN_TO_REMOVED_0BASED:
        return None
    removed = set(REMAIN_TO_REMOVED_0BASED[remain_layers])
    survivors = [i for i in range(L_dense) if i not in removed]
    survivors.sort()

    if len(survivors) != remain_layers:
        raise ValueError(
            f"[Mapping ERROR] remain={remain_layers}, survivors={len(survivors)} != remain. "
            f"L_dense={L_dense}, removed={len(removed)}. Check your remove-layer list base (0/1-based)."
        )
    return survivors


def align_pruned_mean_to_dense(pruned_mean: np.ndarray, survivors_0based: list[int]):
    if len(pruned_mean) != len(survivors_0based):
        raise ValueError("Length mismatch between pruned_mean and survivors list.")
    x = np.array([i + 1 for i in survivors_0based], dtype=np.int32)  # 1-based
    y = pruned_mean.astype(np.float32)
    return x, y


def build_full_nan(L_dense: int, survivors_0based: list[int], pruned_mean: np.ndarray):
    full = np.full((L_dense,), np.nan, dtype=np.float32)
    for lid, v in zip(survivors_0based, pruned_mean):
        full[lid] = float(v)
    return full


def build_dense_full(L_dense: int, dense_mean: np.ndarray):
    if len(dense_mean) != L_dense:
        raise ValueError("dense_mean length mismatch.")
    return np.asarray(dense_mean, dtype=np.float32).copy()


# -------------------------
# Parse one combined jsonl (dense+pruned in same row)
# -------------------------
def parse_mean_curves_and_deltas_from_combined_jsonl(rows):
    """
    returns:
      dense_mean: [L_dense]
      pruned_mean: [L_pruned]
      dense_delta_mean: [L_dense-1]   (delta for to_layer 2..L_dense)
      pruned_delta_mean: [L_pruned-1] (delta for pruned to_layer 2..L_pruned)
      L_dense, L_pruned
    """
    L_dense_list = [int(r["dense_layers"]) for r in rows if "dense_layers" in r]
    L_prun_list  = [int(r["pruned_layers"]) for r in rows if "pruned_layers" in r]

    if len(L_dense_list) == 0:
        L_dense = min(len(r["m_curve_dense"]) for r in rows if "m_curve_dense" in r)
    else:
        L_dense = int(min(L_dense_list))

    if len(L_prun_list) == 0:
        L_pruned = max(len(r["m_curve_pruned"]) for r in rows if "m_curve_pruned" in r)
    else:
        L_pruned = int(max(L_prun_list))

    def _to_float_array_with_nan(x, L):
        """
        Convert a list-like x to float32 array of length L.
        Any non-finite values (nan/inf) or None are converted to np.nan.
        """
        arr = np.array(x[:L], dtype=np.float32)
        # 将 inf/-inf 也处理掉
        arr[~np.isfinite(arr)] = np.nan
        return arr

    # ---- build curves (nan-safe) ----
    dense_curves = np.array(
        [_to_float_array_with_nan(r["m_curve_dense"], L_dense) for r in rows if "m_curve_dense" in r],
        dtype=np.float32
    )
    pruned_curves = np.array(
        [_to_float_array_with_nan(r["m_curve_pruned"], L_pruned) for r in rows if "m_curve_pruned" in r],
        dtype=np.float32
    )

    # ---- nan-robust mean ----
    dense_mean = np.nanmean(dense_curves, axis=0)
    pruned_mean = np.nanmean(pruned_curves, axis=0)

    has_dense_delta = any("delta_m_curve_dense" in r for r in rows)
    has_pruned_delta = any("delta_m_curve_pruned" in r for r in rows)

    if has_dense_delta:
        dense_deltas = np.array(
            [np.array(r["delta_m_curve_dense"][:(L_dense - 1)], dtype=np.float32) for r in rows],
            dtype=np.float32
        )
        dense_delta_mean = np.nanmean(dense_deltas, axis=0)
    else:
        dense_delta_mean = (dense_mean[1:] - dense_mean[:-1]).astype(np.float32)

    if has_pruned_delta:
        pruned_deltas = np.array(
            [np.array(r["delta_m_curve_pruned"][:(L_pruned - 1)], dtype=np.float32) for r in rows],
            dtype=np.float32
        )
        pruned_delta_mean = np.nanmean(pruned_deltas, axis=0)
    else:
        pruned_delta_mean = (pruned_mean[1:] - pruned_mean[:-1]).astype(np.float32)

    return dense_mean, pruned_mean, dense_delta_mean, pruned_delta_mean, L_dense, L_pruned


# -------------------------
# Δm alignment helpers
# -------------------------
def build_dense_delta_full_nan(L_dense: int, dense_delta_mean: np.ndarray):
    full = np.full((L_dense,), np.nan, dtype=np.float32)
    if len(dense_delta_mean) != (L_dense - 1):
        raise ValueError("dense_delta_mean length mismatch.")
    full[1:] = dense_delta_mean  # delta to_layer=2..L
    return full


def build_pruned_delta_full_nan_on_dense(L_dense: int, pruned_delta_mean: np.ndarray, survivors_0based: list[int]):
    """
    pruned_delta_mean[k-1] aligns to dense_to_layer = survivors[k] (0-based index)
    """
    full = np.full((L_dense,), np.nan, dtype=np.float32)

    if len(pruned_delta_mean) != (len(survivors_0based) - 1):
        raise ValueError(
            f"pruned_delta_mean length mismatch: got {len(pruned_delta_mean)}, "
            f"expected {len(survivors_0based) - 1}"
        )

    for k in range(1, len(survivors_0based)):
        dense_to_0 = survivors_0based[k]
        full[dense_to_0] = float(pruned_delta_mean[k - 1])

    return full


def find_positive_runs_on_delta(delta_full_nan: np.ndarray, valid_from_layer=2, min_len=2):
    """
    Find runs where delta>0 on to_layer axis (layer 2..L).
    returns runs in 1-based layer ids: (start_to, end_to) inclusive
    """
    L = len(delta_full_nan)
    runs = []
    in_run = False
    s = None

    for layer in range(valid_from_layer, L + 1):
        v = delta_full_nan[layer - 1]
        ok = np.isfinite(v) and (v > 0)

        if ok and not in_run:
            in_run = True
            s = layer
        elif (not ok) and in_run:
            e = layer - 1
            if (e - s + 1) >= min_len:
                runs.append((s, e))
            in_run = False
            s = None

    if in_run:
        e = L
        if (e - s + 1) >= min_len:
            runs.append((s, e))

    return runs


# -------------------------
# Plots
# -------------------------
def plot_aligned_step(out_dir, name, dense_mean, pruned_series, L_dense, dpi):
    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    x_dense = np.arange(1, L_dense + 1)

    ax.plot(x_dense, dense_mean, linewidth=2.8, label="dense")

    color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', [])
    if not color_cycle:
        color_cycle = ["C0","C1","C2","C3","C4","C5","C6","C7","C8","C9"]

    for idx, s in enumerate(pruned_series):
        col = color_cycle[(idx + 1) % len(color_cycle)]
        full = s["full_nan"].copy()

        last = np.nan
        for i in range(L_dense):
            if np.isfinite(full[i]):
                last = full[i]
            else:
                full[i] = last

        if not np.isfinite(full[0]):
            first_idx = np.where(np.isfinite(full))[0]
            if first_idx.size > 0:
                full[:first_idx[0]] = full[first_idx[0]]

        ax.step(x_dense, full, where="post", linewidth=2.2, color=col, label=s["label"])
        ax.scatter(s["x"], s["y"], s=34, color=col, edgecolor="white", linewidth=0.6, zorder=5)

    ax.axhline(0.0, color="black", linestyle="--", linewidth=1.0)
    ax.set_xlabel("Dense Layer ID (aligned)")
    ax.set_ylabel("Mean Margin")
    ax.set_title("Mean Margin Curve (Aligned) | Step Plot (forward-fill) + block markers")
    ax.legend()
    plt.tight_layout()
    plt.savefig(f"{out_dir}/{name}", dpi=dpi)
    plt.close()


def plot_aligned_scatter_segment(out_dir, name, dense_mean, pruned_series, L_dense, dpi):
    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    x_dense = np.arange(1, L_dense + 1)

    ax.plot(x_dense, dense_mean, linewidth=3, label="dense") # dense

    # 获取tab20颜色映射中的颜色（20种）
    tab20_colors = [tab20(i) for i in range(20)]

    for idx, s in enumerate(pruned_series):
        col = tab20_colors[(idx+1) % 20]
        x = np.asarray(s["x"], dtype=np.int32)
        y = np.asarray(s["y"], dtype=np.float32)

        ax.scatter(x, y, s=150, color=col, edgecolor="white", linewidth=0.6,
                   label=s["label"], zorder=5)

        for i in range(len(x) - 1):
            if (x[i + 1] - x[i]) == 1:
                ax.plot([x[i], x[i + 1]], [y[i], y[i + 1]],
                        linewidth=3.0, color=col, linestyle='-', alpha=0.9, zorder=4)
            else:
                ax.plot([x[i], x[i + 1]], [y[i], y[i + 1]],
                        linewidth=1.6, color=col, linestyle='--', alpha=0.7, zorder=3)

    ax.axhline(0.0, color="black", linestyle="--", linewidth=1.0)
    ax.set_xlabel("Layer ID", fontsize=22)
    ax.set_ylabel("Decision Margin", fontsize=22)
    ax.set_title("Llama3-8B | ARC-Challenge", fontsize=22)
    ax.legend(loc='upper left', ncol=2, fontsize=14) 
    ax.tick_params(axis='x', labelsize=20)
    ax.tick_params(axis='y', labelsize=20)
    plt.tight_layout()
    plt.savefig(f"{out_dir}/{name}", dpi=dpi)
    plt.close()


def plot_aligned_heatmap(out_dir, name, heat_rows, L_dense, dpi):
    """
    Heatmap includes:
      - dense row (always)
      - pruned rows (compare only, unless include_main_pruned=True)
      - square cells
      - horizontal separators
      - red rectangles for contiguous Δm>0 (start shifted left by 1 layer)
    """
    dense_rows = [r for r in heat_rows if r.get("kind") == "dense"]
    pruned_rows = [r for r in heat_rows if r.get("kind") == "pruned"]
    pruned_rows = sorted(pruned_rows, key=lambda r: r["L_pruned"])
    rows_sorted = pruned_rows + dense_rows

    K = len(rows_sorted)
    if K == 0:
        return

    mat = np.stack([r["full_nan"] for r in rows_sorted], axis=0).astype(np.float32)
    labels = [r["label"] for r in rows_sorted]
    delta_fulls = [r["delta_full_nan"] for r in rows_sorted]

    mat_m = np.ma.masked_invalid(mat)

    cell = 0.35
    fig_w = max(10, cell * L_dense)
    fig_h = max(4.5, cell * K + 1.5)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))

    cmap = plt.get_cmap("viridis").copy()
    cmap.set_bad(color=(1, 1, 1, 0.0))

    x_edges = np.arange(0, L_dense + 1)
    y_edges = np.arange(0, K + 1)

    im = ax.pcolormesh(x_edges, y_edges, mat_m, cmap=cmap, shading="flat")
    cbar = fig.colorbar(im, ax=ax, orientation='horizontal', shrink=0.5)
    cbar.set_label("Mean Margin")

    ax.set_aspect("equal")

    ax.set_xticks(np.arange(L_dense) + 0.5)
    show_every = 2
    ax.set_xticklabels([str(i + 1) if (i % show_every == 0) else "" for i in range(L_dense)], rotation=0)

    ax.set_yticks(np.arange(K) + 0.5)
    ax.set_yticklabels(labels)

    for y in range(1, K):
        ax.axhline(y, color="white", linewidth=2.0, alpha=0.9)

    ax.set_xlabel("Layer ID")
    ax.set_title("Mean Margin Heatmap")

    plt.tight_layout()
    plt.savefig(f"{out_dir}/{name}", dpi=dpi)
    plt.close()


# -------------------------
# Main
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--metrics_jsonl", type=str, required=True,
                    help="Combined jsonl. Used to define the dense baseline (dense row).")
    ap.add_argument("--out_dir", type=str, required=True)

    ap.add_argument("--compare_other_jsonl", type=str, action="append", default=[],
                    help="Overlay pruned settings (each combined jsonl).")
    ap.add_argument("--compare_labels", type=str, nargs="*", default=None,
                    help="Labels for compare_other_jsonl.")

    ap.add_argument("--include_main_pruned", action="store_true",
                    help="If set, also include the pruned row from --metrics_jsonl in heatmap/plots. Default OFF.")

    ap.add_argument("--font_size", type=int, default=14)
    ap.add_argument("--dpi", type=int, default=250)
    ap.add_argument("--remove_ids_are_1based", action="store_true")

    args = ap.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)
    set_global_style(args.font_size)

    compare_paths = args.compare_other_jsonl or []
    compare_labels = args.compare_labels
    if compare_labels is not None and len(compare_labels) != len(compare_paths):
        raise ValueError("Length mismatch: --compare_labels must match number of --compare_other_jsonl entries.")

    if args.remove_ids_are_1based:
        global REMAIN_TO_REMOVED_0BASED
        REMAIN_TO_REMOVED_0BASED = {k: [x - 1 for x in v] for k, v in REMAIN_TO_REMOVED_0BASED.items()}

    # ---- load main (defines dense baseline) ----
    rows_main = load_jsonl(args.metrics_jsonl)
    dense_mean_main, pruned_mean_main, dense_delta_main, pruned_delta_main, L_dense_main, L_pruned_main = \
        parse_mean_curves_and_deltas_from_combined_jsonl(rows_main)
    

    L_dense = int(L_dense_main)
    if L_dense != 32:
        print(f"[WARN] dense_layers from file = {L_dense}. Mapping assumes 32, may be invalid.")

    dense_mean = dense_mean_main
    dense_delta_full = build_dense_delta_full_nan(L_dense, dense_delta_main)

    # ---- pruned specs: ONLY compare by default (this fixes your "extra row") ----
    specs = []
    if args.include_main_pruned:
        specs.append({
            "label": default_label_from_path(args.metrics_jsonl),
            "pruned_mean": pruned_mean_main,
            "pruned_delta": pruned_delta_main,
            "L_pruned": int(L_pruned_main),
        })

    for i, p in enumerate(compare_paths):
        rows = load_jsonl(p)
        d_mean, p_mean, d_delta, p_delta, Ld, Lp = parse_mean_curves_and_deltas_from_combined_jsonl(rows)
        lab = compare_labels[i] if compare_labels is not None else default_label_from_path(p)
        specs.append({
            "label": lab,
            "pruned_mean": p_mean,
            "pruned_delta": p_delta,
            "L_pruned": int(Lp),
        })

    pruned_series = []
    heat_rows = []

    # dense row always present and label is exactly "dense"
    heat_rows.append({
        "kind": "dense",
        "label": "dense",
        "L_pruned": L_dense,
        "full_nan": build_dense_full(L_dense, dense_mean),
        "delta_full_nan": dense_delta_full,
    })

    for sp in specs:
        Lp = int(sp["L_pruned"])
        pruned_mean_i = np.asarray(sp["pruned_mean"], dtype=np.float32)
        pruned_delta_i = np.asarray(sp["pruned_delta"], dtype=np.float32)

        survivors = get_survivors_dense_ids(L_dense=L_dense, remain_layers=Lp)
        if survivors is None:
            raise ValueError(
                f"No mapping found for pruned_layers(remain)={Lp}. "
                f"Supported remain: {sorted(REMAIN_TO_REMOVED_0BASED.keys())}."
            )

        x, y = align_pruned_mean_to_dense(pruned_mean_i, survivors_0based=survivors)
        full_nan = build_full_nan(L_dense, survivors, pruned_mean_i)
        delta_full_nan = build_pruned_delta_full_nan_on_dense(L_dense, pruned_delta_i, survivors_0based=survivors)

        pruned_series.append({
            "label": sp["label"],
            "L_pruned": Lp,
            "x": x,
            "y": y,
            "full_nan": full_nan,
        })

        heat_rows.append({
            "kind": "pruned",
            "label": sp["label"],
            "L_pruned": Lp,
            "full_nan": full_nan,
            "delta_full_nan": delta_full_nan,
        })

        print(f"[OK] {sp['label']}: pruned_layers={Lp}, survivors(min,max)={min(survivors)}..{max(survivors)}")

    # ---- plots ----
    plot_aligned_scatter_segment(
        out_dir=args.out_dir,
        name="mean_margin_aligned_scatter_segment.png",
        dense_mean=dense_mean,
        pruned_series=pruned_series,
        L_dense=L_dense,
        dpi=args.dpi
    )

    # plot_aligned_step(
    #     out_dir=args.out_dir,
    #     name="mean_margin_aligned_step.png",
    #     dense_mean=dense_mean,
    #     pruned_series=pruned_series,
    #     L_dense=L_dense,
    #     dpi=args.dpi
    # )

    # plot_aligned_heatmap(
    #     out_dir=args.out_dir,
    #     name="mean_margin_aligned_heatmap.png",
    #     heat_rows=heat_rows,
    #     L_dense=L_dense,
    #     dpi=args.dpi
    # )

    print("[Saved aligned mean margin plots to]", args.out_dir)


if __name__ == "__main__":
    main()
