#!/usr/bin/env python3
"""
比较两个数据集的input字段重复比例
"""

import json
import argparse
from typing import Set, Dict, Any
from collections import Counter

def load_jsonl(file_path: str) -> list:
    """加载JSONL文件"""
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    data.append(json.loads(line))
    except FileNotFoundError:
        print(f"错误: 找不到文件 {file_path}")
        return []
    except json.JSONDecodeError as e:
        print(f"错误: 解析JSON失败 {file_path}: {e}")
        return []
    return data

def extract_inputs(data: list, field_name: str = 'input') -> Set[str]:
    """提取指定字段的所有值"""
    inputs = set()
    for item in data:
        if isinstance(item, dict) and field_name in item:
            inputs.add(item[field_name].split("user\n")[1].split("\nassistant")[0].strip())
    return inputs

def calculate_overlap_ratio(set1: Set[str], set2: Set[str]) -> Dict[str, float]:
    """计算两个集合的重叠比例"""
    intersection = set1 & set2
    union = set1 | set2
    
    # 计算各种比例
    overlap_ratio_1 = len(intersection) / len(set1) if len(set1) > 0 else 0
    overlap_ratio_2 = len(intersection) / len(set2) if len(set2) > 0 else 0
    jaccard_similarity = len(intersection) / len(union) if len(union) > 0 else 0
    
    return {
        'intersection_size': len(intersection),
        'set1_size': len(set1),
        'set2_size': len(set2),
        'union_size': len(union),
        'overlap_ratio_1_to_2': overlap_ratio_1,  # set1中有多少比例在set2中
        'overlap_ratio_2_to_1': overlap_ratio_2,  # set2中有多少比例在set1中
        'jaccard_similarity': jaccard_similarity  # Jaccard相似度
    }

def analyze_duplicates(data: list, field_name: str = 'input') -> Dict[str, Any]:
    """分析数据集内部的重复情况"""
    inputs = [item[field_name] for item in data if isinstance(item, dict) and field_name in item]
    input_counter = Counter(inputs)
    
    unique_inputs = len(input_counter)
    total_inputs = len(inputs)
    duplicate_count = sum(count - 1 for count in input_counter.values() if count > 1)
    
    return {
        'total_items': len(data),
        'total_inputs': total_inputs,
        'unique_inputs': unique_inputs,
        'duplicate_count': duplicate_count,
        'duplicate_ratio': duplicate_count / total_inputs if total_inputs > 0 else 0
    }

def print_different_inputs(inputs1: Set[str], inputs2: Set[str], max_examples: int = 10):
    """打印两个数据集中不同的input"""
    only_in_1 = inputs1 - inputs2
    only_in_2 = inputs2 - inputs1
    
    print(f"\n=== 不同的Input分析 ===")
    print(f"仅在数据集1中的input数量: {len(only_in_1)}")
    print(f"仅在数据集2中的input数量: {len(only_in_2)}")
    
    if only_in_1:
        print(f"\n仅在数据集1中的input示例 (最多显示{max_examples}个):")
        for i, input_text in enumerate(list(only_in_1)[:max_examples]):
            print(f"  {i+1}. {input_text[:200]}{'...' if len(input_text) > 200 else ''}")
    
    if only_in_2:
        print(f"\n仅在数据集2中的input示例 (最多显示{max_examples}个):")
        for i, input_text in enumerate(list(only_in_2)[:max_examples]):
            print(f"  {i+1}. {input_text[:200]}{'...' if len(input_text) > 200 else ''}")

def main():
    parser = argparse.ArgumentParser(description='比较两个数据集的指定字段重复比例')
    parser.add_argument('file1', help='第一个数据集文件路径')
    parser.add_argument('file2', help='第二个数据集文件路径')
    parser.add_argument('--field', '-f', default='input', help='要比较的字段名 (默认: input)')
    parser.add_argument('--output', '-o', help='输出结果到文件', default="analysis/test.json")
    parser.add_argument('--show-different', action='store_true', help='显示不同的input示例')
    parser.add_argument('--max-examples', type=int, default=10, help='显示的最大示例数量 (默认: 10)')
    
    args = parser.parse_args()
    
    print("正在加载数据集...")
    
    # 加载数据集
    data1 = load_jsonl(args.file1)
    data2 = load_jsonl(args.file2)
    
    if not data1 or not data2:
        print("无法加载数据集，请检查文件路径")
        return
    
    print(f"数据集1: {len(data1)} 条记录")
    print(f"数据集2: {len(data2)} 条记录")
    print(f"比较字段: {args.field}")
    
    # 提取指定字段
    inputs1 = extract_inputs(data1, args.field)
    inputs2 = extract_inputs(data2, args.field)
    
    print(f"数据集1 unique {args.field}s: {len(inputs1)}")
    print(f"数据集2 unique {args.field}s: {len(inputs2)}")
    
    # 分析各自内部的重复情况
    print(f"\n=== 数据集内部重复分析 ({args.field}字段) ===")
    dup1 = analyze_duplicates(data1, args.field)
    dup2 = analyze_duplicates(data2, args.field)
    
    print(f"数据集1:")
    print(f"  总记录数: {dup1['total_items']}")
    print(f"  总{args.field}数: {dup1['total_inputs']}")
    print(f"  唯一{args.field}数: {dup1['unique_inputs']}")
    print(f"  重复{args.field}数: {dup1['duplicate_count']}")
    print(f"  内部重复比例: {dup1['duplicate_ratio']:.4f}")
    
    print(f"数据集2:")
    print(f"  总记录数: {dup2['total_items']}")
    print(f"  总{args.field}数: {dup2['total_inputs']}")
    print(f"  唯一{args.field}数: {dup2['unique_inputs']}")
    print(f"  重复{args.field}数: {dup2['duplicate_count']}")
    print(f"  内部重复比例: {dup2['duplicate_ratio']:.4f}")
    
    # 计算两个数据集之间的重叠
    print(f"\n=== 数据集间重叠分析 ({args.field}字段) ===")
    overlap_stats = calculate_overlap_ratio(inputs1, inputs2)
    
    print(f"交集大小: {overlap_stats['intersection_size']}")
    print(f"并集大小: {overlap_stats['union_size']}")
    print(f"数据集1在数据集2中的重复比例: {overlap_stats['overlap_ratio_1_to_2']:.4f}")
    print(f"数据集2在数据集1中的重复比例: {overlap_stats['overlap_ratio_2_to_1']:.4f}")
    print(f"Jaccard相似度: {overlap_stats['jaccard_similarity']:.4f}")
    
    # 显示不同的input
    if args.show_different:
        print_different_inputs(inputs1, inputs2, args.max_examples)
    
    # 输出到文件
    if args.output:
        result = {
            'field_name': args.field,
            'dataset1_stats': dup1,
            'dataset2_stats': dup2,
            'overlap_stats': overlap_stats
        }
        with open(args.output, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=2)
        print(f"\n结果已保存到: {args.output}")

if __name__ == "__main__":
    main()