# analyze_influence_scores.py
# 功能:
# - 读取 finder 产出的 pkl
# - 打印/导出正向Top-100信息（含解码片段）
# - 绘制正/负分数分布(直方图/ECDF)、分数-样本序号散点
# - 统计分数在样本分段区间（k*102400,(k+1)*102400）上的集中情况（图+CSV）
# - 新增：对 positive 中 score>0 的子集，绘制密度分布；计算前10%/20%/50%样本（按分数降序）的累计总分占比；幂函数拟合累计占比
# - 生成到指定输出文件夹

import os
import json
import math
import pickle
from typing import List, Dict, Any, Tuple

import numpy as np
import matplotlib
matplotlib.use("Agg")  # 无GUI环境下输出图片
import matplotlib.pyplot as plt

from transformers import GPTNeoXTokenizerFast

# =============== 配置（请修改为你的实际路径） ===============
PKL_PATH = r"influence_copytarget_qkvo124_L3H3.pkl"
TOKENIZER_PATH = r""   # 你的 tokenizer 目录
DATA_NPY_PATH = r"" # NPY 路径，如不改则优先用 pkl 里config的路径
OUTPUT_DIR = r"" # 输出目录

# 文本解码
SEQ_LENGTH_DEFAULT = 2048  # 若 pkl 中没有配置，则用这个
SNIPPET_CHARS = 40000      # 每条样本解码文本片段的最大字符数

# 分段统计
BLOCK_SIZE = 102400        # 统计区间宽度：k*BLOCK_SIZE ~ (k+1)*BLOCK_SIZE

# 直方图分箱数
HIST_BINS = 80

# =============== 基础工具 ===============
def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)

def load_pkl(pkl_path: str) -> Dict[str, Any]:
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)
    return data

def load_tokenizer(tokenizer_path: str):
    tok = GPTNeoXTokenizerFast.from_pretrained(tokenizer_path, local_files_only=True)
    return tok

def load_npy(npy_path: str) -> np.memmap:
    arr = np.load(npy_path, mmap_mode="r")
    return arr

def decode_text(tokenizer, tokens: np.ndarray) -> str:
    toks = tokens.tolist()
    text = tokenizer.decode(toks, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return text

def summary_stats(arr: np.ndarray) -> Dict[str, float]:
    if arr.size == 0:
        return {}
    return {
        "count": int(arr.size),
        "mean": float(np.mean(arr)),
        "std": float(np.std(arr)),
        "min": float(np.min(arr)),
        "p01": float(np.percentile(arr, 1)),
        "p05": float(np.percentile(arr, 5)),
        "p50": float(np.percentile(arr, 50)),
        "p95": float(np.percentile(arr, 95)),
        "p99": float(np.percentile(arr, 99)),
        "max": float(np.max(arr)),
    }

def save_csv(filepath: str, rows: List[Dict[str, Any]], header: List[str]):
    import csv
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=header)
        writer.writeheader()
        for r in rows:
            writer.writerow({k: r.get(k, "") for k in header})

def barplot_counts(counts: Dict[int, int], title: str, xlabel: str, ylabel: str, out_png: str):
    if not counts:
        return
    ks = sorted(counts.keys())
    vs = [counts[k] for k in ks]
    plt.figure(figsize=(12, 4))
    plt.bar([str(k) for k in ks], vs, color="#4472C4")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def hist_plot(values: np.ndarray, title: str, xlabel: str, out_png: str, bins: int = HIST_BINS, density: bool = False):
    if values.size == 0:
        return
    plt.figure(figsize=(6,4))
    plt.hist(values, bins=bins, color="#4472C4", alpha=0.85, edgecolor="k", density=density)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel("density" if density else "count")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def ecdf_plot(values: np.ndarray, title: str, xlabel: str, out_png: str):
    if values.size == 0:
        return
    xs = np.sort(values)
    ys = np.arange(1, xs.size+1) / xs.size
    plt.figure(figsize=(6,4))
    plt.plot(xs, ys, lw=2, color="#C00000")
    plt.grid(alpha=0.3)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel("ECDF")
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def scatter_plot(x: np.ndarray, y: np.ndarray, title: str, xlabel: str, ylabel: str, out_png: str, s=5, alpha=0.3):
    if x.size == 0 or y.size == 0:
        return
    plt.figure(figsize=(8,4))
    plt.scatter(x, y, s=s, alpha=alpha)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

