import json
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer
import argparse
import os

# 输出更像论文图的风格（与 plot_length_ngram 主图一致的审美）
matplotlib.use('Agg')
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['pdf.use14corefonts'] = False

# 设置字体
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 数据路径和模型路径
# path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/grpo-qwen3-8b-deepscaler-BASELINE_12388_test.jsonl"
# path="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_sep/Qwen3-4B-Base-valid-temp0.6_32768_test.jsonl"
# path="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-40step-valid_32768_test.jsonl"
model_path = "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base"
# model_name = "add1k-new-60steps-continue-40step-valid"
path="/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_dec/grpo-step430-valid-all_32768_test.jsonl"
model_name = "GRPO"
max_length = 32000

def load_data(file_path):
    """Load JSONL data"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

def tokenize_texts(texts, tokenizer):
    """Tokenize texts and return lengths"""
    lengths = []
    for text in texts:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        lengths.append(len(tokens))
    return lengths


def _extract_texts_and_correctness(data):
    """兼容不同字段名，提取生成文本与正确性标签（如果有）。"""
    texts = []
    correctness = []
    has_correctness = False

    for item in data:
        if not isinstance(item, dict):
            continue
        txt = (
            item.get("generated_text")
            or item.get("output")
            or item.get("response")
            or item.get("text")
        )
        if txt is None:
            continue
        texts.append(txt)

        if "correctness" in item:
            has_correctness = True
            correctness.append(bool(item.get("correctness")))

    if not has_correctness:
        correctness = None

    return texts, correctness


def plot_length_frequency_spectrum(lengths, save_path, model_name, bin_width=256, max_len=None, fit_exponential=True):
    """
    画 Length-Frequency Spectrum：
    - X: Sequence Length (L)
    - Y: Frequency (log scale)
    可选：对 log(freq) vs L 做线性拟合，展示斜率（~ -λ）
    """
    if not lengths:
        print("没有可用的长度数据，跳过绘图")
        return

    if max_len is None:
        max_len = max(lengths)
    max_len = int(max_len)

    # bins：用固定 bin_width，频谱更平滑、更像“长度-频率谱”
    bin_width = int(bin_width)
    bin_width = max(1, bin_width)
    bins = np.arange(0, max_len + bin_width, bin_width)

    counts, edges = np.histogram(lengths, bins=bins)
    centers = (edges[:-1] + edges[1:]) / 2

    # 主图配色（沿用 plot_length_ngram：蓝/紫/红体系）
    COLOR_SPECTRUM = '#00468B'  # 深蓝
    COLOR_FIT = '#AE1029'       # 绯红
    BG = '#FAFAFA'

    fig, ax = plt.subplots(1, 1, figsize=(12, 7))

    # 统一边框/背景/网格（对齐主图风格）
    ax.set_facecolor(BG)
    for side in ['left', 'bottom', 'top', 'right']:
        ax.spines[side].set_linewidth(1.8)
        ax.spines[side].set_color('#666666')
        ax.spines[side].set_visible(True)
    ax.grid(True, axis='both', alpha=0.3, color='#CCCCCC', linewidth=0.8, linestyle='-')
    ax.set_axisbelow(True)
    ax.tick_params(axis='both', labelcolor='#333333', labelsize=16, length=6, width=1.5)

    # log y
    ax.set_yscale("log")

    # 主曲线：粗线 + 半透明阴影带（类似主图“置信区间”视觉）
    counts_array = np.array(counts, dtype=float)
    window = min(5, len(counts) // 10) if len(counts) > 10 else 1
    if window > 1:
        std_approx = np.convolve(
            np.abs(np.diff(np.concatenate([[counts_array[0]], counts_array]))),
            np.ones(window) / window,
            mode='same'
        ) * 1.2
        lower = np.maximum(1.0, counts_array - std_approx)  # log 轴下限不能 <=0
        upper = np.maximum(1.0, counts_array + std_approx)
        ax.fill_between(centers, lower, upper, color=COLOR_SPECTRUM, alpha=0.20, zorder=1)

    ln_spec = ax.plot(centers, counts_array, linewidth=3.0, color=COLOR_SPECTRUM, label='Length Spectrum', zorder=4)[0]

   
    ax.set_xlabel(r"Sequence Length ($L$)", fontsize=20, fontweight='bold', labelpad=10)
    ax.set_ylabel("Frequency (Log Scale)", fontsize=20, fontweight='bold', labelpad=10)
    ax.set_title(
        f"{model_name}: The Length-Frequency Spectrum ($\\log P(L)=a-\\lambda L$)",
        fontsize=24,
        fontweight="bold",
        pad=18,
    )

    # 可选拟合：log(freq) = a - λ L（红色虚线，保持主图语义）
    ln_fit = None
    if fit_exponential:
        mask = (counts > 0)
        x = centers[mask]
        y = counts[mask]
        if len(x) >= 2:
            # 只拟合到 max_length（避免截断段影响），同时丢弃很小的长度区间（可按需调整）
            fit_mask = np.ones_like(x, dtype=bool)
            if max_length is not None:
                fit_mask &= (x <= max_length)
            fit_mask &= (x >= bin_width)  # 避免极短长度点主导

            xf = x[fit_mask]
            yf = y[fit_mask]
            if len(xf) >= 2:
                coef = np.polyfit(xf, np.log(yf), 1)  # log(y) = m x + b
                m, b = float(coef[0]), float(coef[1])
                lam = -m

                y_fit = np.exp(m * xf + b)
                ln_fit = ax.plot(xf, y_fit, linestyle="--", linewidth=3.0, color=COLOR_FIT, alpha=0.95, label='Exp fit', zorder=5)[0]
                ax.text(
                    0.02,
                    0.96,
                    rf"Fit: $\log P(L)=a-\lambda L$" "\n" rf"$\lambda \approx {lam:.4e}$",
                    transform=ax.transAxes,
                    va="top",
                    ha="left",
                    fontsize=18,
                    bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.95, edgecolor="#999999", linewidth=1.2),
                    family='monospace'
                )

    # 让细节更明显：紧凑 y 轴下限
    ax.set_ylim(bottom=1.0)

    # 底部大图例（对齐主图）
    handles = [ln_spec]
    if ln_fit is not None:
        handles.append(ln_fit)
    
    labels = [h.get_label() for h in handles]
    legend = fig.legend(
        handles, labels,
        loc='lower center',
        bbox_to_anchor=(0.5, -0.1),
        ncol=min(3, len(handles)),
        frameon=True,
        framealpha=0.95,
        edgecolor='#888888',
        fancybox=True,
        shadow=False,
        prop={'weight': 'bold', 'size': 18},
    )
    legend.get_frame().set_linewidth(1.5)

    plt.subplots_adjust(top=0.86, bottom=0.18, left=0.10, right=0.97)
    plt.savefig(
        save_path,
        format='pdf',
        dpi=600,
        bbox_inches="tight",
        metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
    )
    print(f"已保存图像: {save_path}")
    plt.close(fig)


def plot_length_frequency_spectrum_multi(
    run_name2lengths: dict,
    save_path: str,
    max_length_tokens: int,
    bin_width: int = 256,
    fit_exponential: bool = True,
):
    """多 run 版本：所有 run 叠在同一张图里（统一坐标范围）。"""
    run_items = [(k, v) for k, v in run_name2lengths.items() if v]
    if not run_items:
        print("没有可用的长度数据，跳过绘图")
        return

    # 统一 bins（用 max_length_tokens 或全局最大长度）
    all_lengths = [x for _, ls in run_items for x in ls]
    max_len = int(max_length_tokens) if max_length_tokens is not None else int(max(all_lengths))
    bin_width = max(1, int(bin_width))
    bins = np.arange(0, max_len + bin_width, bin_width)
    centers = (bins[:-1] + bins[1:]) / 2

    # 配色：与主图一致的红/紫/蓝系，补橙/绿用于更多 run
    RUN_COLORS = ['#00468B', '#9B59B6', '#AE1029', '#FF7F0E', '#2CA02C']
    COLOR_FIT_FALLBACK = '#AE1029'
    BG = '#FAFAFA'

    # 先算统一 y 上限，避免每个子图尺度不同
    global_max = 1.0
    for _, ls in run_items:
        counts, _ = np.histogram(ls, bins=bins)
        if counts.size > 0:
            global_max = max(global_max, float(np.max(counts)))

    fig, ax = plt.subplots(1, 1, figsize=(12, 7))

    # 统一风格（对齐 plot_length_ngram 主图）
    ax.set_facecolor(BG)
    for side in ['left', 'bottom', 'top', 'right']:
        ax.spines[side].set_linewidth(1.8)
        ax.spines[side].set_color('#666666')
        ax.spines[side].set_visible(True)
    ax.grid(True, axis='both', alpha=0.3, color='#CCCCCC', linewidth=0.8, linestyle='-')
    ax.set_axisbelow(True)
    ax.tick_params(axis='both', labelcolor='#333333', labelsize=16, length=6, width=1.5)
    ax.set_yscale("log")
    ax.set_ylim(1.0, global_max * 1.2)
    ax.set_xlim(0, max_len)

    ax.set_xlabel(r"Sequence Length ($L$)", fontsize=20, fontweight='bold', labelpad=10)
    ax.set_ylabel("Frequency (Log Scale)", fontsize=20, fontweight='bold', labelpad=10)
    ax.set_title(
        "The Length-Frequency Spectrum ($\\log P(L)=a-\\lambda L$)",
        fontsize=24,
        fontweight="bold",
        pad=18,
    )

    # max_length 竖线（只画一次）
    # ln_max = ax.axvline(x=max_length_tokens, color="black", linestyle="--", alpha=0.7, linewidth=2.0, label='max_length')

    # 画每个 run 的谱线 + 拟合线（拟合线不进 legend，避免太乱）
    lambda_lines = []
    run_handles = []
    run_labels = []

    for i, (run_name, ls) in enumerate(run_items):
        color = RUN_COLORS[i % len(RUN_COLORS)]
        counts, _ = np.histogram(ls, bins=bins)
        counts_array = np.array(counts, dtype=float)
        # 让曲线更顺滑：在 log(count) 空间做滑动平均，再 exp 回来
        # 这样在 log-y 图上视觉更平滑，同时保留指数衰减形状
        smooth_window = 11  # 越大越平滑（保持奇数）
        # 不要超过序列长度；同时根据数据长度给一点自适应
        target = max(3, (len(counts_array) // 10) * 2 + 1)  # odd
        smooth_window = min(smooth_window, target)
        if smooth_window >= len(counts_array):
            smooth_window = len(counts_array) if (len(counts_array) % 2 == 1) else max(1, len(counts_array) - 1)
        if smooth_window < 3:
            counts_plot = np.maximum(1.0, counts_array)
        else:
            if smooth_window % 2 == 0:
                smooth_window += 1
            kernel = np.ones(smooth_window, dtype=float) / float(smooth_window)
            log_counts = np.log(np.maximum(1.0, counts_array))
            log_counts_smooth = np.convolve(log_counts, kernel, mode='same')
            counts_plot = np.maximum(1.0, np.exp(log_counts_smooth))

        # 阴影带：多 run 情况下容易糊，这里更轻一点
        window = min(5, len(counts) // 10) if len(counts) > 10 else 1
        if window > 1:
            std_approx = np.convolve(
                np.abs(np.diff(np.concatenate([[counts_array[0]], counts_array]))),
                np.ones(window) / window,
                mode='same'
            ) * 1.0
            lower = np.maximum(1.0, counts_array - std_approx)
            upper = np.maximum(1.0, counts_array + std_approx)
            ax.fill_between(centers, lower, upper, color=color, alpha=0.08, zorder=1)

        ln = ax.plot(centers, counts_plot, linewidth=3.0, color=color, label=run_name, zorder=4)[0]
        run_handles.append(ln)
        run_labels.append(run_name)

        if fit_exponential:
            mask = (counts > 0)
            x = centers[mask]
            y = counts[mask]
            if len(x) >= 2:
                fit_mask = np.ones_like(x, dtype=bool)
                if max_length_tokens is not None:
                    fit_mask &= (x <= max_length_tokens)
                fit_mask &= (x >= bin_width)
                xf = x[fit_mask]
                yf = y[fit_mask]
                if len(xf) >= 2:
                    coef = np.polyfit(xf, np.log(yf), 1)
                    m, b = float(coef[0]), float(coef[1])
                    lam = -m
                    y_fit = np.exp(m * xf + b)
                    ax.plot(
                        xf, y_fit,
                        linestyle="--",
                        linewidth=3.0,
                        color=color if i < len(RUN_COLORS) else COLOR_FIT_FALLBACK,
                        alpha=0.95,
                        label="_nolegend_",
                        zorder=5
                    )
                    lambda_lines.append((run_name, lam))

    # λ 汇总框（右上角）
    if lambda_lines:
        text = "\n".join([f"{name}: λ≈{lam:.3e}" for name, lam in lambda_lines])
        ax.text(
            0.98, 0.96,
            text,
            transform=ax.transAxes,
            va="top",
            ha="right",
            fontsize=16,
            bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.95, edgecolor="#999999", linewidth=1.2),
            family='monospace'
        )

    # legend：run 曲线 + max_length，再加一个“Exp fit(--)”说明
    from matplotlib.lines import Line2D
    fit_handle = Line2D([0], [0], color="#333333", linestyle="--", linewidth=3.0, label="Exp fit (--)")
    handles = run_handles + [fit_handle]
    labels = run_labels + [fit_handle.get_label()]

    legend = fig.legend(
        handles,
        labels,
        loc='lower center',
        bbox_to_anchor=(0.5, -0.10),
        ncol=min(4, len(handles)),
        frameon=True,
        framealpha=0.95,
        edgecolor='#888888',
        fancybox=True,
        shadow=False,
        prop={'weight': 'bold', 'size': 18},
    )
    legend.get_frame().set_linewidth(1.5)

    plt.subplots_adjust(top=0.86, bottom=0.14, left=0.10, right=0.97)
    plt.savefig(
        save_path,
        format='pdf',
        dpi=600,
        bbox_inches="tight",
        metadata={'Creator': 'matplotlib', 'Producer': 'matplotlib'},
    )
    print(f"已保存图像: {save_path}")
    plt.close(fig)

def plot_length_distribution(data, tokenizer):
    """兼容旧函数名：按 Length-Frequency Spectrum 的逻辑绘图（log y 轴）"""
    generated_texts, correctness_labels = _extract_texts_and_correctness(data)
    
    # Tokenize texts
    print("Tokenizing texts...")
    lengths = tokenize_texts(generated_texts, tokenizer)

    save_path = f"{model_name}_length_frequency_spectrum.png"
    plot_length_frequency_spectrum(
        lengths=lengths,
        save_path=save_path,
        model_name=model_name,
        bin_width=256,
        max_len=max_length,
        fit_exponential=True,
    )

def main():
    """Main function"""
    try:
        parser = argparse.ArgumentParser(description="Plot length-frequency spectrum (log scale)")
        parser.add_argument("--tokenizer", type=str, default=model_path, help="Tokenizer/model path")
        parser.add_argument("--inputs", type=str, nargs="+", default=[path], help="One or more JSONL files")
        parser.add_argument("--names", type=str, nargs="*", help="Optional run names (same order as inputs)")
        parser.add_argument("--max-length", type=int, default=max_length, help="max_length tokens (vertical line & fit cutoff)")
        parser.add_argument("--bin-width", type=int, default=256, help="Histogram bin width (tokens)")
        parser.add_argument("--output", type=str, default=None, help="Output file path (pdf)")
        args = parser.parse_args()

        print("Loading tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)

        # 读取多 run
        run_name2lengths = {}
        for idx, p in enumerate(args.inputs):
            print(f"Loading data: {p}")
            data = load_data(p)
            texts, _ = _extract_texts_and_correctness(data)
            print(f"Tokenizing {len(texts)} samples...")
            lengths = tokenize_texts(texts, tokenizer)

            if args.names and idx < len(args.names):
                run_name = args.names[idx]
            else:
                run_name = os.path.splitext(os.path.basename(p))[0]
            run_name2lengths[run_name] = lengths

        if args.output:
            save_path = args.output
        else:
            if len(run_name2lengths) == 1:
                only_name = next(iter(run_name2lengths.keys()))
                save_path = f"{only_name}_length_frequency_spectrum.pdf"
            else:
                save_path = "length_frequency_spectrum_multi.pdf"

        if len(run_name2lengths) == 1:
            run_name, lengths = next(iter(run_name2lengths.items()))
            plot_length_frequency_spectrum(
                lengths=lengths,
                save_path=save_path,
                model_name=run_name,
                bin_width=args.bin_width,
                max_len=args.max_length,
                fit_exponential=True,
            )
        else:
            plot_length_frequency_spectrum_multi(
                run_name2lengths=run_name2lengths,
                save_path=save_path,
                max_length_tokens=args.max_length,
                bin_width=args.bin_width,
                fit_exponential=True,
            )
        
    except Exception as e:
        print(f"Error occurred: {e}")
        print("Please check if the file path and model path are correct")

if __name__ == "__main__":
    main()