#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Layer-wise option-argmax frequency tracing (logit-lens) for MCQ tasks.

Outputs (in output_dir):
  - option_argmax_freq_matrix.csv
  - option_argmax_freq_matrix_wide.csv
  - option_argmax_freq_heatmap.png
  - option_argmax_freq_lines.png

If mapping is available (project pruned axis back to dense axis with gaps):
  - option_argmax_freq_matrix_wide_dense_axis.csv
  - option_argmax_freq_heatmap_dense_axis.png
  - option_argmax_freq_lines_dense_axis.png

Key choices:
  - Decision position = last prompt token (predicting the first generated token).
  - Option set auto-detected from prompt "### Options:" block; supports ABCD or AB (and numeric).
  - Logit-lens: hidden @ decision pos -> final norm -> lm_head -> logits.
  - For each layer: argmax among option logits => vote. Aggregate votes across N samples => frequency.
  - Strict: if later samples have different option set from the first detected set, skip.
"""

import os
import re
import argparse
from typing import List, Dict, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from data_utils import get_sft_dataset, collate_sft  # keep your interface

torch.set_num_threads(4)

# -----------------------------
# prompt parsing
# -----------------------------
_OPTION_LINE_RE = re.compile(r"^\s*([A-D])\.\s", re.IGNORECASE)
_OPTION_LINE_NUM_RE = re.compile(r"^\s*([1-9])\.\s")

_ANS_RE_LETTER = re.compile(r"\b([A-D])\b", re.IGNORECASE)
_ANS_RE_NUM    = re.compile(r"\b([1-4])\b")

def get_answer_key_from_labels(labels_1d: torch.Tensor, tokenizer, option_keys: List[str]) -> str:
    """
    Robustly extract ground-truth option key from labels.
    Strategy:
      1) Take the whole answer span (labels != -100), decode it as a string.
      2) Extract the first occurrence of A/B/C/D (or 1-4) from that string.
      3) Map 1-4 -> A-D if needed.
    Return "" if cannot parse.
    """
    label_mask = labels_1d != -100
    if not label_mask.any():
        return ""

    ans_ids = labels_1d[label_mask].tolist()

    # Decode full answer span, not a single token (key fix!)
    ans_text = tokenizer.decode(ans_ids, skip_special_tokens=True)
    ans_text = ans_text.strip().upper()

    if not ans_text:
        return ""

    # Try letter first
    m = _ANS_RE_LETTER.search(ans_text)
    if m:
        k = m.group(1).upper()
        return k if k in option_keys else ""

    # Then numeric 1-4 -> A-D
    m2 = _ANS_RE_NUM.search(ans_text)
    if m2:
        num = m2.group(1)
        num2abc = {"1":"A","2":"B","3":"C","4":"D"}
        k = num2abc.get(num, "")
        return k if k in option_keys else ""

    return ""


# ============================================================
# Layer alignment: project pruned-layer results back to dense axis
# ============================================================
# llama3-8b
REMAIN_TO_REMOVED_0BASED = {
    31: [24],
    30: [24,23],
    29: [24,23,22],
    28: [24,23,22,21],
    27: [24,23,22,21,19],
    26: [24,23,22,21,19,20],
    25: [24,23,22,21,19,20,18],
    24: [24,23,25,26,27,28,22,21],
    23: [24,23,25,26,27,28,22,21,19],
    22: [24,23,25,26,27,28,22,21,19,20],
    21: [24,23,25,26,27,28,22,21,19,20,18],
    20: [24,23,25,26,27,28,22,21,19,20,18,17],
    19: [24,23,25,26,27,28,22,21,19,20,18,17,10],
    18: [24,23,25,26,27,28,22,21,19,20,18,17,10,2],
    17: [24,23,25,26,27,28,22,21,19,20,18,17,10,2,11],
    16: [24,23,25,26,27,28,22,21,19,20,18,17,10,2,11,9],
}
# llama2-7b
# REMAIN_TO_REMOVED_0BASED = {
#     31: [25],
#     30: [25,24],
#     29: [25,24,23],
#     28: [25,24,23,21],
#     27: [25,24,23,21,20],
#     26: [25,24,23,21,20,26],
#     25: [25,24,23,21,20,26,19],
#     24: [25,24,23,21,20,26,19,22],
#     23: [25,24,23,26,21,20,27,28,29],
#     22: [25,24,23,26,21,20,27,28,29,19],
#     21: [25,24,23,26,21,20,27,28,29,19,22],
#     20: [25,24,23,26,21,20,27,28,29,19,22,14],
#     19: [25,24,23,26,21,20,27,28,29,19,22,14,12],
#     18: [25,24,23,26,21,20,27,28,29,19,22,14,12,10],
#     17: [25,24,23,26,21,20,27,28,29,19,22,14,12,10,16],
#     16: [25,24,23,26,21,20,27,28,29,19,22,14,12,10,16,15],
# }
# qwen3-4b
# REMAIN_TO_REMOVED_0BASED = {
#     35: [32],
#     34: [32,31],
#     33: [32,31,30],
#     32: [32,31,30,2],
#     31: [32,31,30,2,29],
#     30: [32,31,30,2,29,26],
#     29: [32,31,30,2,29,26,1],
#     28: [32,31,30,2,29,26,1,28],
#     27: [32,31,30,2,29,26,1,28,27],
#     26: [32,31,30,2,29,26,1,28,27,25],
#     25: [32,31,30,2,29,26,1,28,27,25,24],
#     24: [32,31,30,2,29,26,1,28,27,25,24,20],
#     23: [32,31,30,2,29,26,1,28,27,25,24,20,19],
#     22: [32,31,30,2,29,26,1,28,27,25,24,20,19,7],
#     21: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18],
#     20: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8],
#     19: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17],
#     18: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17,21],
#     17: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17,21,22],
#     16: [32,31,30,2,29,26,1,28,27,25,24,20,19,7,18,8,17,21,22,23],
# }
# -----------------------------
# 散度计算工具
# -----------------------------

def calculate_kl_divergence(p, q):
    """
    计算 KL 散度 D_KL(P || Q)
    p: 目标分布 (Ground Truth Distribution), 形状 [O]
    q: 预测分布 (Model Argmax Frequency), 形状 [O, L]
    """
    eps = 1e-9
    # 确保 p 也是 [O, 1] 方便广播
    p = p.reshape(-1, 1) + eps
    q = q + eps
    
    # D_KL(P||Q) = sum( P * log(P/Q) )
    kl = np.sum(p * np.log2(p / q), axis=0)
    return kl

def _forward_fill_nan_2d(mat: np.ndarray) -> np.ndarray:
    """
    mat: [O, L] with NaN
    forward-fill NaN along axis=1 for each option row.
    If a row starts with NaN, it will back-fill using the first finite value.
    """
    out = mat.copy().astype(np.float32)
    O, L = out.shape
    for o in range(O):
        row = out[o]
        last = np.nan
        for i in range(L):
            if np.isfinite(row[i]):
                last = row[i]
            else:
                row[i] = last
        # if still NaN at the beginning (all NaN or leading NaNs), back-fill
        if not np.isfinite(row[0]):
            finite_idx = np.where(np.isfinite(row))[0]
            if finite_idx.size > 0:
                row[:finite_idx[0]] = row[finite_idx[0]]
        out[o] = row
    return out

def _split_kept_removed(freq_ffill: np.ndarray, removed_0based: List[int]):
    """
    freq_ffill: [O, L_dense] finite
    return:
      freq_kept:   [O, L_dense] values only at kept layers, 0 at removed
      freq_removed:[O, L_dense] values only at removed layers, 0 elsewhere
      removed_mask:[L_dense] bool
    """
    O, L = freq_ffill.shape
    removed_mask = np.zeros(L, dtype=bool)
    removed_mask[np.array(removed_0based, dtype=int)] = True

    freq_kept = freq_ffill.copy()
    freq_kept[:, removed_mask] = 0.0

    freq_removed = np.zeros_like(freq_ffill, dtype=np.float32)
    freq_removed[:, removed_mask] = freq_ffill[:, removed_mask]
    return freq_kept, freq_removed, removed_mask


def save_enhanced_stacked_area_plot(
    freq_mat,
    option_keys,
    gt_dist,
    out_path,
    title_suffix="Llama3-8B | ARC-Easy",
    dpi=200,
    dense_axis: bool = False,
    removed_0based: Optional[List[int]] = None,
    alpha_kept: float = 0.85,
    # 下面两个参数你可以再调狠一点
    removed_fade_alpha: float = 0.60,   # removed 层“变透明”的力度（越大越淡）
    removed_hatch: str = "////",        # 斜线纹理（越密越醒目）
):
    """
    Continuous 100% stacked area + KL curve.
    - 保持 stackplot 连续形态（尖尖的）
    - removed 层用：淡化(white overlay) + hatch 纹理 强提示（不是大块背景）
    - KL 在 removed 层：深灰更醒目
    """
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    O, L = freq_mat.shape
    layers = np.arange(1, L + 1)

    # ---------- dense axis: forward-fill to keep continuity ----------
    if dense_axis:
        # dense_axis=True 但 removed_0based 允许为空（dense 模型/无剪枝）
        if removed_0based is None:
            removed_0based = []
        removed_mask = np.zeros(L, dtype=bool)
        if len(removed_0based) > 0:
            removed_mask[np.array(removed_0based, dtype=int)] = True

        freq_ffill = _forward_fill_nan_2d(freq_mat)  # [O, L], finite
        col_sum = np.sum(freq_ffill, axis=0, keepdims=True)
        col_sum = np.maximum(col_sum, 1e-12)
        freq_plot = freq_ffill / col_sum
        kl_div = calculate_kl_divergence(gt_dist, freq_plot)
    else:
        removed_mask = np.zeros(L, dtype=bool)
        freq_plot = freq_mat.astype(np.float32)
        col_sum = np.sum(freq_plot, axis=0, keepdims=True)
        col_sum = np.maximum(col_sum, 1e-12)
        freq_plot = freq_plot / col_sum
        kl_div = calculate_kl_divergence(gt_dist, freq_plot)

    # ---------- colors ----------
    base_colors = {"A": "#1f77b4", "B": "#ff7f0e", "C": "#2ca02c", "D": "#d62728"}
    colors = [base_colors.get(k, plt.cm.tab10(i)) for i, k in enumerate(option_keys)]

    fig, ax1 = plt.subplots(figsize=(12, 8))

    # ---------- 1) continuous stacked area ----------
    ax1.stackplot(
        layers,
        freq_plot,
        labels=[f"{k}" for k in option_keys],
        colors=colors,
        alpha=alpha_kept
    )

    # ---------- 2) 强提示 removed layers：淡化 + 纹理（每层一条，不会出现大梯形空白） ----------
    if dense_axis and removed_mask.any():
        for li in np.where(removed_mask)[0]:
            x0, x1 = (li + 1) - 0.5, (li + 1) + 0.5

            # (a) 先用白色“擦淡”（相当于把该层整体透明化一些）
            ax1.axvspan(
                x0, x1,
                facecolor="white",
                alpha=removed_fade_alpha,
                linewidth=0,
                zorder=5,
            )

            # (b) 再叠一层 hatch 纹理（极其醒目，但不靠背景色）
            ax1.axvspan(
                x0, x1,
                facecolor="none",
                edgecolor="0.25",   # 深一点，别太浅
                hatch=removed_hatch,
                linewidth=0.0,
                zorder=6,
            )

    # ---------- 3) KL curve：removed 用更深灰、更粗、更大的 marker ----------
    ax2 = ax1.twinx()

    if dense_axis and removed_mask.any():
        # kept segments in black
        kl_kept = kl_div.copy()
        kl_kept[removed_mask] = np.nan
        for sx, sy in line_segments_with_gaps(layers, kl_kept):
            ax2.plot(sx, sy, color="black", linestyle="--", linewidth=3.5)

        # removed points/segments in dark gray
        idx_removed = np.where(removed_mask)[0]
        ax2.scatter(
            layers[idx_removed],
            kl_div[idx_removed],
            color="0.35",   # 深灰
            s=60,           # 点更大
            zorder=10
        )
        # 只连接连续 removed（避免跨空连接）
        if idx_removed.size > 1:
            for a, b in zip(idx_removed[:-1], idx_removed[1:]):
                if b == a + 1:
                    ax2.plot(
                        [layers[a], layers[b]],
                        [kl_div[a], kl_div[b]],
                        color="0.35",
                        linestyle="-",
                        linewidth=3.0,
                        alpha=1.0,
                        zorder=9
                    )
    else:
        ax2.plot(layers, kl_div, color="black", linestyle="--", linewidth=3.5)

    ax2.set_ylim(0, max(np.nanmax(kl_div) * 1.15, 0.1))

    # ---------- styling ----------
    plt.title(title_suffix, fontsize=22, pad=25)

    ax1.set_xlabel("Layer ID", fontsize=22)
    ax1.set_xticks(layers)
    ax1.tick_params(axis='x', labelsize=20, rotation=90)

    ax1.set_ylabel("Option Frequency", fontsize=20)
    ax1.set_ylim(0, 1.0)
    ax1.tick_params(axis='y', labelsize=20)

    ax2.set_ylabel("KL Divergence", fontsize=22, rotation=270, labelpad=30)
    ax2.tick_params(axis='y', labelsize=20)

    # legend：stackplot 的 handles + KL 一条
    lines1, labels1 = ax1.get_legend_handles_labels()
    handles2 = [plt.Line2D([0], [0], color="black", linestyle="--", linewidth=3.5)]
    labels2 = ["KL Divergence"]

    ax1.legend(
        lines1 + handles2, labels1 + labels2,
        loc='upper center', bbox_to_anchor=(0.5, -0.12),
        ncol=5, fontsize=16, frameon=False
    )

    gt_text_str = "GT Dist: " + "  ".join([f"{k}: {gt_dist[i]:.2f}" for i, k in enumerate(option_keys)])
    fig.text(0.5, 0.02, gt_text_str, ha='center', fontsize=16,
             bbox=dict(facecolor='white', alpha=0.8, edgecolor='lightgray', boxstyle='round,pad=0.3'))

    plt.subplots_adjust(bottom=0.22, top=0.92, left=0.08, right=0.92)
    plt.grid(True, axis='x', linestyle=':', linewidth=1.5, alpha=0.5)

    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f"[Saved Enhanced Plot with KL] {out_path}")

def save_radial_stacked_plot(
    freq_mat,
    option_keys,
    out_path,
    title_suffix="Llama3-8B | ARC-Easy",
    dpi=200,
    dense_axis: bool = False,
    removed_0based: Optional[List[int]] = None,
    alpha_kept: float = 0.80,
    alpha_removed: float = 0.20,
):
    """
    Radial stacked bars.
    - dense_axis=True: freq_mat is [O, L_dense] with NaN at removed layers.
      We forward-fill NaN to keep a continuous-looking ring, then mark removed layers by lower alpha.
    - dense_axis=False: freq_mat is [O, L] normal.
    """
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)

    # ---- ensure defined in all branches ----
    freq_ffill = freq_mat.astype(np.float32)  # default for pruned-axis
    O, L = freq_ffill.shape
    removed_mask = np.zeros(L, dtype=bool)

    if dense_axis:
        if removed_0based is None:
            removed_0based = []
        removed_0based = [int(x) for x in removed_0based]

        removed_mask = np.zeros(L, dtype=bool)
        if len(removed_0based) > 0:
            removed_mask[np.array(removed_0based, dtype=int)] = True

        # forward-fill NaN for continuity
        freq_ffill = _forward_fill_nan_2d(freq_mat).astype(np.float32)

        # re-normalize to sum=1 per layer (robustness)
        col_sum = np.sum(freq_ffill, axis=0, keepdims=True)
        col_sum = np.maximum(col_sum, 1e-12)
        freq_ffill = freq_ffill / col_sum
    else:
        # also normalize (optional but recommended)
        col_sum = np.sum(freq_ffill, axis=0, keepdims=True)
        col_sum = np.maximum(col_sum, 1e-12)
        freq_ffill = freq_ffill / col_sum

    # ---- polar geometry ----
    angles = np.linspace(0, 2 * np.pi, L, endpoint=False)
    width = (2 * np.pi) / L * 0.8

    base_colors = {"A": "#1f77b4", "B": "#ff7f0e", "C": "#2ca02c", "D": "#d62728"}
    colors = [base_colors.get(k, plt.cm.tab10(i)) for i, k in enumerate(option_keys)]

    fig, ax = plt.subplots(figsize=(12, 12), subplot_kw={'projection': 'polar'})

    # per-layer stacking so alpha can differ by layer
    for li in range(L):
        a = alpha_removed if removed_mask[li] else alpha_kept
        bottom = 0.0
        for oi in range(O):
            ax.bar(
                angles[li],
                float(freq_ffill[oi, li]),
                width=width,
                bottom=bottom,
                color=colors[oi],
                alpha=a,
                linewidth=0.0
            )
            bottom += float(freq_ffill[oi, li])

    # legend once
    handles = [plt.Line2D([0], [0], color=colors[i], lw=10) for i in range(O)]
    ax.legend(handles, option_keys, loc='lower right', fontsize=14)

    ax.set_xticks(angles)
    ax.set_xticklabels(np.arange(1, L + 1), fontsize=14)
    ax.set_yticklabels([])
    plt.title(title_suffix, fontsize=22, pad=30)

    plt.savefig(out_path, dpi=dpi, bbox_inches='tight')
    plt.close()
    print(f"[Saved Radial Plot] {out_path}")



def build_pruned_to_dense_mapping(L_dense: int, L_pruned: int, removed_0based: List[int]) -> List[int]:
    """
    Return mapping list length L_pruned:
      pruned_layer_i (0-based) -> dense_layer_id (0-based)
    Assumes pruned keeps the original order of remaining layers.
    """
    removed = set(int(x) for x in removed_0based)
    remain = [i for i in range(L_dense) if i not in removed]
    if len(remain) != L_pruned:
        raise ValueError(
            f"Mapping mismatch: dense={L_dense}, removed={len(removed)}, remain={len(remain)} "
            f"!= L_pruned={L_pruned}. Check REMAIN_TO_REMOVED_0BASED for L_pruned={L_pruned}."
        )
    return remain


def expand_freq_to_dense_axis(freq_mat: np.ndarray, pruned_to_dense: List[int], L_dense: int) -> np.ndarray:
    """
    freq_mat: [O, L_pruned] on pruned axis
    return:   [O, L_dense]  on dense axis, NaN at removed layers
    """
    O, L_pruned = freq_mat.shape
    if L_pruned != len(pruned_to_dense):
        raise ValueError("freq_mat L_pruned != mapping length")

    out = np.full((O, L_dense), np.nan, dtype=np.float32)
    for p in range(L_pruned):
        d = pruned_to_dense[p]
        out[:, d] = freq_mat[:, p]
    return out


def line_segments_with_gaps(xs: np.ndarray, ys: np.ndarray):
    """
    Split (xs,ys) into contiguous segments where ys is finite.
    """
    finite = np.isfinite(ys)
    if not finite.any():
        return []

    idx = np.where(finite)[0]
    segs = []
    start = idx[0]
    prev = idx[0]
    for k in idx[1:]:
        if k == prev + 1:
            prev = k
        else:
            segs.append((xs[start:prev + 1], ys[start:prev + 1]))
            start = k
            prev = k
    segs.append((xs[start:prev + 1], ys[start:prev + 1]))
    return segs


def set_seed(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def normalize_option_keys(keys: List[str]) -> List[str]:
    """
    Normalize option keys:
      - {1,2,3,4} -> {A,B,C,D}
      - {1,2}     -> {A,B}
      - {A,B,C,D} stays
      - {A,B}     stays
    """
    seen = set()
    uniq = []
    for k in keys:
        k = str(k).strip()
        if k and k not in seen:
            seen.add(k)
            uniq.append(k)

    s = set(uniq)
    num2abc = {"1": "A", "2": "B", "3": "C", "4": "D"}

    if s.issubset({"1", "2"}) and len(s) >= 2:
        return ["A", "B"]
    if s.issubset({"1", "2", "3", "4"}) and len(s) >= 4:
        return ["A", "B", "C", "D"]

    if s.issubset({"A", "B"}) and len(s) >= 2:
        return ["A", "B"]
    if s.issubset({"A", "B", "C", "D"}) and len(s) >= 4:
        return ["A", "B", "C", "D"]

    mapped = []
    for k in uniq:
        mapped.append(num2abc.get(k, k.upper() if k.isalpha() else k))

    seen2 = set()
    out = []
    for k in mapped:
        if k not in seen2:
            seen2.add(k)
            out.append(k)
    return out


def extract_option_keys_from_prompt(prompt_text: str) -> List[str]:
    keys = []
    idx = prompt_text.find("### Options:")
    if idx >= 0:
        opt_block = prompt_text[idx:].split("### Answer", 1)[0]
    else:
        opt_block = prompt_text

    for line in opt_block.splitlines():
        m = _OPTION_LINE_RE.match(line)
        if m:
            keys.append(m.group(1).upper())
            continue
        m2 = _OPTION_LINE_NUM_RE.match(line)
        if m2:
            keys.append(m2.group(1))
            continue

    return normalize_option_keys(keys)


# -----------------------------
# model + data
# -----------------------------
def load_model_and_tokenizer(model_path: str, dtype: str = "bf16", force_eager_attn: bool = True):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    torch_dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[dtype]

    config_kwargs = {}
    if force_eager_attn:
        try:
            cfg = AutoConfig.from_pretrained(model_path)
            cfg._attn_implementation = "eager"
            config_kwargs["config"] = cfg
        except Exception as e:
            print(f"[Warn] AutoConfig eager attn setup failed: {e}")

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch_dtype,
        device_map=None,
        attn_implementation="eager" if force_eager_attn else None,
        **config_kwargs
    )
    model.eval()
    return model, tokenizer


def build_eval_dataloader(tokenizer, sft_dataset: str, max_length: int, seed: int,
                          eval_split: str, num_workers: int, num_eval_samples: int):
    dataset = get_sft_dataset(
        name=sft_dataset,
        tokenizer=tokenizer,
        max_length=max_length,
        seed=seed,
        num_samples=num_eval_samples,
        split=eval_split,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_sft,
        num_workers=num_workers,
        pin_memory=True,
    )
    return dataloader, dataset


# -----------------------------
# option token id resolution
# -----------------------------
def candidate_token_ids_for_key(tokenizer, key: str) -> List[int]:
    cands = [key, " " + key, "\n" + key, "\n " + key]
    ids = []
    for s in cands:
        enc = tokenizer.encode(s, add_special_tokens=False)
        if len(enc) == 1:
            ids.append(enc[0])
    uniq = []
    seen = set()
    for i in ids:
        if i not in seen:
            seen.add(i)
            uniq.append(i)
    return uniq


def choose_best_token_id(tokenizer, model, device, prompt_ids_1d: torch.Tensor, key: str) -> int:
    cand_ids = candidate_token_ids_for_key(tokenizer, key)
    if not cand_ids:
        raise ValueError(f"Cannot find any single-token encoding for option key={key!r}")

    input_ids = prompt_ids_1d.unsqueeze(0).to(device)
    attn = torch.ones_like(input_ids, dtype=torch.long, device=device)

    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attn, use_cache=False)
    logits = out.logits[0, -1, :]

    best_id = cand_ids[0]
    best_val = float(logits[best_id].item())
    for tid in cand_ids[1:]:
        v = float(logits[tid].item())
        if v > best_val:
            best_val = v
            best_id = tid
    return best_id


# -----------------------------
# logit-lens helpers
# -----------------------------
def get_final_norm_module(model) -> torch.nn.Module:
    if hasattr(model, "model") and hasattr(model.model, "norm"):
        return model.model.norm
    if hasattr(model, "transformer") and hasattr(model.transformer, "ln_f"):
        return model.transformer.ln_f
    raise AttributeError("Cannot locate final norm module (model.model.norm or transformer.ln_f).")


def get_lm_head_module(model) -> torch.nn.Module:
    if hasattr(model, "lm_head"):
        return model.lm_head
    if hasattr(model, "embed_out"):
        return model.embed_out
    raise AttributeError("Cannot locate lm_head module.")


def extract_prompt_ids(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
    input_ids = batch["input_ids"][0]
    attention_mask = batch["attention_mask"][0]
    labels = batch["labels"][0]
    prompt_mask = (labels == -100) & (attention_mask == 1)
    return input_ids[prompt_mask]


def layerwise_argmax_option_for_prompt(
    model,
    device,
    prompt_ids_1d: torch.Tensor,
    option_keys: List[str],
    option_token_ids: Dict[str, int],
    max_prompt_tokens: int = 256
) -> Tuple[np.ndarray, int]:
    if max_prompt_tokens and prompt_ids_1d.numel() > max_prompt_tokens:
        prompt_ids_1d = prompt_ids_1d[-max_prompt_tokens:]

    input_ids = prompt_ids_1d.unsqueeze(0).to(device)
    attn = torch.ones_like(input_ids, dtype=torch.long, device=device)

    with torch.no_grad():
        out = model(
            input_ids=input_ids,
            attention_mask=attn,
            use_cache=False,
            output_hidden_states=True,
        )

    hs = out.hidden_states
    L = len(hs) - 1
    final_norm = get_final_norm_module(model)
    lm_head = get_lm_head_module(model)

    pos = -1  # last prompt token
    O = len(option_keys)
    argmax_ids = np.zeros((L,), dtype=np.int32)

    for li in range(1, L + 1):
        h = hs[li][0, pos, :]
        h = final_norm(h)
        logits = lm_head(h)

        opt_logits = np.asarray([float(logits[option_token_ids[k]].item()) for k in option_keys], dtype=np.float32)
        argmax_ids[li - 1] = int(opt_logits.argmax())

    return argmax_ids, L


# -----------------------------
# plotting
# -----------------------------
def _get_option_color_map(option_keys: List[str]):
    """
    Fix colors per option so that:
      - solid segments and dashed bridges share the exact same color
      - pruned/dense plots are visually comparable if option sets match
    """
    base = {
        "A": "#1f77b4",  # matplotlib default blue
        "B": "#ff7f0e",  # default orange
        "C": "#2ca02c",  # default green
        "D": "#d62728",  # default red
    }
    cmap = {}
    for i, k in enumerate(option_keys):
        if k in base:
            cmap[k] = base[k]
        else:
            # fallback: deterministic from tab10
            cmap[k] = plt.cm.tab10(i % 10)
    return cmap


def save_freq_heatmap(freq_mat: np.ndarray, option_keys: List[str], out_path: str,
                      title: str, dpi: int, fig_w: float, fig_h: float,
                      dense_axis: bool = False):
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    plt.figure(figsize=(fig_w, fig_h))

    cmap = plt.cm.viridis.copy()
    cmap.set_bad(color=(1, 1, 1, 0))

    im = plt.imshow(freq_mat, aspect="auto", interpolation="nearest",
                    vmin=0.0, vmax=1.0, cmap=cmap)
    plt.colorbar(im, label="Argmax frequency")
    plt.title(title)
    plt.xlabel("Dense Layer ID (1-based)" if dense_axis else "Layer (1-based)")
    plt.ylabel("Option")
    plt.yticks(np.arange(len(option_keys)), option_keys)

    L = freq_mat.shape[1]
    plt.xticks(np.arange(L), np.arange(1, L + 1))

    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi)
    plt.close()
    print(f"[Saved] {out_path}")


def save_freq_lines(freq_mat: np.ndarray, option_keys: List[str], out_path: str,
                    title: str, dpi: int, fig_w: float, fig_h: float,
                    dense_axis: bool = False, draw_gap_bridges: bool = True):
    """
    Ensures color-consistency:
      - each option uses a fixed color
      - all solid segments and dashed bridges for that option share the same color
    """
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    O, L = freq_mat.shape
    xs = np.arange(1, L + 1)

    color_map = _get_option_color_map(option_keys)

    plt.figure(figsize=(fig_w, fig_h))

    for oi in range(O):
        opt = option_keys[oi]
        ys = freq_mat[oi].astype(np.float32)
        color = color_map.get(opt, None)

        # 1) solid segments (with label only once)
        segs = line_segments_with_gaps(xs, ys)
        labeled = False
        for (sx, sy) in segs:
            plt.plot(
                sx, sy,
                linewidth=2.0,
                color=color,
                label=(opt if not labeled else None),
            )
            labeled = True

        # 2) dashed bridges across gaps (same color)
        if draw_gap_bridges:
            finite = np.isfinite(ys)
            idx = np.where(finite)[0]
            if len(idx) >= 2:
                for a, b in zip(idx[:-1], idx[1:]):
                    if b > a + 1:
                        plt.plot(
                            [xs[a], xs[b]], [ys[a], ys[b]],
                            linestyle="--",
                            linewidth=1.5,
                            color=color,
                        )

    plt.title(title)
    plt.xlabel("Dense Layer ID (1-based)" if dense_axis else "Layer (1-based)")
    plt.ylabel("Option Frequency")
    plt.ylim(0.0, 1.0)
    plt.xticks(xs)
    plt.grid(True, linestyle="--", linewidth=0.8, alpha=0.6)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=dpi)
    plt.close()
    print(f"[Saved] {out_path}")


# -----------------------------
# main
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser("Layer-wise option argmax frequency (logit-lens)")

    p.add_argument("--model_name_or_path", type=str, required=True)
    p.add_argument("--output_dir", type=str, required=True)

    p.add_argument("--sft_dataset", type=str, default="hellaswag")
    p.add_argument("--eval_split", type=str, default="validation",
                   choices=["train", "validation", "valid", "test"])
    p.add_argument("--max_length", type=int, default=512)

    p.add_argument("--num_eval_samples", type=int, default=500)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--seed", type=int, default=42)

    p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
    p.add_argument("--force_eager_attn", action="store_true", default=True)

    p.add_argument("--max_prompt_tokens", type=int, default=256)

    # plotting
    p.add_argument("--dpi", type=int, default=200)
    p.add_argument("--fig_w", type=float, default=12.0)
    p.add_argument("--fig_h", type=float, default=4.5)

    # dense layer count for projection plot (default 32 for LLaMA-32 layers)
    p.add_argument("--dense_layers", type=int, default=32)

    return p.parse_args()


def main():
    torch.backends.cudnn.enabled = False
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("====== Option argmax frequency args ======")
    for k, v in vars(args).items():
        print(f"{k}: {v}")
    print("=========================================")

    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[Device] {device}")

    print(f"[Load] model from {args.model_name_or_path}")
    model, tokenizer = load_model_and_tokenizer(
        args.model_name_or_path, dtype=args.dtype, force_eager_attn=args.force_eager_attn
    )
    model.to(device).eval()

    dataloader, dataset = build_eval_dataloader(
        tokenizer=tokenizer,
        sft_dataset=args.sft_dataset,
        max_length=args.max_length,
        seed=args.seed,
        eval_split=args.eval_split,
        num_workers=args.num_workers,
        num_eval_samples=args.num_eval_samples,
    )

    option_keys = None
    option_token_ids = {}
    vote_counts = None  # [O, L]

    print(f"[Data] dataset size (after num_samples): {len(dataset)}")

    option_keys: Optional[List[str]] = None
    option_token_ids: Dict[str, int] = {}

    vote_counts: Optional[np.ndarray] = None  # [O, L]
    L_global: Optional[int] = None
    # --- 新增：Ground Truth 统计 ---
    gt_counts = None # 延迟初始化，因为不知道有几个选项
    n_used = 0

    for idx, batch in enumerate(dataloader):
        if n_used >= args.num_eval_samples:
            break

        # 1. 提取 Prompt 信息
        batch_gpu = {k: v.to(device) for k, v in batch.items()}
        prompt_ids = extract_prompt_ids(batch_gpu)
        if prompt_ids.numel() == 0: continue

        prompt_text = tokenizer.decode(prompt_ids.tolist(), skip_special_tokens=False)
        keys_now = extract_option_keys_from_prompt(prompt_text)

        # 2. 提取 Label (Ground Truth) 信息
        # 提取答案方式一：在 SFT 数据集中，labels 序列中不为 -100 的第一个 token 通常是答案
        # labels = batch["labels"][0]
        # # 寻找第一个非 -100 的位置
        # label_mask = labels != -100
        # if not label_mask.any():
        #     continue
        # answer_token_id = labels[label_mask][0].item()
        # answer_text = tokenizer.decode([answer_token_id]).strip().upper()
        # # 处理可能带点的情况，如 "A." -> "A"
        # answer_key = answer_text[0] if len(answer_text) > 0 else ""
        # 提取答案方式二：直接从
        labels = batch["labels"][0]
        answer_key = get_answer_key_from_labels(labels, tokenizer, option_keys if option_keys is not None else ["A","B","C","D"])
        if answer_key == "":
            # 这里建议别每条都刷屏，先计数，最后汇总（下面我给你）
            pass

        # 初始化选项和统计矩阵
        if option_keys is None:
            option_keys = keys_now if keys_now else ["A", "B", "C", "D"]
            for k in option_keys:
                option_token_ids[k] = choose_best_token_id(tokenizer, model, device, prompt_ids, k)
            gt_counts = {k: 0 for k in option_keys}
            print(f"[Options] detected: {option_keys}")

        # 统计 GT (只统计在 option_keys 范围内的)
        if answer_key in gt_counts:
            gt_counts[answer_key] += 1
        else:
            # 兼容性处理：如果 label 匹配不上，尝试从 token_id 反推或者略过
            print(f"[Warn] Label '{answer_key}' not in options {option_keys}, skip gt counting for this sample")

        # 3. 运行 Logit-Lens 提取每一层的预测
        argmax_ids, L = layerwise_argmax_option_for_prompt(
            model=model, device=device, prompt_ids_1d=prompt_ids.detach(),
            option_keys=option_keys, option_token_ids=option_token_ids,
            max_prompt_tokens=args.max_prompt_tokens,
        )

        if vote_counts is None:
            vote_counts = np.zeros((len(option_keys), L), dtype=np.int64)
        
        for li in range(L):
            vote_counts[argmax_ids[li], li] += 1

        n_used += 1
        if n_used % 25 == 0:
            print(f"[Progress] used {n_used}/{args.num_eval_samples} samples")

    # --- 计算最终分布 ---
    freq_mat = vote_counts.astype(np.float32) / float(n_used)
    
    # 计算 GT 概率分布 P
    gt_total = sum(gt_counts.values())
    if gt_total == 0: raise RuntimeError("No valid ground truth labels found.")
    gt_dist = np.array([gt_counts[k] / gt_total for k in option_keys], dtype=np.float32)
    
    print(f"[GT Distribution] {dict(zip(option_keys, gt_dist))}")

    if vote_counts is None or option_keys is None or n_used < 2:
        raise RuntimeError(f"Too few valid samples collected: n_used={n_used}")

    freq_mat = (vote_counts.astype(np.float32) / float(n_used))  # [O, L], in [0,1]

    # ============================================================
    # Project to dense axis with gaps (if mapping is available)
    # ============================================================
    L_pruned = freq_mat.shape[1]
    L_dense = int(args.dense_layers)

    freq_mat_dense = None
    removed = REMAIN_TO_REMOVED_0BASED.get(L_pruned, None)
    if removed is None:
        print(f"[Warn] No removed-layer list found for L_pruned={L_pruned}. Will only save pruned-axis plots.")
    else:
        pruned_to_dense = build_pruned_to_dense_mapping(
            L_dense=L_dense,
            L_pruned=L_pruned,
            removed_0based=removed
        )
        freq_mat_dense = expand_freq_to_dense_axis(
            freq_mat=freq_mat,
            pruned_to_dense=pruned_to_dense,
            L_dense=L_dense
        )

        wide_dense_df = pd.DataFrame(
            freq_mat_dense,
            index=option_keys,
            columns=[f"dense_layer_{i}" for i in range(1, L_dense + 1)]
        )
        wide_dense_path = os.path.join(args.output_dir, "option_argmax_freq_matrix_wide_dense_axis.csv")
        wide_dense_df.to_csv(wide_dense_path)
        print(f"[Saved] {wide_dense_path}")

    # Save CSV (pruned axis)
    wide_df = pd.DataFrame(freq_mat, index=option_keys, columns=[f"layer_{i}" for i in range(1, freq_mat.shape[1] + 1)])
    wide_path = os.path.join(args.output_dir, "option_argmax_freq_matrix_wide.csv")
    wide_df.to_csv(wide_path)
    print(f"[Saved] {wide_path}")

    long_rows = []
    for oi, opt in enumerate(option_keys):
        for li in range(freq_mat.shape[1]):
            long_rows.append({
                "option": opt,
                "layer": li + 1,
                "argmax_freq": float(freq_mat[oi, li]),
                "n_used": int(n_used),
            })
    long_df = pd.DataFrame(long_rows)
    long_path = os.path.join(args.output_dir, "option_argmax_freq_matrix.csv")
    long_df.to_csv(long_path, index=False)
    print(f"[Saved] {long_path}")

    # Heatmap (pruned axis)
    heatmap_path = os.path.join(args.output_dir, "option_argmax_freq_heatmap.png")
    save_freq_heatmap(
        freq_mat,
        option_keys=option_keys,
        out_path=heatmap_path,
        title=f"Layer-wise argmax frequency (N={n_used}) | {args.sft_dataset}/{args.eval_split}",
        dpi=args.dpi,
        fig_w=args.fig_w,
        fig_h=max(args.fig_h, 3.5),
        dense_axis=False
    )

    # Lines (pruned axis)
    lines_path = os.path.join(args.output_dir, "option_argmax_freq_lines.png")
    enhanced_path = os.path.join(args.output_dir, "option_argmax_kl_divergence_plot.png")

    # Heatmap + Lines (dense axis with gaps + dashed bridges)
    # -------- dense-axis stacked plots (new) --------
    if freq_mat_dense is not None and removed is not None:
        enhanced_dense_path = os.path.join(args.output_dir, "option_argmax_kl_divergence_plot_dense_axis.png")
        save_enhanced_stacked_area_plot(
            freq_mat_dense,
            option_keys=option_keys,
            gt_dist=gt_dist,
            out_path=enhanced_dense_path,
            title_suffix="Llama3-8B | ARC-Easy",
            dense_axis=True,
            removed_0based=removed,
            alpha_kept=0.85,
        )

        radial_dense_path = os.path.join(args.output_dir, "option_argmax_radial_dense_axis.png")
        save_radial_stacked_plot(
            freq_mat_dense,
            option_keys=option_keys,
            out_path=radial_dense_path,
            title_suffix="Llama3-8B | ARC-Easy",
            dense_axis=True,
            removed_0based=removed,
            alpha_kept=0.80,
            alpha_removed=0.20,
        )
    # -----------------------------------------------

    save_enhanced_stacked_area_plot(
        freq_mat, 
        option_keys=option_keys, 
        gt_dist=gt_dist,
        out_path=enhanced_path,
        title_suffix=f"Llama3-8B | ARC-Easy"
    )
    
    radial_path = os.path.join(args.output_dir, "option_argmax_radial.png")
    save_radial_stacked_plot(
        freq_mat,
        option_keys=option_keys,
        out_path=radial_path,
        title_suffix="Llama3-8B | ARC-Easy"
    )

    save_freq_lines(
        freq_mat,
        option_keys=option_keys,
        out_path=lines_path,
        title=f"Option argmax frequency vs layer (N={n_used})",
        dpi=args.dpi,
        fig_w=args.fig_w,
        fig_h=max(args.fig_h, 4.5),
        dense_axis=False,
        draw_gap_bridges=False
    )


    print("[Done]")


if __name__ == "__main__":
    main()
