
import pandas as pd
import json
import os
from typing import Dict, List, Any, Optional, Tuple
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import torch
import re
import difflib
import unicodedata


def load_jsonl(file_path: str) -> List[Dict]:
    """加载JSONL文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data


def load_parquet(file_path: str) -> pd.DataFrame:
    """加载parquet文件"""
    return pd.read_parquet(file_path)

# def _normalize_prompt_text(text: Any) -> str:
#     """尽量稳定地将prompt归一化为字符串，用于匹配key。"""
#     if text is None:
#         return ""
#     if isinstance(text, str):
#         s = text.strip()
#         # 统一空白符，提升“顺序滑动匹配”的鲁棒性
#         s = re.sub(r"\s+", " ", s)
#         return s
#     # 有些jsonl可能把prompt写成list/dict等，尽量转成字符串
#     s = str(text).strip()
#     s = re.sub(r"\s+", " ", s)
#     return s


def _extract_prompt_from_jsonl_item(jsonl_item: Dict[str, Any]) -> str:
    """从jsonl条目里提取prompt文本（兼容 prompt / input 字段）。"""
    prompt_text = jsonl_item.get('prompt', '') or jsonl_item.get('input', '')
    # breakpoint()
    prompt_text = prompt_text.split('user')[-1].split('assistant')[0]
    prompt_text = prompt_text.split(
        '<|im_start|>user')[-1].split('<|im_end|> <|im_start|>assistant')[0]
    prompt_text = prompt_text.split(
        "User: ")[-1].split("Assistant:")[0].strip()
    prompt_text = prompt_text.replace("<|endoftext|>", "")
    prompt_text = prompt_text.replace(
        "You must put your answer inside \\boxed{} and Your final answer will be extracted automatically by the \\boxed{} tag.", "").strip()
    # breakpoint()
    return prompt_text


def _extract_responses_from_jsonl_item(jsonl_item: Dict[str, Any]) -> List[str]:
    """
    从jsonl条目里提取response文本列表，兼容：
    - output / generated_text 为 str
    - output / generated_text 为 List[str]（默认只取第一个，避免把一次generation当成多次）
    - 其它类型尽量转成 str
    """
    out = jsonl_item.get('output', None)
    if out is None or out == "":
        out = jsonl_item.get('generated_text', None)

    if out is None:
        return [""]
    if isinstance(out, list):
        # 重要：很多生成脚本里一条jsonl代表“一次生成”，即使返回了多个候选，
        # 这里也只取第一个，避免把一行jsonl计成多条response导致“8遍”对不齐。
        if len(out) == 0:
            return [""]
        first = out[0]
        s = str(first).strip() if first is not None else ""
        s = s.replace("<|endoftext|>", "")
        return [s] if s != "" else [""]
    if isinstance(out, str):
        out = out.replace("<|endoftext|>", "")
        return [out] if out.strip() != "" else [""]
    s = str(out).strip()
    s = s.replace("<|endoftext|>", "")
    return [s] if s != "" else [""]


def _extract_score_from_jsonl_item(jsonl_item: Dict[str, Any]) -> float:
    """兼容 score / correctness，默认缺失为0。"""
    val = jsonl_item.get("score", None)
    if val is None:
        val = jsonl_item.get("correctness", 0)
    try:
        return float(val)
    except Exception:
        return 0.0


def _approx_prompt_match(a: str, b: str, min_len: int = 20, ratio_threshold: float = 0.8) -> bool:
    """
    “大致匹配”两个 prompt：
    - 先做完全匹配
    - 再做互为子串（避免很短字符串误匹配）
    - 最后用 difflib 的相似度兜底
    - 如果都失败，尝试去掉所有非字母数字字符后再对比
    """
    # a = _normalize_prompt_text(a)
    # b = _normalize_prompt_text(b)
    if not a or not b:
        return False
    if a == b:
        return True
    if len(a) >= min_len and len(b) >= min_len:
        if a in b or b in a:
            return True
    # 相似度兜底（对小的格式差异更稳）
    r = difflib.SequenceMatcher(None, a, b).ratio()
    if r >= ratio_threshold:
        return True

    # 最后的保底：只对比字母和数字
    def _clean(s):
        return re.sub(r'[^a-zA-Z0-9]', '', s).lower()

    a_clean, b_clean = _clean(a), _clean(b)
    if len(a_clean) > min_len and len(b_clean) > min_len:
        if a_clean == b_clean or a_clean in b_clean or b_clean in a_clean:
            return True

    return False


def _build_jsonl_prompt_groups(
    jsonl_data: List[Dict[str, Any]],
    max_group_size: int = 8,
) -> List[Dict[str, Any]]:
    """
    将 jsonl 按“连续相同 prompt”聚合成组，每组最多 max_group_size 条。
    这样可兼容：
    - 同一个 prompt 生成 8 遍（通常连续）
    - 同一个 prompt 只生成 1 遍
    - 生成过程中缺失几条（组大小 < 8）
    """
    groups: List[Dict[str, Any]] = []
    cur_prompt: Optional[str] = None
    cur_items: List[Tuple[int, Dict[str, Any]]] = []

    def _flush():
        nonlocal cur_prompt, cur_items
        if cur_prompt is not None and len(cur_items) > 0:
            groups.append({"prompt_norm": cur_prompt, "items": cur_items})
        cur_prompt = None
        cur_items = []

    for idx, it in enumerate(jsonl_data):
        p = _extract_prompt_from_jsonl_item(it)

        if not p:
            # 空 prompt：直接跳过，不要打断对齐（缺失/脏数据常见）
            continue
        # p = _normalize_prompt_text(p)
        if cur_prompt is None:
            cur_prompt = p
            cur_items = [(idx, it)]
            continue

        if p == cur_prompt and len(cur_items) < max_group_size:
            cur_items.append((idx, it))
        else:
            _flush()
            cur_prompt = p
            cur_items = [(idx, it)]

    _flush()
    return groups


def batch_calculate_response_lengths(responses: List[str], tokenizer, special_token_ids: set, device: torch.device, batch_size: int = 32) -> List[int]:
    """
    计算response的token长度

    Args:
        responses: response文本列表
        tokenizer: tokenizer对象
        special_token_ids: 未使用
        device: 未使用
        batch_size: 未使用

    Returns:
        response长度列表
    """
    response_lengths = []
    total_responses = len(responses)

    print(f"开始计算 {total_responses} 个response的长度...")

    # 获取pad_token_id，用于过滤
    pad_token_id = tokenizer.pad_token_id

    for i, response in enumerate(responses):
        if (i + 1) % 1000 == 0 or (i + 1) == total_responses:
            print(
                f"处理进度: {i + 1}/{total_responses} ({(i + 1)/total_responses*100:.1f}%)")

        # Unicode规范化：使用NFKC（Normalization Form Compatibility Composition）
        # 这会将兼容字符转换为规范形式，确保encode/decode一致性
        normalized_response = unicodedata.normalize('NFKC', response)

        # 直接encode，获取token数量
        tokens = tokenizer.encode(
            normalized_response, add_special_tokens=False)

        # # 过滤pad_token（虽然单独encode不应该产生pad_token，但保险起见）
        # if pad_token_id is not None:
        #     tokens = [t for t in tokens if t != pad_token_id]

        response_lengths.append(len(tokens))

    return response_lengths


def process_data(parquet_file: str, jsonl_file: str, output_file: str, batch_size: int = 32, filter_length: int = 7192, tokenizer=None):
    """
    处理deepscaler.parquet和测试结果jsonl文件，生成新的parquet文件

    Args:
        parquet_file: 原始deepscaler.parquet文件路径
        jsonl_file: 测试结果jsonl文件路径  
        output_file: 输出parquet文件路径
        batch_size: 批量处理大小，默认32
    """

    tokenizer.pad_token = tokenizer.eos_token
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    # 获取special token IDs
    special_token_ids = set()
    if tokenizer.bos_token_id is not None:
        special_token_ids.add(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        special_token_ids.add(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        special_token_ids.add(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        special_token_ids.add(tokenizer.unk_token_id)
    # 加载数据
    print("正在加载parquet文件...")
    df_parquet = load_parquet(parquet_file)
    print(f"Parquet文件包含 {len(df_parquet)} 行数据")

    print("正在加载jsonl文件...")
    jsonl_data = load_jsonl(jsonl_file)
    print(f"JSONL文件包含 {len(jsonl_data)} 行数据")
    # 处理数据
    matched_data = []  # 存储匹配成功的数据结构
    matched_count = 0
    unmatched_count = 0
    batch_responses = []  # 批量存储所有responses用于批量计算长度（可能每个prompt对应多条）

    # 按“连续相同prompt”将 jsonl 聚合成组：每组最多8条，但允许不足8条（缺失/中断）
    jsonl_groups = _build_jsonl_prompt_groups(jsonl_data, max_group_size=1)
    print(f"创建了jsonl顺序组：groups={len(jsonl_groups)} (max_group_size=1)")

    if len(jsonl_groups) > 0:
        group_sizes = np.array([len(g["items"])
                               for g in jsonl_groups], dtype=np.int32)
        uniq_sizes, uniq_freqs = np.unique(group_sizes, return_counts=True)
        print("\n=== jsonl按连续prompt聚合：组大小分布 ===")
        print(
            f"组大小统计: min={group_sizes.min()}, median={np.median(group_sizes):.0f}, mean={group_sizes.mean():.3f}, max={group_sizes.max()}")
        print(
            f"按(组大小 -> 组数量)汇总: {list(zip(uniq_sizes.tolist(), uniq_freqs.tolist()))}")

    # 2) 检查jsonl里是否存在“单行多候选”(output为list且长度>1) —— 我们只取第一个
    multi_candidate_lines = 0
    empty_response_lines = 0
    for it in jsonl_data:
        out = it.get('output', None)
        if out is None or out == "":
            out = it.get('generated_text', None)

        if isinstance(out, list) and len(out) > 1:
            multi_candidate_lines += 1
        extracted = _extract_responses_from_jsonl_item(it)
        if len(extracted) == 0:
            empty_response_lines += 1
    print("\n=== jsonl response字段诊断 ===")
    print(
        f"output/generated_text 为 list 且长度>1 的行数: {multi_candidate_lines} （已默认只取第一个）")
    print(f"提取不到有效response的行数(空/缺失): {empty_response_lines}")

    # === 顺序对齐匹配 ===
    # 对每个 parquet 样本，从当前 jsonl_group 指针开始向下滑动查找，直到 prompt 大致匹配。
    # 这样能处理：jsonl 中缺失若干条导致"index 不对应"的情况。
    group_ptr = 0
    total_skipped_groups = 0
    # 可调参数：越大越“能追上缺失”，但也越慢；一般缺失不多的话 200~500 足够
    max_lookahead_groups = 500
    # 可调参数：prompt“大致匹配”阈值（越大越严格）
    approx_ratio_threshold = 0.88

    print("开始顺序对齐匹配(parquet顺序 <-> jsonl组顺序)...")
    parquet_iter = df_parquet.iterrows()
    parquet_rows = list(parquet_iter)
    for row_idx, (idx, row) in enumerate(parquet_rows):
        # 获取原始数据
        data_source = row.get('data_source', '')
        prompt = row.get('prompt', [])
        ability = row.get('ability', 'math')
        reward_model = row.get('reward_model', {})
        extra_info = row.get('extra_info', {})

        # 提取原始prompt内容
        original_prompt = ''
        if prompt:
            if isinstance(prompt, np.ndarray) and len(prompt) > 0:
                original_prompt = prompt[0].get('content', '') if isinstance(
                    prompt[0], dict) else str(prompt[0])
            elif isinstance(prompt, list) and len(prompt) > 0:
                original_prompt = prompt[0].get('content', '') if isinstance(
                    prompt[0], dict) else str(prompt[0])
            else:
                original_prompt = str(prompt)

        # 查找对应的测试结果
        # original_prompt_norm = _normalize_prompt_text(original_prompt)
        original_prompt_norm = original_prompt
        matched_prompt_norm: Optional[str] = None
        matched_items: Optional[List[Tuple[int, Dict[str, Any]]]] = None

        if original_prompt_norm and group_ptr < len(jsonl_groups):
            search_ptr = group_ptr
            looked = 0
            while search_ptr < len(jsonl_groups) and looked <= max_lookahead_groups:
                cand_prompt_norm = jsonl_groups[search_ptr]["prompt_norm"]
                if _approx_prompt_match(original_prompt_norm, cand_prompt_norm, ratio_threshold=approx_ratio_threshold):
                    matched_prompt_norm = cand_prompt_norm
                    matched_items = jsonl_groups[search_ptr]["items"]
                    # 更新指针到下一组（顺序消费）
                    total_skipped_groups += max(0, search_ptr - group_ptr)
                    group_ptr = search_ptr + 1
                    break
                search_ptr += 1
                looked += 1

        # 如果当前行找不到匹配，检查下一行parquet是否能匹配当前的group_ptr位置
        # 如果能，说明当前行是parquet中多出来的，应该跳过
        if not matched_items and original_prompt_norm and group_ptr < len(jsonl_groups):
            if row_idx + 1 < len(parquet_rows):
                next_idx, next_row = parquet_rows[row_idx + 1]
                next_prompt = next_row.get('prompt', [])
                next_original_prompt = ''
                if next_prompt:
                    if isinstance(next_prompt, np.ndarray) and len(next_prompt) > 0:
                        next_original_prompt = next_prompt[0].get('content', '') if isinstance(
                            next_prompt[0], dict) else str(next_prompt[0])
                    elif isinstance(next_prompt, list) and len(next_prompt) > 0:
                        next_original_prompt = next_prompt[0].get('content', '') if isinstance(
                            next_prompt[0], dict) else str(next_prompt[0])
                    else:
                        next_original_prompt = str(next_prompt)
                # next_prompt_norm = _normalize_prompt_text(next_original_prompt)
                next_prompt_norm = next_original_prompt

                # 检查下一行是否能匹配当前的group_ptr位置
                if next_prompt_norm and group_ptr < len(jsonl_groups):
                    cand_prompt_norm = jsonl_groups[group_ptr]["prompt_norm"]
                    if _approx_prompt_match(next_prompt_norm, cand_prompt_norm, ratio_threshold=approx_ratio_threshold):
                        # 下一行能匹配，说明当前行确实是parquet中多出来的，跳过
                        unmatched_count += 1
                        if unmatched_count <= 10:
                            preview = (original_prompt_norm[:80] + "...") if len(
                                original_prompt_norm) > 80 else original_prompt_norm
                            print(
                                f"Skipping parquet idx={idx} (next row can match current jsonl position), prompt={preview}")
                        continue

        if matched_items:
            # 一个 parquet row 对应一个 prompt；jsonl 里可能有多条生成结果（如8遍）
            responses: List[str] = []
            scores: List[float] = []
            for _, it in matched_items:
                responses.extend(_extract_responses_from_jsonl_item(it))
                scores.append(_extract_score_from_jsonl_item(it))

            # 去掉空 response
            responses = [r for r in responses]

            if not responses:
                print(
                    f"Warning: matched prompt but no valid responses at parquet idx {idx}")
                # continue

            matched_count += 1

            # 记录本条样本在 batch_responses 中的范围，便于后续按prompt聚合平均长度
            start = len(batch_responses)
            batch_responses.extend(responses)
            end = len(batch_responses)

            matched_item = {
                'data_source': data_source,
                'original_prompt': original_prompt_norm,
                'ability': ability,
                'reward_model': reward_model,
                'extra_info': extra_info,
                'matched_prompt_norm': matched_prompt_norm,
                'responses_count': len(responses),
                'scores': scores,
                'resp_slice': (start, end),
            }
            matched_data.append(matched_item)
        else:
            unmatched_count += 1
            if unmatched_count <= 20:  # 稍微多打印一点，方便排查
                preview_p = (original_prompt_norm[:80] + "...") if len(
                    original_prompt_norm) > 80 else original_prompt_norm
                print(f"No matching jsonl group found for parquet idx={idx}")
                print(f"  Parquet prompt: {preview_p}")
                if group_ptr < len(jsonl_groups):
                    cand_p = jsonl_groups[group_ptr]["prompt_norm"]
                    preview_j = (
                        cand_p[:80] + "...") if len(cand_p) > 80 else cand_p
                    print(
                        f"  JSONL current group_ptr={group_ptr} prompt: {preview_j}")
            continue

    # 批量计算所有response的长度
    print(f"\n开始批量计算 {len(batch_responses)} 个response的长度...")
    if batch_responses:
        flat_response_lengths = batch_calculate_response_lengths(
            batch_responses, tokenizer, special_token_ids, device, batch_size=batch_size
        )
        print(f"批量计算完成，共计算了 {len(flat_response_lengths)} 个response长度")
    else:
        print("没有找到任何有效的response")
        return

    # 构建最终数据，使用预计算的长度
    print("开始构建最终数据...")
    processed_data = []
    target_lengths = []
    response_lengths = []  # 每个prompt的平均response长度（兼容1遍/8遍）
    for i, matched_item in enumerate(matched_data):
        start, end = matched_item['resp_slice']
        per_prompt_lengths = flat_response_lengths[start:end]
        if not per_prompt_lengths:
            # 理论上不应该发生（前面过滤过），保险起见
            continue
        avg_response_length = float(np.mean(per_prompt_lengths))
        # 用于 num_tokens/绘图的整数长度（平均值四舍五入）
        response_length = int(round(avg_response_length))
        response_lengths.append(response_length)

        # 计算num_tokens：如果response_length + 1000 > filter_length，则设置为filter_length，否则为response_length + 1000
        # if response_length + 1000 > filter_length:
        #     num_tokens = filter_length
        #     print(f"Response length {response_length} + 1000 > filter_length ({filter_length}), setting num_tokens to {filter_length}, item: {i}")
        # else:
        #     num_tokens = response_length + 1000

        if response_length > filter_length:
            num_tokens = filter_length + 1000
            print(
                f"Response length {response_length} + 1000 > filter_length ({filter_length}), setting num_tokens to {filter_length} + 1000, item: {i}")
        else:
            num_tokens = response_length + 1000

        # if response_length > filter_length + 1000:
        #     num_tokens = filter_length + 1000
        #     print(f"Response length {response_length} + 1000 > filter_length ({filter_length}), setting num_tokens to {filter_length}, item: {i}")
        # else:
        #     num_tokens = response_length + 1000

        new_prompt = [
            # {'role': 'system', 'content': 'Your task is to follow a systematic, thorough reasoning process before providing the final solution. This involves analyzing, summarizing, exploring, reassessing, and refining your thought process through multiple iterations. Structure your response into two sections: Thought and Solution. In the Thought section, present your reasoning using the format: "<think>\n {thoughts} </think>\n". Each thought should include detailed analysis, brainstorming, verification, and refinement of ideas. After "</think>\n," in the Solution section, provide the final, logical, and accurate answer, clearly derived from the exploration in the Thought section. If applicable, include the answer in \\boxed{} for closed-form results like multiple choices or mathematical solutions.'},
            # + f" Think for {response_length} tokens."},
            {'role': 'user', 'content': matched_item['original_prompt']}
            #  + " Please reason step by step, and put your final answer within \\boxed{}."}
        ]

        # 创建新的reward_model和extra_info字典
        new_reward_model = matched_item['reward_model'].copy()
        new_reward_model['num_tokens'] = num_tokens

        new_extra_info = matched_item['extra_info'].copy()

        # 多次生成时，按 jsonl 条目的 score/correctness 求平均（1遍也兼容）
        scores = matched_item.get('scores', [])
        new_extra_info['ori_acc'] = float(np.mean(scores)) if scores else 0.0
        new_extra_info['gen_count'] = int(
            matched_item.get('responses_count', 1))
        new_extra_info['avg_response_len'] = float(avg_response_length)

        # 构建新的数据结构
        new_row = {
            'data_source': matched_item['data_source'],
            'prompt': new_prompt,
            'ability': matched_item['ability'],
            'reward_model': new_reward_model,
            'extra_info': new_extra_info
        }
        target_lengths.append(num_tokens)
        processed_data.append(new_row)

    # 打印匹配统计信息
    print("\n=== 数据匹配统计 ===")
    print(f"Parquet文件总行数: {len(df_parquet)}")
    print(f"JSONL文件总行数: {len(jsonl_data)}")
    print(f"JSONL聚合后的组数: {len(jsonl_groups)}")
    print(f"成功匹配的样本数: {matched_count}")
    print(f"未匹配的样本数: {unmatched_count}")
    print(f"匹配率: {matched_count/len(df_parquet)*100:.1f}%")
    print(f"最终处理的数据行数: {len(processed_data)}")
    print(f"对齐过程中跳过的jsonl组数(用于重对齐): {total_skipped_groups}")

    # 打印最终的统计信息
    print("\n=== Response长度统计（按每个prompt的多次生成取平均） ===")
    print(f"总样本数: {len(response_lengths)}")
    if response_lengths:
        print(f"平均长度: {np.mean(response_lengths):.1f} tokens")
        print(f"中位数长度: {np.median(response_lengths):.1f} tokens")
        print(f"最小长度: {np.min(response_lengths)} tokens")
        print(f"最大长度: {np.max(response_lengths)} tokens")
        print(f"标准差: {np.std(response_lengths):.1f} tokens")
    else:
        print("没有有效的response数据")
        return

    # 统计大于filter_length的样本数量
    over_filter_length_count = sum(
        1 for x in response_lengths if x > filter_length)
    print(
        f"大于filter_length tokens的样本数: {over_filter_length_count} ({over_filter_length_count/len(response_lengths)*100:.1f}%)")

    # 长度分布统计
    print("\n=== 长度分布 ===")
    print(f"0-500 tokens: {sum(1 for x in response_lengths if x <= 500)} 个样本 ({sum(1 for x in response_lengths if x <= 500)/len(response_lengths)*100:.1f}%)")
    print(f"501-1000 tokens: {sum(1 for x in response_lengths if 500 < x <= 1000)} 个样本 ({sum(1 for x in response_lengths if 500 < x <= 1000)/len(response_lengths)*100:.1f}%)")
    print(f"1001-2000 tokens: {sum(1 for x in response_lengths if 1000 < x <= 2000)} 个样本 ({sum(1 for x in response_lengths if 1000 < x <= 2000)/len(response_lengths)*100:.1f}%)")
    print(f"2001-3000 tokens: {sum(1 for x in response_lengths if 2000 < x <= 3000)} 个样本 ({sum(1 for x in response_lengths if 2000 < x <= 3000)/len(response_lengths)*100:.1f}%)")
    print(f"7001-9000 tokens: {sum(1 for x in response_lengths if 7000 < x <= 9000)} 个样本 ({sum(1 for x in response_lengths if 7000 < x <= 9000)/len(response_lengths)*100:.1f}%)")
    print(f"5001-filter_length tokens: {sum(1 for x in response_lengths if 5000 < x <= filter_length)} 个样本 ({sum(1 for x in response_lengths if 5000 < x <= filter_length)/len(response_lengths)*100:.1f}%)")
    print(
        f"filter_length+ tokens: {over_filter_length_count} 个样本 ({over_filter_length_count/len(response_lengths)*100:.1f}%)")

    # 创建包含两个子图的图表
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    # 设置bin边界，每个bin宽度为1024
    # 注意：为了让filter_length(8192)出现在左侧bin[7168, 8192]里，需要调整边界
    # 将filter_length位置的边界稍微增大，使8192落在[7168, 8192.001)这个bin里
    bin_edges = []
    for i in range(0, 10 * 1024 + 1, 1024):
        if i == filter_length:
            bin_edges.append(filter_length + 0.001)  # 稍微增大，使8192落在左侧bin
        else:
            bin_edges.append(i)
    bin_edges = np.array(bin_edges)

    # 左图：Response长度分布
    ax1.hist(response_lengths, bins=bin_edges, alpha=0.7,
             color='skyblue', edgecolor='black')
    ax1.axvline(np.mean(response_lengths), color='red', linestyle='--',
                linewidth=2, label=f'average: {np.mean(response_lengths):.1f}')
    # ax1.axvline(np.median(response_lengths), color='green', linestyle='--', linewidth=2, label=f'中位数: {np.median(response_lengths):.1f}')
    ax1.axvline(filter_length, color='orange', linestyle='-',
                linewidth=3, label='filter_length tokens')
    ax1.set_xlabel('Token Length')
    ax1.set_ylabel('Frequency')
    ax1.set_title('Response Length Distribution')
    ax1.set_xlim(0, 17 * 1024)
    ax1.set_xticks(np.arange(0, 17 * 1024 + 1, 1024))
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 右图：Target长度分布
    ax2.hist(target_lengths, bins=bin_edges, alpha=0.7,
             color='lightcoral', edgecolor='black')
    ax2.axvline(np.mean(target_lengths), color='red', linestyle='--',
                linewidth=2, label=f'average: {np.mean(target_lengths):.1f}')
    # ax2.axvline(np.median(target_lengths), color='green', linestyle='--', linewidth=2, label=f'中位数: {np.median(target_lengths):.1f}')
    ax2.axvline(filter_length, color='orange', linestyle='-',
                linewidth=3, label='filter_length tokens')
    ax2.set_xlabel('Token Length')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Target Length Distribution')
    ax2.set_xlim(0, 17 * 1024)
    ax2.set_xticks(np.arange(0, 17 * 1024 + 1, 1024))
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 调整布局
    plt.tight_layout()

    # 保存直方图
    plot_file = output_file.replace('.parquet', '_length_distribution.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    print(f"直方图已保存到: {plot_file}")

    # 创建新的DataFrame
    new_df = pd.DataFrame(processed_data)

    # 保存为parquet文件
    print(f"\n正在保存到 {output_file}...")
    new_df.to_parquet(output_file, index=False)
    print(f"成功保存 {len(new_df)} 行数据到 {output_file}")


if __name__ == "__main__":
    # 文件路径
    # tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen3-4B-Base")
    # tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/OctoThinker-3B-Long-Base")
    # tokenizer= AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Instruct-2507")
    # tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B")
    # tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Llama-3.2-3B-Instruct")
    # tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/Qwen2.5-Math-7B-16k-think")
    tokenizer = AutoTokenizer.from_pretrained(
        "/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-8B-Base")

    # parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/l1/deepscaler_qwen3_polaris.parquet"
    # parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/deepmath-5k_qwen3.parquet"
    # parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/dapo-math-17k_qwen3_polaris.parquet"
    # parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/dapo-math-17k-octothinker.parquet"
    parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/qwen3-4b-s1.parquet"

    # jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/octothinker-deepmath/valid/0_16384.jsonl"
    # jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/gspo-stage1-dapo-math/valid/0_16384.jsonl"
    # jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/qwen3-8b-base-dapo-math/valid/0_16384.jsonl"
    jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/qwen3-8b-base-polaris/valid/0_16384.jsonl"
    # jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/sft-dapo-math/valid/0_32768.jsonl"

    # output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/dapo-math-17k-octothinker_add1k.parquet"
    # output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/gspo_stage1_dapo_math_add1k_max12k.parquet"
    # output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/qwen3-8b-base-dapo-math_add1k.parquet"
    output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/qwen3-8b-base-polaris_add1k.parquet"
    # output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/polaris_qwen3-4b-add1k_max16k.parquet"
    # 检查文件是否存在
    if not os.path.exists(parquet_file):
        print(f"错误：找不到parquet文件 {parquet_file}")
        exit(1)

    if not os.path.exists(jsonl_file):
        print(f"错误：找不到jsonl文件 {jsonl_file}")
        exit(1)

    # 处理数据，使用批量大小为64以提高效率
    process_data(parquet_file, jsonl_file, output_file,
                 batch_size=256, filter_length=8192, tokenizer=tokenizer)
