#!/usr/bin/env python3
"""
内存优化的n-gram提取脚本
支持从大量jsonl文件中提取output字段的n-gram (n=3-5)
使用流式处理和内存优化策略，避免内存溢出
"""

import json
import os
import glob
import argparse
from collections import Counter, defaultdict
from multiprocessing import Pool, cpu_count
from typing import List, Dict, Tuple, Iterator
import re
from tqdm import tqdm
import pickle
import gc
import psutil
import sys


def get_memory_usage():
    """获取当前内存使用情况（MB）"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024


def clean_text(text: str) -> str:
    """清理文本，移除多余空格和特殊字符"""
    if not isinstance(text, str):
        return ""
    # 移除多余空格，保留基本标点
    text = re.sub(r'\s+', ' ', text.strip())
    return text


def extract_ngrams_streaming(text: str, n: int) -> Iterator[str]:
    """流式提取n-gram，避免一次性创建大量列表"""
    if not text or len(text.split()) < n:
        return
    
    words = text.split()
    for i in range(len(words) - n + 1):
        yield ' '.join(words[i:i+n])


def process_single_file_streaming(file_path: str, max_memory_mb: int = 1000) -> Dict[int, Counter]:
    """流式处理单个jsonl文件，限制内存使用"""
    ngram_counters = {3: Counter(), 4: Counter(), 5: Counter()}
    line_count = 0
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    data = json.loads(line.strip())
                    output_text = data.get('generated_text', '')
                    
                    if output_text:
                        cleaned_text = clean_text(output_text)
                        if cleaned_text:
                            # 流式提取n-gram，避免内存累积
                            for n in [3, 4, 5]:
                                for ngram in extract_ngrams_streaming(cleaned_text, n):
                                    ngram_counters[n][ngram] += 1
                    
                    line_count += 1
                    
                    # 每处理1000行检查一次内存使用
                    if line_count % 1000 == 0:
                        current_memory = get_memory_usage()
                        if current_memory > max_memory_mb:
                            print(f"警告: 内存使用超过 {max_memory_mb}MB ({current_memory:.1f}MB)，强制垃圾回收")
                            gc.collect()
                            
                except json.JSONDecodeError:
                    print(f"警告: {file_path} 第{line_num}行JSON解析失败")
                    continue
                except Exception as e:
                    print(f"警告: {file_path} 第{line_num}行处理失败: {e}")
                    continue
                    
    except Exception as e:
        print(f"错误: 无法读取文件 {file_path}: {e}")
        
    return ngram_counters


def save_single_file_results_optimized(file_path: str, ngram_counters: Dict[int, Counter], output_dir: str):
    """为单个文件保存n-gram结果 - 内存优化版本"""
    # 获取文件名（不含路径和扩展名）
    file_name = os.path.splitext(os.path.basename(file_path))[0]
    
    # 为每个文件创建子目录
    file_output_dir = os.path.join(output_dir, file_name)
    os.makedirs(file_output_dir, exist_ok=True)
    
    for n in [3, 4, 5]:
        counter = ngram_counters[n]
        if not counter:
            continue
        
        # 计算统计信息
        unique_ngrams = len(counter)
        total_occurrences = sum(counter.values())
        
        # 保存完整结果（pickle格式）
        full_path = os.path.join(file_output_dir, f"ngram_{n}_full.pkl")
        with open(full_path, 'wb') as f:
            pickle.dump(counter, f)
        
        # 流式保存去重后的结果，避免内存问题
        deduplicated_path = os.path.join(file_output_dir, f"ngram_{n}_deduplicated.txt")
        with open(deduplicated_path, 'w', encoding='utf-8') as f:
            f.write(f"文件: {file_path}\n")
            f.write(f"去重后的 {n}-gram 结果 (按频次排序):\n")
            f.write("=" * 60 + "\n")
            f.write(f"总唯一n-gram数量: {unique_ngrams}\n")
            f.write(f"总出现次数: {total_occurrences}\n")
            f.write("=" * 60 + "\n")
            
            # 流式写入，避免一次性加载所有数据到内存
            for ngram, count in counter.most_common():
                f.write(f"{ngram}\t{count}\n")
        
        # 保存统计信息
        stats_path = os.path.join(file_output_dir, f"ngram_{n}_stats.txt")
        with open(stats_path, 'w', encoding='utf-8') as f:
            f.write(f"文件: {file_path}\n")
            f.write(f"{n}-gram 统计信息 (去重后):\n")
            f.write("=" * 40 + "\n")
            f.write(f"总出现次数: {total_occurrences}\n")
            f.write(f"唯一n-gram数量: {unique_ngrams}\n")
            f.write(f"平均频率: {total_occurrences / unique_ngrams:.2f}\n")
            f.write(f"最高频率: {max(counter.values()) if counter else 0}\n")
            f.write(f"最低频率: {min(counter.values()) if counter else 0}\n")
            
            # 频率分布统计
            freq_dist = Counter(counter.values())
            f.write(f"\n频率分布:\n")
            f.write("-" * 20 + "\n")
            for freq, count in sorted(freq_dist.items(), reverse=True):
                f.write(f"出现{freq}次: {count}个n-gram\n")
    
    print(f"保存了文件 {file_name} 的n-gram结果到: {file_output_dir}")


def process_single_file_with_save_optimized(args_tuple):
    """处理单个文件并保存结果 - 内存优化版本"""
    file_path, output_dir, max_memory_mb = args_tuple
    
    try:
        print(f"开始处理文件: {os.path.basename(file_path)}")
        initial_memory = get_memory_usage()
        
        # 处理文件
        ngram_counters = process_single_file_streaming(file_path, max_memory_mb)
        
        # 保存单个文件的结果
        save_single_file_results_optimized(file_path, ngram_counters, output_dir)
        
        final_memory = get_memory_usage()
        print(f"完成文件 {os.path.basename(file_path)}，内存使用: {initial_memory:.1f}MB -> {final_memory:.1f}MB")
        
        # 强制垃圾回收
        del ngram_counters
        gc.collect()
        
        return {3: Counter(), 4: Counter(), 5: Counter()}  # 返回空计数器，避免内存累积
        
    except Exception as e:
        print(f"错误: 处理文件 {file_path} 时发生异常: {e}")
        gc.collect()
        return {3: Counter(), 4: Counter(), 5: Counter()}


def find_jsonl_files(directory: str) -> List[str]:
    """递归查找目录下的所有jsonl文件"""
    pattern = os.path.join(directory, "**", "*_test.jsonl")
    files = glob.glob(pattern, recursive=True)
    return files


def main():
    parser = argparse.ArgumentParser(description="从jsonl文件中提取n-gram - 内存优化版本")
    parser.add_argument("--input_dir", type=str, default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep",
                       help="包含jsonl文件的目录路径")
    parser.add_argument("--output_dir", type=str, default="/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/results_sep/ngram_results",
                       help="输出目录")
    parser.add_argument("--num_processes", type=int, default=2,  # 减少默认进程数
                       help="并行进程数")
    parser.add_argument("--batch_size", type=int, default=1,  # 每次只处理一个文件
                       help="每个批次处理的文件数")
    parser.add_argument("--max_memory_mb", type=int, default=800,  # 内存限制
                       help="每个进程的最大内存使用量（MB）")
    parser.add_argument("--single_process", action="store_true", default=True,  # 默认单进程
                       help="强制使用单进程处理（避免多进程问题）")
    
    args = parser.parse_args()
    
    print(f"开始处理目录: {args.input_dir}")
    print(f"内存限制: {args.max_memory_mb}MB")
    print(f"初始内存使用: {get_memory_usage():.1f}MB")
    
    if args.single_process:
        print("使用单进程处理")
    else:
        print(f"使用 {args.num_processes} 个进程")
    
    # 查找所有jsonl文件
    jsonl_files = find_jsonl_files(args.input_dir)
    if not jsonl_files:
        print(f"在 {args.input_dir} 中未找到jsonl文件")
        return
    
    print(f"找到 {len(jsonl_files)} 个jsonl文件")
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    if args.single_process:
        # 单进程处理，逐个文件处理
        print("使用单进程逐个处理文件...")
        for i, file_path in enumerate(tqdm(jsonl_files, desc="处理文件")):
            try:
                print(f"\n处理进度: {i+1}/{len(jsonl_files)}")
                print(f"当前内存使用: {get_memory_usage():.1f}MB")
                
                process_single_file_with_save_optimized((file_path, args.output_dir, args.max_memory_mb))
                
                # 每处理几个文件后强制垃圾回收
                if (i + 1) % 3 == 0:
                    print("执行垃圾回收...")
                    gc.collect()
                    
            except Exception as file_error:
                print(f"处理文件 {file_path} 失败: {file_error}")
                gc.collect()
                continue
    else:
        # 多进程处理，但使用更小的批次
        try:
            with Pool(processes=args.num_processes) as pool:
                print(f"使用 {args.num_processes} 个进程处理 {len(jsonl_files)} 个文件...")
                
                # 准备参数
                worker_args = [(file_path, args.output_dir, args.max_memory_mb) for file_path in jsonl_files]
                
                # 分批处理以避免内存问题
                batch_size = max(1, args.batch_size)
                results = []
                
                for i in range(0, len(worker_args), batch_size):
                    batch = worker_args[i:i + batch_size]
                    print(f"处理批次 {i//batch_size + 1}/{(len(worker_args) + batch_size - 1)//batch_size}")
                    print(f"当前内存使用: {get_memory_usage():.1f}MB")
                    
                    batch_results = pool.map(process_single_file_with_save_optimized, batch)
                    results.extend(batch_results)
                    
                    # 每批处理后强制垃圾回收
                    gc.collect()
                    
        except Exception as e:
            print(f"多进程处理出错，改用单进程处理: {e}")
            # 如果多进程失败，回退到单进程
            for file_path in tqdm(jsonl_files, desc="单进程处理文件"):
                try:
                    process_single_file_with_save_optimized((file_path, args.output_dir, args.max_memory_mb))
                    gc.collect()
                except Exception as file_error:
                    print(f"处理文件 {file_path} 失败: {file_error}")
                    gc.collect()
                    continue
    
    print(f"\n所有结果已保存到: {args.output_dir}")
    print(f"最终内存使用: {get_memory_usage():.1f}MB")
    print("每个文件的单独结果保存在对应的子目录中")


if __name__ == "__main__":
    main()
