import json
import argparse
from collections import defaultdict
import os
from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description='清理和整理jsonl文件')
    parser.add_argument('--input', type=str, required=True,
                       help='输入的jsonl文件路径')
    parser.add_argument('--output', type=str,
                       help='输出的jsonl文件路径，如果不指定则在输入文件名后加上_cleaned')
    return parser.parse_args()

def clean_jsonl(input_file: str, output_file: str):
    # 用于存储有效的记录
    records = defaultdict(dict)  # {id: {sample_index: record}}
    
    # 统计信息
    total_lines = 0
    valid_lines = 0
    duplicate_records = 0
    invalid_lines = 0
    
    print(f"正在读取文件: {input_file}")
    
    # 首先计算总行数用于进度条
    with open(input_file, 'r') as f:
        total_lines = sum(1 for _ in f)
    
    # 读取并处理每一行
    with open(input_file, 'r') as f:
        for line in tqdm(f, total=total_lines, desc="处理数据"):
            try:
                # 尝试解析JSON
                record = json.loads(line.strip())
                
                # 确保必要的字段存在
                if 'id' not in record or 'sample_index' not in record:
                    print(f"警告: 记录缺少必要字段 id 或 sample_index")
                    invalid_lines += 1
                    continue
                
                record_id = record['id']
                sample_index = record['sample_index']
                
                # 检查是否是重复的sample_index
                if sample_index in records[record_id]:
                    print(f"警告: 发现重复的记录 - ID: {record_id}, Sample Index: {sample_index}")
                    duplicate_records += 1
                    # 保留最后一个出现的记录
                    records[record_id][sample_index] = record
                else:
                    records[record_id][sample_index] = record
                    valid_lines += 1
                
            except json.JSONDecodeError as e:
                print(f"错误: JSON解析失败 - {str(e)}")
                invalid_lines += 1
                continue
    
    # 将记录排序并写入新文件
    print(f"正在写入清理后的文件: {output_file}")
    with open(output_file, 'w') as f:
        # 按ID排序
        for record_id in sorted(records.keys()):
            # 按sample_index排序
            for sample_index in sorted(records[record_id].keys()):
                json.dump(records[record_id][sample_index], f, ensure_ascii=False)
                f.write('\n')
    
    # 打印统计信息
    print("\n处理完成！统计信息：")
    print(f"总行数: {total_lines}")
    print(f"有效记录数: {valid_lines}")
    print(f"重复记录数: {duplicate_records}")
    print(f"无效行数: {invalid_lines}")
    print(f"最终记录数: {sum(len(indices) for indices in records.values())}")
    
    # 检查每个ID是否都有100个sample_index
    incomplete_ids = []
    for record_id, samples in records.items():
        if len(samples) != 100:
            incomplete_ids.append((record_id, len(samples)))
    
    if incomplete_ids:
        print("\n警告：以下ID的sample_index数量不足100个：")
        for record_id, count in sorted(incomplete_ids):
            print(f"ID: {record_id}, 数量: {count}")

def main():
    args = parse_args()
    
    # 如果没有指定输出文件，则在输入文件名后加上_cleaned
    if args.output is None:
        base, ext = os.path.splitext(args.input)
        args.output = f"{base}_cleaned{ext}"
    
    # 确保输出目录存在
    os.makedirs(os.path.dirname(args.output) if os.path.dirname(args.output) else '.', exist_ok=True)
    
    # 处理文件
    clean_jsonl(args.input, args.output)

if __name__ == "__main__":
    main() 