#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import argparse

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


def safe_token_label(s: str, max_len: int = 18) -> str:
    if s is None:
        return ""
    s = str(s).replace("\\n", "↵")
    s = s.replace("\t", "⇥")
    # 避免 label 过长
    if len(s) > max_len:
        return s[: max_len - 1] + "…"
    return s


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--csv", required=True, help="trace_token_contrib.csv path")
    ap.add_argument("--out", required=True, help="output figure path, e.g., out.png")
    ap.add_argument("--metric", default="attn_mean", choices=["attn_mean", "attn_max"])
    ap.add_argument("--step_tag", default=None, help="optional: filter by step_tag")
    ap.add_argument("--layer_base1", action="store_true",
                    help="If set, y-axis will display layer as 1..N instead of 0..N-1")
    ap.add_argument("--max_xticks", type=int, default=10000,
                    help="max number of x tick labels to show (sparse labeling)")
    ap.add_argument("--fig_w", type=float, default=18.0)
    ap.add_argument("--fig_h", type=float, default=7.0)
    ap.add_argument("--dpi", type=int, default=200)
    ap.add_argument("--plot_lines", action="store_true",
                help="Also plot line curves: one curve per token")
    ap.add_argument("--topk_tokens", type=int, default=200000,
                    help="Plot only top-K tokens by mean contribution")
    ap.add_argument("--token_stride", type=int, default=1,
                    help="Subsample tokens by stride (e.g., 5 means every 5 tokens)")
    ap.add_argument("--normalize", action="store_true",
                    help="Normalize each token curve to [0,1]")
    args = ap.parse_args()

    df = pd.read_csv(args.csv)

    if args.step_tag is not None:
        df = df[df["step_tag"] == args.step_tag].copy()
        if df.empty:
            raise ValueError(f"No rows after filtering step_tag={args.step_tag!r}")

    # 取必要列
    need_cols = {"layer", "token_pos", "token_str", args.metric}
    missing = need_cols - set(df.columns)
    if missing:
        raise ValueError(f"CSV missing columns: {missing}")

    # 构建 Layer x TokenPos 矩阵
    # 如果某些位置缺失（理论上不应该），用 NaN 填充
    pivot = df.pivot_table(
        index="layer",
        columns="token_pos",
        values=args.metric,
        aggfunc="mean"
    ).sort_index(axis=0).sort_index(axis=1)

    mat = pivot.to_numpy()  # shape [num_layers, num_tokens]
    layers = pivot.index.to_list()
    token_pos = pivot.columns.to_list()

    # 生成 token 标签（按 token_pos 从原 df 抓取一个 token_str）
    # 注意：同一个 token_pos 在所有 layer 应该 token_str 一致
    tok_map = (df.sort_values(["token_pos", "layer"])
                 .drop_duplicates(subset=["token_pos"], keep="first")
                 .set_index("token_pos")["token_str"]
                 .to_dict())
    token_labels = [safe_token_label(tok_map.get(p, "")) for p in token_pos]

    # 画图
    os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
    plt.figure(figsize=(args.fig_w, args.fig_h))

    # vmin/vmax 使用分位数更稳（避免极端值把对比度压没）
    finite_vals = mat[np.isfinite(mat)]
    if finite_vals.size == 0:
        raise ValueError("Matrix has no finite values.")
    vmin = np.quantile(finite_vals, 0.02)
    vmax = np.quantile(finite_vals, 0.98)
    if vmin == vmax:
        vmin = float(np.min(finite_vals))
        vmax = float(np.max(finite_vals) + 1e-6)

    im = plt.imshow(mat, aspect="auto", interpolation="nearest", vmin=vmin, vmax=vmax)

    title = f"Token contribution heatmap ({args.metric})"
    if args.step_tag:
        title += f" | step_tag={args.step_tag}"
    plt.title(title)
    plt.xlabel("Tokens (history order)")
    plt.ylabel("Layer" + (" (1-based)" if args.layer_base1 else " (0-based)"))

    # y ticks：显示所有层（如果层很多也可以稀疏，但 32 层完全没问题）
    ytick_labels = [(l + 1) if args.layer_base1 else l for l in layers]
    plt.yticks(ticks=np.arange(len(layers)), labels=ytick_labels)

    # x ticks：token 很多时只显示部分
    n_tok = len(token_pos)
    if n_tok <= args.max_xticks:
        xticks = np.arange(n_tok)
    else:
        # 均匀抽样 max_xticks 个
        xticks = np.linspace(0, n_tok - 1, args.max_xticks).round().astype(int)
        xticks = np.unique(xticks)

    plt.xticks(
        ticks=xticks,
        labels=[token_labels[i] for i in xticks],
        rotation=60,
        ha="right"
    )
    
    cbar = plt.colorbar(im)
    cbar.set_label(args.metric)

    plt.tight_layout()
    plt.savefig(args.out, dpi=args.dpi)
    print(f"[Saved] {args.out}")

    # ============================================================
    # Line plot: token-wise curves across layers
    # ============================================================
    if args.plot_lines:
        # 重新组织成 token_pos × layer 的矩阵
        pivot_tok = df.pivot_table(
            index="token_pos",
            columns="layer",
            values=args.metric,
            aggfunc="mean"
        ).sort_index(axis=0).sort_index(axis=1)

        # 每个 token 的“整体强度”，用于筛 Top-K
        token_strength = pivot_tok.mean(axis=1)

        # 选 token
        selected_tokens = token_strength.sort_values(ascending=False).index.tolist()

        # stride 抽样
        if args.token_stride > 1:
            selected_tokens = selected_tokens[::args.token_stride]

        # Top-K
        selected_tokens = selected_tokens[:args.topk_tokens]

        if len(selected_tokens) == 0:
            print("[Warn] No tokens selected for line plot.")
        else:
            plt.figure(figsize=(args.fig_w, args.fig_h))

            layers_sorted = pivot_tok.columns.to_numpy()

            for tok in selected_tokens:
                y = pivot_tok.loc[tok].to_numpy()

                if args.normalize:
                    ymin, ymax = np.nanmin(y), np.nanmax(y)
                    if ymax > ymin:
                        y = (y - ymin) / (ymax - ymin)

                label = safe_token_label(tok_map.get(tok, f"tok{tok}"))
                plt.plot(layers_sorted, y, alpha=0.7, linewidth=1.2, label=label)

            plt.xlabel("Layer" + (" (1-based)" if args.layer_base1 else " (0-based)"))
            plt.ylabel(args.metric + (" (normalized)" if args.normalize else ""))
            plt.title(f"Token-wise contribution curves ({args.metric})")

            # legend 控制：token 太多时放到图外
            if len(selected_tokens) <= 15:
                plt.legend(fontsize=8)
            else:
                plt.legend(fontsize=7, ncol=2, bbox_to_anchor=(1.02, 1), loc="upper left")

            plt.tight_layout()
            out_line = args.out.replace(".png", "_lines.png")
            plt.savefig(out_line, dpi=args.dpi, bbox_inches="tight")
            print(f"[Saved] {out_line}")


if __name__ == "__main__":
    main()