# =============== 新增：累计占比与幂函数拟合工具 ===============
def cumulative_share(scores: np.ndarray, mode: str = "top") -> Tuple[np.ndarray, np.ndarray]:
    """
    计算累计份额曲线：
      - mode="top": 按分数降序累计，返回 (p, S_top(p))，p=j/n
      - mode="bottom": 按分数升序累计，返回 (q, S_bottom(q))
    """
    if scores.size == 0:
        return np.array([]), np.array([])
    s = np.asarray(scores, dtype=np.float64)
    total = s.sum()
    if total <= 0:
        return np.array([]), np.array([])
    if mode == "top":
        s_sorted = np.sort(s)[::-1]
    else:
        s_sorted = np.sort(s)
    csum = np.cumsum(s_sorted)
    share = csum / total
    frac = np.arange(1, s.size + 1, dtype=np.float64) / s.size
    return frac, share

def top_share_at_percent(scores: np.ndarray, p: float) -> float:
    """
    前 p 比例（按分数降序）的累计分数占比。p in (0,1]
    """
    if scores.size == 0:
        return float("nan")
    s = np.sort(np.asarray(scores, dtype=np.float64))[::-1]
    n = s.size
    total = s.sum()
    if total <= 0:
        return float("nan")
    m = max(1, int(math.ceil(p * n)))
    return float(s[:m].sum() / total)

def fit_power_on_bottom_share(frac: np.ndarray, share: np.ndarray) -> Tuple[float, float]:
    """
    用 S_bottom(q) ≈ q^alpha 拟合，返回 (alpha, R2)。
    采用 1D 网格搜索（无 SciPy 依赖）。
    """
    if frac.size == 0 or share.size == 0:
        return float("nan"), float("nan")
    mask = frac > 0
    x = frac[mask]
    y = share[mask]
    alphas = np.linspace(0.05, 5.0, 1996)
    best_alpha = None
    best_sse = None
    y_mean = y.mean()
    tss = float(np.sum((y - y_mean) ** 2)) + 1e-12
    for a in alphas:
        y_pred = np.power(x, a)
        sse = float(np.sum((y - y_pred) ** 2))
        if (best_sse is None) or (sse < best_sse):
            best_sse = sse
            best_alpha = float(a)
    r2 = 1.0 - (best_sse / tss)
    return best_alpha, float(r2)

def plot_cumshare_top_with_fit(scores: np.ndarray, alpha: float, out_png: str, title_suffix: str = ""):
    """
    绘制顶部累计份额曲线以及幂函数拟合曲线：
      - 实线：S_top(p) 的经验曲线
      - 虚线：S_top_fit(p) = 1 - (1-p)^alpha
    """
    p, s_top = cumulative_share(scores, mode="top")
    if p.size == 0:
        return
    s_top_fit = 1.0 - np.power(1.0 - p, alpha)
    plt.figure(figsize=(6,4))
    plt.plot(p, s_top, label="Empirical S_top(p)", color="#1f77b4", lw=2)
    plt.plot(p, s_top_fit, label=f"Fit: 1-(1-p)^{alpha:.3f}", color="#d62728", lw=2, ls="--")
    plt.xlabel("Top fraction p")
    plt.ylabel("Cumulative share of total score")
    ttl = "Cumulative Share (Top p)" + (f" - {title_suffix}" if title_suffix else "")
    plt.title(ttl)
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_png, dpi=150)
    plt.close()

