#!/usr/bin/env python3
"""
专门处理5样本格式的n-gram提取脚本
支持从每5条记录对应一个问题的jsonl文件中提取n-gram (n=3-5)
将结果分成5个文件夹保存，每个文件夹对应一个样本位置
使用流式处理和内存优化策略，避免内存溢出
"""

import json
import os
import argparse
from collections import Counter
from typing import List, Dict, Iterator
import re
from tqdm import tqdm
import pickle
import gc
import psutil


def get_memory_usage():
    """获取当前内存使用情况（MB）"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def clean_text(text: str) -> str:
    """清理文本，移除多余空格和特殊字符"""
    if not isinstance(text, str):
        return ""
    # 移除多余空格，保留基本标点
    text = re.sub(r'\s+', ' ', text.strip())
    return text


def extract_ngrams_streaming(text: str, n: int) -> Iterator[str]:
    """流式提取n-gram，避免一次性创建大量列表"""
    if not text or len(text.split()) < n:
        return
    
    words = text.split()
    for i in range(len(words) - n + 1):
        yield ' '.join(words[i:i+n])


def process_5samples_file(file_path: str, max_memory_mb: int = 1000) -> Dict[int, Dict[int, Counter]]:
    """
    处理5样本格式的jsonl文件
    返回格式: {sample_id: {n: Counter}}
    """
    # 为每个样本位置（1-5）创建n-gram计数器
    sample_counters = {
        1: {3: Counter(), 4: Counter(), 5: Counter()},
        2: {3: Counter(), 4: Counter(), 5: Counter()},
        3: {3: Counter(), 4: Counter(), 5: Counter()},
        4: {3: Counter(), 4: Counter(), 5: Counter()},
        5: {3: Counter(), 4: Counter(), 5: Counter()}
    }
    
    line_count = 0
    current_sample = 1  # 当前样本位置（1-5）
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    output_text = data.get('generated_text', '')
                    
                    if output_text:
                        cleaned_text = clean_text(output_text)
                        if cleaned_text:
                            # 为当前样本位置提取n-gram
                            for n in [3, 4, 5]:
                                for ngram in extract_ngrams_streaming(cleaned_text, n):
                                    sample_counters[current_sample][n][ngram] += 1
                    
                    line_count += 1
                    
                    # 更新样本位置：每5行为一个循环
                    current_sample = ((line_count - 1) % 5) + 1
                    
                    # 每处理1000行检查一次内存使用
                    if line_count % 1000 == 0:
                        current_memory = get_memory_usage()
                        if current_memory > max_memory_mb:
                            print(f"警告: 内存使用超过 {max_memory_mb}MB ({current_memory:.1f}MB)，强制垃圾回收")
                            gc.collect()
                            
                except json.JSONDecodeError:
                    print(f"警告: {file_path} 第{line_num}行JSON解析失败")
                    continue
                except Exception as e:
                    print(f"警告: {file_path} 第{line_num}行处理失败: {e}")
                    continue
                    
    except Exception as e:
        print(f"错误: 无法读取文件 {file_path}: {e}")
        
    return sample_counters


def save_sample_results(file_path: str, sample_counters: Dict[int, Dict[int, Counter]], output_dir: str):
    """为每个样本位置保存n-gram结果"""
    # 获取文件名（不含路径和扩展名）
    file_name = os.path.splitext(os.path.basename(file_path))[0]
    
    # 为每个样本位置创建子目录
    for sample_id in range(1, 6):
        sample_output_dir = os.path.join(output_dir, f"{file_name}_sample_{sample_id}")
        os.makedirs(sample_output_dir, exist_ok=True)
        
        for n in [3, 4, 5]:
            counter = sample_counters[sample_id][n]
            if not counter:
                continue
            
            # 计算统计信息
            unique_ngrams = len(counter)
            total_occurrences = sum(counter.values())
            
            # 保存完整结果（pickle格式）
            full_path = os.path.join(sample_output_dir, f"ngram_{n}_full.pkl")
            with open(full_path, 'wb') as f:
                pickle.dump(counter, f)
            
            # 流式保存去重后的结果，避免内存问题
            deduplicated_path = os.path.join(sample_output_dir, f"ngram_{n}_deduplicated.txt")
            with open(deduplicated_path, 'w', encoding='utf-8') as f:
                f.write(f"文件: {file_path}\n")
                f.write(f"样本位置: {sample_id}\n")
                f.write(f"去重后的 {n}-gram 结果 (按频次排序):\n")
                f.write("=" * 60 + "\n")
                f.write(f"总唯一n-gram数量: {unique_ngrams}\n")
                f.write(f"总出现次数: {total_occurrences}\n")
                f.write("=" * 60 + "\n")
                
                # 流式写入，避免一次性加载所有数据到内存
                for ngram, count in counter.most_common():
                    f.write(f"{ngram}\t{count}\n")
            
            # 保存统计信息
            stats_path = os.path.join(sample_output_dir, f"ngram_{n}_stats.txt")
            with open(stats_path, 'w', encoding='utf-8') as f:
                f.write(f"文件: {file_path}\n")
                f.write(f"样本位置: {sample_id}\n")
                f.write(f"{n}-gram 统计信息 (去重后):\n")
                f.write("=" * 40 + "\n")
                f.write(f"总出现次数: {total_occurrences}\n")
                f.write(f"唯一n-gram数量: {unique_ngrams}\n")
                f.write(f"平均频率: {total_occurrences / unique_ngrams:.2f}\n")
                f.write(f"最高频率: {max(counter.values()) if counter else 0}\n")
                f.write(f"最低频率: {min(counter.values()) if counter else 0}\n")
                
                # 频率分布统计
                freq_dist = Counter(counter.values())
                f.write(f"\n频率分布:\n")
                f.write("-" * 20 + "\n")
                for freq, count in sorted(freq_dist.items(), reverse=True):
                    f.write(f"出现{freq}次: {count}个n-gram\n")
        
        print(f"保存了样本 {sample_id} 的n-gram结果到: {sample_output_dir}")


def main():
    parser = argparse.ArgumentParser(description="从5样本格式的jsonl文件中提取n-gram")
    parser.add_argument("--input_file", type=str, 
                       default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/l1-8b-5samples_8192_test.jsonl",
                       help="输入的5样本格式jsonl文件路径")
    parser.add_argument("--output_dir", type=str, 
                       default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/ngram_results",
                       help="输出目录")
    parser.add_argument("--max_memory_mb", type=int, default=10240,
                       help="最大内存使用量（MB）")
    
    args = parser.parse_args()
    
    print(f"开始处理5样本文件: {args.input_file}")
    print(f"内存限制: {args.max_memory_mb}MB")
    print(f"初始内存使用: {get_memory_usage():.1f}MB")
    
    # 检查输入文件是否存在
    if not os.path.exists(args.input_file):
        print(f"错误: 输入文件不存在: {args.input_file}")
        return
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    try:
        print("开始处理文件...")
        initial_memory = get_memory_usage()
        
        # 处理文件
        sample_counters = process_5samples_file(args.input_file, args.max_memory_mb)
        
        # 保存结果
        save_sample_results(args.input_file, sample_counters, args.output_dir)
        
        final_memory = get_memory_usage()
        print(f"完成处理，内存使用: {initial_memory:.1f}MB -> {final_memory:.1f}MB")
        
        # 强制垃圾回收
        del sample_counters
        gc.collect()
        
    except Exception as e:
        print(f"错误: 处理文件时发生异常: {e}")
        gc.collect()
        return
    
    print(f"\n所有结果已保存到: {args.output_dir}")
    print(f"最终内存使用: {get_memory_usage():.1f}MB")
    print("每个样本的结果保存在对应的子目录中:")
    print("  - {filename}_sample1_test/ (第1个样本)")
    print("  - {filename}_sample2_test/ (第2个样本)")
    print("  - {filename}_sample3_test/ (第3个样本)")
    print("  - {filename}_sample4_test/ (第4个样本)")
    print("  - {filename}_sample5_test/ (第5个样本)")


if __name__ == "__main__":
    main()
