
import pandas as pd
import json
import os
from typing import Dict, List, Any
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
import torch

def load_jsonl(file_path: str) -> List[Dict]:
    """加载JSONL文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

def load_parquet(file_path: str) -> pd.DataFrame:
    """加载parquet文件"""
    return pd.read_parquet(file_path)


def batch_calculate_response_lengths(responses: List[str], tokenizer, special_token_ids: set, device: torch.device, batch_size: int = 32) -> List[int]:
    """
    批量计算response的token长度
    
    Args:
        responses: response文本列表
        tokenizer: tokenizer对象
        special_token_ids: special token ID集合
        device: 计算设备
        batch_size: 批处理大小
    
    Returns:
        response长度列表
    """
    response_lengths = []
    total_batches = (len(responses) + batch_size - 1) // batch_size
    
    # 预创建special tokens tensor避免重复创建
    special_tokens_tensor = torch.tensor(list(special_token_ids), device=device)
    
    for i in range(0, len(responses), batch_size):
        batch_idx = i // batch_size + 1
        batch_responses = responses[i:i + batch_size]
        
        # 显示进度
        if batch_idx % 10 == 0 or batch_idx == total_batches:
            print(f"处理批次 {batch_idx}/{total_batches} ({batch_idx/total_batches*100:.1f}%)")
        
        # 批量编码
        batch_tokens = tokenizer(
            batch_responses, 
            add_special_tokens=False, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=32010  # 设置最大长度避免内存问题
        )
        
        # 移动到设备
        input_ids = batch_tokens['input_ids'].to(device)
        attention_mask = batch_tokens['attention_mask'].to(device)
        
        # 计算每个response的实际长度（排除padding）
        for j in range(len(batch_responses)):
            # 获取非padding的tokens
            valid_tokens = input_ids[j][attention_mask[j] == 1]
            
            # 过滤special tokens
            filtered_tokens = valid_tokens[~torch.isin(valid_tokens, special_tokens_tensor)]
            
            response_lengths.append(len(filtered_tokens))
        
        # 清理GPU内存
        del input_ids, attention_mask, batch_tokens
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    return response_lengths


def process_data(parquet_file: str, jsonl_file: str, output_file: str, batch_size: int = 32, filter_length: int = 7192):
    """
    处理deepscaler.parquet和测试结果jsonl文件，生成新的parquet文件
    
    Args:
        parquet_file: 原始deepscaler.parquet文件路径
        jsonl_file: 测试结果jsonl文件路径  
        output_file: 输出parquet文件路径
        batch_size: 批量处理大小，默认32
    """
    tokenizer = AutoTokenizer.from_pretrained("/mnt/shared-storage-user/p1-shared/Qwen/Qwen3-4B-Base")
    
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    # 获取special token IDs
    special_token_ids = set()
    if tokenizer.bos_token_id is not None:
        special_token_ids.add(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        special_token_ids.add(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        special_token_ids.add(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        special_token_ids.add(tokenizer.unk_token_id)
    # 加载数据
    print("正在加载parquet文件...")
    df_parquet = load_parquet(parquet_file)
    print(f"Parquet文件包含 {len(df_parquet)} 行数据")
    
    print("正在加载jsonl文件...")
    jsonl_data = load_jsonl(jsonl_file)
    print(f"JSONL文件包含 {len(jsonl_data)} 行数据")
    
    # 处理数据
    matched_data = []  # 存储匹配成功的数据结构
    matched_count = 0
    unmatched_count = 0
    batch_responses = []  # 批量存储responses用于批量计算长度
    
    # 创建jsonl数据的查找字典，基于prompt内容建立索引
    jsonl_lookup = {}
    for idx, jsonl_item in enumerate(jsonl_data):
        prompt_key = jsonl_item.get('prompt', '') or jsonl_item.get('input', '')
        if prompt_key:
            # 使用prompt的前100个字符作为key，避免key过长
            key = prompt_key[:100] if len(prompt_key) > 100 else prompt_key
            jsonl_lookup[key] = (jsonl_item, idx)  # 存储数据和索引
    
    print(f"创建了包含 {len(jsonl_lookup)} 个条目的jsonl查找字典")
    
    # 用于跟踪模糊查找的起始位置
    last_matched_jsonl_idx = -1
    
    print("开始第一次匹配...")
    for idx, row in df_parquet.iterrows():
        # 获取原始数据
        data_source = row.get('data_source', '')
        prompt = row.get('prompt', [])
        ability = row.get('ability', 'math')
        reward_model = row.get('reward_model', {})
        extra_info = row.get('extra_info', {})
        
        # 提取原始prompt内容
        original_prompt = ''
        if prompt:
            if isinstance(prompt, np.ndarray) and len(prompt) > 0:
                original_prompt = prompt[0].get('content', '') if isinstance(prompt[0], dict) else str(prompt[0])
            elif isinstance(prompt, list) and len(prompt) > 0:
                original_prompt = prompt[0].get('content', '') if isinstance(prompt[0], dict) else str(prompt[0])
            else:
                original_prompt = str(prompt)
        
        # 查找对应的测试结果
        matching_result = None
        if original_prompt:
            # 尝试多种匹配策略
            prompt_key = original_prompt[:100] if len(original_prompt) > 100 else original_prompt
            
            # 策略1: 直接匹配
            if prompt_key in jsonl_lookup:
                matching_result, jsonl_idx = jsonl_lookup[prompt_key]
                last_matched_jsonl_idx = jsonl_idx
            else:
                # 策略2: 模糊匹配 - 从上一次匹配位置之后开始查找
                start_idx = last_matched_jsonl_idx + 1
                for i in range(start_idx, len(jsonl_data)):
                    jsonl_item = jsonl_data[i]
                    jsonl_prompt = jsonl_item.get('prompt', '') or jsonl_item.get('input', '')
                    if original_prompt in jsonl_prompt or jsonl_prompt in original_prompt:
                        matching_result = jsonl_item
                        last_matched_jsonl_idx = i
                        break
                
                # 策略3: 如果从上次位置之后没找到，从头开始查找（但只到上次位置）
                if not matching_result and last_matched_jsonl_idx >= 0:
                    for i in range(0, min(start_idx, len(jsonl_data))):
                        jsonl_item = jsonl_data[i]
                        jsonl_prompt = jsonl_item.get('prompt', '') or jsonl_item.get('input', '')
                        if original_prompt in jsonl_prompt or jsonl_prompt in original_prompt:
                            matching_result = jsonl_item
                            last_matched_jsonl_idx = i
                            break
        
        if matching_result:
            matched_count += 1
            response = matching_result.get('output', '') or matching_result.get('generated_text', '')
            
            if response:
                # 存储匹配成功的数据结构，但不计算长度
                matched_item = {
                    'data_source': data_source,
                    'original_prompt': original_prompt,
                    'ability': ability,
                    'reward_model': reward_model,
                    'extra_info': extra_info,
                    'matching_result': matching_result,
                    'response': response
                }
                matched_data.append(matched_item)
                batch_responses.append(response)
            else:
                print(f"Warning: No response found for matched result at idx {idx}")
                continue
        else:
            unmatched_count += 1
            if unmatched_count <= 10:  # 只打印前10个未匹配的样本
                print(f"No matching result found for prompt: {original_prompt[:50]}... at idx: {idx}")
            continue
    
    # 批量计算所有response的长度
    print(f"\n开始批量计算 {len(batch_responses)} 个response的长度...")
    if batch_responses:
        response_lengths = batch_calculate_response_lengths(
            batch_responses, tokenizer, special_token_ids, device, batch_size=batch_size
        )
        print(f"批量计算完成，共计算了 {len(response_lengths)} 个response长度")
    else:
        print("没有找到任何有效的response")
        return
    
    # 构建最终数据，使用预计算的长度
    print("开始构建最终数据...")
    processed_data = []
    for i, matched_item in enumerate(matched_data):
        response_length = response_lengths[i]
        
        flag = True
        if response_length > filter_length:
            print(f"Response length {response_length} is greater than filter_length, item: {i}")
            # flag = False
            # print(f"Response: {matched_item['response']}")
            response_length = filter_length
        if flag:
            new_prompt = [
                # {'role': 'system', 'content': 'Your task is to follow a systematic, thorough reasoning process before providing the final solution. This involves analyzing, summarizing, exploring, reassessing, and refining your thought process through multiple iterations. Structure your response into two sections: Thought and Solution. In the Thought section, present your reasoning using the format: "<think>\n {thoughts} </think>\n". Each thought should include detailed analysis, brainstorming, verification, and refinement of ideas. After "</think>\n," in the Solution section, provide the final, logical, and accurate answer, clearly derived from the exploration in the Thought section. If applicable, include the answer in \\boxed{} for closed-form results like multiple choices or mathematical solutions.'},
                # + f" Think for {response_length} tokens."},
                {'role': 'user', 'content': matched_item['original_prompt'] }
                #  + " Please reason step by step, and put your final answer within \\boxed{}."}
            ]
            
            # 创建新的reward_model和extra_info字典
            new_reward_model = matched_item['reward_model'].copy()
            # new_reward_model['num_tokens'] = response_length
            new_reward_model['num_tokens'] = filter_length
            
            new_extra_info = matched_item['extra_info'].copy()
            new_extra_info['original_response_length'] = response_length
           
            new_extra_info['ori_acc'] = matched_item['matching_result'].get("score", matched_item['matching_result'].get("correctness", 0))
            
            # 构建新的数据结构
            new_row = {
                'data_source': matched_item['data_source'],
                'prompt': new_prompt,
                'ability': matched_item['ability'],
                'reward_model': new_reward_model,
                'extra_info': new_extra_info
            }
            
            processed_data.append(new_row)
    
    # 打印匹配统计信息
    print("\n=== 数据匹配统计 ===")
    print(f"Parquet文件总行数: {len(df_parquet)}")
    print(f"JSONL文件总行数: {len(jsonl_data)}")
    print(f"成功匹配的样本数: {matched_count}")
    print(f"未匹配的样本数: {unmatched_count}")
    print(f"匹配率: {matched_count/len(df_parquet)*100:.1f}%")
    print(f"最终处理的数据行数: {len(processed_data)}")
    
    # 打印最终的统计信息
    print("\n=== 原始Response长度统计 ===")
    print(f"总样本数: {len(response_lengths)}")
    if response_lengths:
        print(f"平均长度: {np.mean(response_lengths):.1f} tokens")
        print(f"中位数长度: {np.median(response_lengths):.1f} tokens")
        print(f"最小长度: {np.min(response_lengths)} tokens")
        print(f"最大长度: {np.max(response_lengths)} tokens")
        print(f"标准差: {np.std(response_lengths):.1f} tokens")
    else:
        print("没有有效的response数据")
        return
    
    # 统计大于filter_length的样本数量
    over_filter_length_count = sum(1 for x in response_lengths if x > filter_length)
    print(f"大于filter_length tokens的样本数: {over_filter_length_count} ({over_filter_length_count/len(response_lengths)*100:.1f}%)")
    
    # 长度分布统计
    print("\n=== 长度分布 ===")
    print(f"0-500 tokens: {sum(1 for x in response_lengths if x <= 500)} 个样本 ({sum(1 for x in response_lengths if x <= 500)/len(response_lengths)*100:.1f}%)")
    print(f"501-1000 tokens: {sum(1 for x in response_lengths if 500 < x <= 1000)} 个样本 ({sum(1 for x in response_lengths if 500 < x <= 1000)/len(response_lengths)*100:.1f}%)")
    print(f"1001-2000 tokens: {sum(1 for x in response_lengths if 1000 < x <= 2000)} 个样本 ({sum(1 for x in response_lengths if 1000 < x <= 2000)/len(response_lengths)*100:.1f}%)")
    print(f"2001-3000 tokens: {sum(1 for x in response_lengths if 2000 < x <= 3000)} 个样本 ({sum(1 for x in response_lengths if 2000 < x <= 3000)/len(response_lengths)*100:.1f}%)")
    print(f"3001-5000 tokens: {sum(1 for x in response_lengths if 3000 < x <= 5000)} 个样本 ({sum(1 for x in response_lengths if 3000 < x <= 5000)/len(response_lengths)*100:.1f}%)")
    print(f"5001-filter_length tokens: {sum(1 for x in response_lengths if 5000 < x <= filter_length)} 个样本 ({sum(1 for x in response_lengths if 5000 < x <= filter_length)/len(response_lengths)*100:.1f}%)")
    print(f"filter_length+ tokens: {over_filter_length_count} 个样本 ({over_filter_length_count/len(response_lengths)*100:.1f}%)")
    
    # 创建直方图
    plt.figure(figsize=(10, 6))
    plt.hist(response_lengths, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    plt.axvline(np.mean(response_lengths), color='red', linestyle='--', linewidth=2, label=f'平均值: {np.mean(response_lengths):.1f}')
    plt.axvline(np.median(response_lengths), color='green', linestyle='--', linewidth=2, label=f'中位数: {np.median(response_lengths):.1f}')
    plt.axvline(filter_length, color='orange', linestyle='-', linewidth=3, label='filter_length tokens')
    plt.xlabel('Token数量')
    plt.ylabel('频次')
    plt.title('Response长度分布直方图')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 保存直方图
    plot_file = output_file.replace('.parquet', '_length_distribution.png')
    plt.savefig(plot_file, dpi=300, bbox_inches='tight')
    print(f"直方图已保存到: {plot_file}")
    
    # 创建新的DataFrame
    new_df = pd.DataFrame(processed_data)
    print(new_df.iloc[0]['prompt'])
    # 保存为parquet文件
    print(f"\n正在保存到 {output_file}...")
    new_df.to_parquet(output_file, index=False)
    print(f"成功保存 {len(new_df)} 行数据到 {output_file}")

if __name__ == "__main__":
    # 文件路径
    # parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/data/l1/deepscaler.parquet"
    # jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results/Qwen3-8b-deepscaler_test.jsonl"
    # parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/POLARIS/parquet/stage1/qwen3-4b-s1.parquet"
    # jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/Qwen3-4B-Base-polaris-30k/valid/0_32000.jsonl"
    # output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/polaris/qwen3-4b-s1-ori_length_8192.parquet"
    parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/l1/deepscaler_qwen3_polaris.parquet"
    jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/Qwen3-4B-Base-deepscaler-47k/valid/0_32000.jsonl"
    output_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/l1/deepscaler_qwen3_polaris_all16k.parquet"
    
    # 检查文件是否存在
    if not os.path.exists(parquet_file):
        print(f"错误：找不到parquet文件 {parquet_file}")
        exit(1)
    
    if not os.path.exists(jsonl_file):
        print(f"错误：找不到jsonl文件 {jsonl_file}")
        exit(1)
    
    # 处理数据，使用批量大小为64以提高效率
    process_data(parquet_file, jsonl_file, output_file, batch_size=64, filter_length=16384)