def main():
    ensure_dir(OUTPUT_DIR)

    data = load_pkl(PKL_PATH)
    cfg = data.get("config", {})
    pos_list = data.get("positive_influencers", []) or []
    neg_list = data.get("negative_influencers", []) or []

    seq_len = int(cfg.get("SEQ_LENGTH", SEQ_LENGTH_DEFAULT))
    npy_path_in_pkl = cfg.get("DATA_NPY_PATH", None)
    npy_path = DATA_NPY_PATH if DATA_NPY_PATH else npy_path_in_pkl
    if npy_path is None:
        raise ValueError("未找到 NPY 路径，请在脚本顶部 DATA_NPY_PATH 指定或 pkl 的 config 中包含 DATA_NPY_PATH。")

    tokenizer = load_tokenizer(TOKENIZER_PATH)
    arr = load_npy(npy_path)

    pos_scores = np.array([float(x["projection_score"]) for x in pos_list], dtype=np.float64) if pos_list else np.array([], dtype=np.float64)
    pos_indices = np.array([int(x["sample_index"]) for x in pos_list], dtype=np.int64) if pos_list else np.array([], dtype=np.int64)
    pos_losses = np.array([float(x.get("sum_loss", np.nan)) for x in pos_list], dtype=np.float64) if pos_list else np.array([], dtype=np.float64)

    neg_scores = np.array([float(x["projection_score"]) for x in neg_list], dtype=np.float64) if neg_list else np.array([], dtype=np.float64)
    neg_indices = np.array([int(x["sample_index"]) for x in neg_list], dtype=np.int64) if neg_list else np.array([], dtype=np.int64)
    neg_losses = np.array([float(x.get("sum_loss", np.nan)) for x in neg_list], dtype=np.float64) if neg_list else np.array([], dtype=np.float64)

    # 导出正向Top-100（按score降序）
    topk = min(100, len(pos_list))
    if topk > 0:
        order = np.argsort(pos_scores)[::-1]
        top_idx = order[:topk]

        top_rows_for_txt: List[str] = []
        top_rows_for_csv: List[Dict[str, Any]] = []

        for rank, i in enumerate(top_idx, start=1):
            s = pos_scores[i]
            idx = int(pos_indices[i])
            loss = float(pos_losses[i]) if not np.isnan(pos_losses[i]) else float("nan")

            row = arr[idx]  # shape [2049]
            input_tokens = row[:seq_len].astype(np.int64)
            text = decode_text(tokenizer, input_tokens)
            snippet = (text[:SNIPPET_CHARS] + ("..." if len(text) > SNIPPET_CHARS else "")) if isinstance(text, str) else ""

            block_k = idx // BLOCK_SIZE
            top_rows_for_txt.append(
                f"Rank {rank}\n"
                f"  sample_index: {idx} (block k={block_k}, range=[{block_k*BLOCK_SIZE}, {(block_k+1)*BLOCK_SIZE}))\n"
                f"  projection_score: {s:.6f}\n"
                f"  sum_loss: {loss:.6f}\n"
                f"  snippet: {snippet}\n"
                f"{'-'*80}\n"
            )
            top_rows_for_csv.append({
                "rank": rank,
                "sample_index": idx,
                "block_k": block_k,
                "projection_score": s,
                "sum_loss": loss,
                "snippet": snippet
            })

        txt_path = os.path.join(OUTPUT_DIR, "top100_positive.txt")
        with open(txt_path, "w", encoding="utf-8") as f:
            f.write("\n".join(top_rows_for_txt))
        csv_path = os.path.join(OUTPUT_DIR, "top100_positive.csv")
        save_csv(csv_path, top_rows_for_csv, header=["rank","sample_index","block_k","projection_score","sum_loss","snippet"])
        print(f"已生成: {txt_path}")
        print(f"已生成: {csv_path}")
    else:
        print("正向列表为空，跳过Top-100导出。")

    # ============ 分布统计与图表 ============
    ensure_dir(OUTPUT_DIR)

    pos_stats = summary_stats(pos_scores)
    neg_stats = summary_stats(neg_scores)

    hist_plot(pos_scores, "Positive Scores Histogram", "projection_score", os.path.join(OUTPUT_DIR, "pos_scores_hist.png"))
    hist_plot(neg_scores, "Negative Scores Histogram", "projection_score", os.path.join(OUTPUT_DIR, "neg_scores_hist.png"))
    ecdf_plot(pos_scores, "Positive Scores ECDF", "projection_score", os.path.join(OUTPUT_DIR, "pos_scores_ecdf.png"))
    ecdf_plot(neg_scores, "Negative Scores ECDF", "projection_score", os.path.join(OUTPUT_DIR, "neg_scores_ecdf.png"))

    scatter_plot(pos_indices, pos_scores, "Positive: Score vs Sample Index", "sample_index", "projection_score", os.path.join(OUTPUT_DIR, "pos_score_vs_index.png"), s=5, alpha=0.4)
    scatter_plot(neg_indices, neg_scores, "Negative: Score vs Sample Index", "sample_index", "projection_score", os.path.join(OUTPUT_DIR, "neg_score_vs_index.png"), s=5, alpha=0.4)

    def counts_by_block(idxs: np.ndarray) -> Dict[int, int]:
        if idxs.size == 0:
            return {}
        blocks = (idxs // BLOCK_SIZE).astype(int)
        uniq, cnts = np.unique(blocks, return_counts=True)
        return {int(k): int(v) for k, v in zip(uniq, cnts)}

    pos_block_counts = counts_by_block(pos_indices)
    neg_block_counts = counts_by_block(neg_indices)

    pos_block_rows = [{"block_k": k, "start": k*BLOCK_SIZE, "end": (k+1)*BLOCK_SIZE, "count": c} for k,c in sorted(pos_block_counts.items())]
    neg_block_rows = [{"block_k": k, "start": k*BLOCK_SIZE, "end": (k+1)*BLOCK_SIZE, "count": c} for k,c in sorted(neg_block_counts.items())]
    save_csv(os.path.join(OUTPUT_DIR, "pos_block_counts.csv"), pos_block_rows, header=["block_k","start","end","count"])
    save_csv(os.path.join(OUTPUT_DIR, "neg_block_counts.csv"), neg_block_rows, header=["block_k","start","end","count"])

    barplot_counts(pos_block_counts, "Positive Counts by Blocks (k = idx//102400)", "block k", "count",
                   os.path.join(OUTPUT_DIR, "pos_block_counts.png"))
    barplot_counts(neg_block_counts, "Negative Counts by Blocks (k = idx//102400)", "block k", "count",
                   os.path.join(OUTPUT_DIR, "neg_block_counts.png"))

    # ============ 新增：正向且 score>0 的过滤 + 密度图 + 累计占比(10/20/50%) + 幂函数拟合 ============
    pos_scores_pos = pos_scores[pos_scores > 0]
    if pos_scores_pos.size > 0 and np.sum(pos_scores_pos) > 0:
        density_png = os.path.join(OUTPUT_DIR, "pos_scores_pos_density.png")
        hist_plot(pos_scores_pos, "Positive (>0) Scores Density", "projection_score", density_png, bins=HIST_BINS, density=True)
        print(f"已生成: {density_png}")

        # 关注前 10% / 20% / 50% 样本的累计总分占比（按分数降序）
        share_10 = top_share_at_percent(pos_scores_pos, 0.10)
        share_20 = top_share_at_percent(pos_scores_pos, 0.20)
        share_50 = top_share_at_percent(pos_scores_pos, 0.50)
        print(f"Top10% share={share_10:.4f}, Top20% share={share_20:.4f}, Top50% share={share_50:.4f}")

        # 底部份额曲线用于幂函数拟合 S_bottom(q) ≈ q^alpha
        q, s_bottom = cumulative_share(pos_scores_pos, mode="bottom")
        alpha, r2 = fit_power_on_bottom_share(q, s_bottom)

        # 绘制顶部累计份额曲线及拟合（S_top(p) ≈ 1-(1-p)^alpha）
        cumshare_png = os.path.join(OUTPUT_DIR, "pos_scores_pos_cumshare_top.png")
        plot_cumshare_top_with_fit(pos_scores_pos, alpha, cumshare_png, title_suffix="Positive (>0)")
        print(f"已生成: {cumshare_png}")

        # 导出顶部累计份额曲线 CSV
        p_emp, s_top_emp = cumulative_share(pos_scores_pos, mode="top")
        rows = [{"p_top": float(p), "share": float(s)} for p, s in zip(p_emp, s_top_emp)]
        cumshare_csv = os.path.join(OUTPUT_DIR, "pos_scores_pos_cumshare_top.csv")
        save_csv(cumshare_csv, rows, header=["p_top", "share"])
        print(f"已生成: {cumshare_csv}")
    else:
        share_10 = share_20 = share_50 = float("nan")
        alpha = r2 = float("nan")
        print("正向且 score>0 的样本为空或总和为0，跳过密度/累计占比/拟合。")

    # 生成 summary.txt（附加新增统计：10/20/50%）
    summary_path = os.path.join(OUTPUT_DIR, "summary.txt")
    with open(summary_path, "a", encoding="utf-8") as f:
        f.write("\n" + "="*60 + "\n")
        f.write("Positive (>0) Score Concentration & Power Fit\n")
        f.write("="*60 + "\n")
        f.write(f"count_pos_gt0: {int(pos_scores_pos.size)}\n")
        f.write(f"sum_pos_gt0: {float(np.sum(pos_scores_pos)) if pos_scores_pos.size>0 else 'nan'}\n")
        f.write(f"top10_percent_share: {share_10 if share_10==share_10 else 'nan'}\n")
        f.write(f"top20_percent_share: {share_20 if share_20==share_20 else 'nan'}\n")
        f.write(f"top50_percent_share: {share_50 if share_50==share_50 else 'nan'}\n")
        f.write(f"power_fit_alpha (S_bottom(q)~q^alpha): {alpha if alpha==alpha else 'nan'}\n")
        f.write(f"power_fit_R2: {r2 if r2==r2 else 'nan'}\n")

    print(f"汇总已更新: {summary_path}")

if __name__ == "__main__":
    main()
