import json
import sys
from pathlib import Path

# 默认文件路径（可以通过命令行参数修改）
default_file1_path = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/add1k-max9k-redo-110step-stage2-grpo/valid/110_16384.jsonl"
default_file2_path = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-4b-oct/add1k-max9k-redo-110step-stage2-grpo/valid/140_16384.jsonl"

# 输出路径
default_output_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/eval_scripts/analysis/samples/right_to_wrong_samples.jsonl"

def load_jsonl(file_path):
    """加载jsonl文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

def main():
    # 解析命令行参数
    if len(sys.argv) == 4:
        file1_path = sys.argv[1]
        file2_path = sys.argv[2]
        output_path = sys.argv[3]
    elif len(sys.argv) == 3:
        file1_path = sys.argv[1]
        file2_path = sys.argv[2]
        output_path = default_output_path
    else:
        file1_path = default_file1_path
        file2_path = default_file2_path
        output_path = default_output_path
    
    print(f"Loading file 1: {file1_path}")
    try:
        data1 = load_jsonl(file1_path)
        print(f"Loaded {len(data1)} samples from file 1")
    except FileNotFoundError:
        print(f"Error: File not found: {file1_path}")
        print(f"\nUsage: python {sys.argv[0]} <file1_path> <file2_path> [output_path]")
        sys.exit(1)
    
    print(f"Loading file 2: {file2_path}")
    try:
        data2 = load_jsonl(file2_path)
        print(f"Loaded {len(data2)} samples from file 2")
    except FileNotFoundError:
        print(f"Error: File not found: {file2_path}")
        sys.exit(1)
    
    # 创建字典，使用问题作为key（假设数据按顺序对应，或者有唯一标识）
    # 如果有唯一ID，可以用ID作为key
    # 这里先假设两个文件的数据是按顺序一一对应的
    
    if len(data1) != len(data2):
        print(f"Warning: File sizes don't match! File1: {len(data1)}, File2: {len(data2)}")
    
    # 收集从正确变为错误的样本和从错误变为正确的样本
    right_to_wrong = []
    wrong_to_right = []
    
    for i in range(min(len(data1), len(data2))):
        item1 = data1[i]
        item2 = data2[i]
        
        # 检查是否有score字段
        score1 = item1.get('score', None)
        score2 = item2.get('score', None)
        
        if score1 is None or score2 is None:
            print(f"Warning: Sample {i} missing score field")
            continue
        
        # 找出第一个文件正确（score=1）但第二个文件错误（score!=1）的样本
        if score1 == 1 and score2 != 1:
            right_to_wrong.append({
                'index': i,
                'file1_score': score1,
                'file2_score': score2,
                'file1_data': item1,
                'file2_data': item2
            })
        # 找出第一个文件错误（score!=1）但第二个文件正确（score=1）的样本
        elif score1 != 1 and score2 == 1:
            wrong_to_right.append({
                'index': i,
                'file1_score': score1,
                'file2_score': score2,
                'file1_data': item1,
                'file2_data': item2
            })
    
    # 统计
    print(f"\n{'='*60}")
    print(f"Statistics:")
    print(f"{'='*60}")
    print(f"Total samples compared: {min(len(data1), len(data2))}")
    
    # 额外统计
    correct_in_both = sum(1 for i in range(min(len(data1), len(data2))) 
                          if data1[i].get('score') == 1 and data2[i].get('score') == 1)
    wrong_in_both = sum(1 for i in range(min(len(data1), len(data2))) 
                        if data1[i].get('score') != 1 and data2[i].get('score') != 1)
    
    print(f"Correct in both files: {correct_in_both}")
    print(f"Wrong in both files: {wrong_in_both}")
    print(f"Wrong to Right samples: {len(wrong_to_right)}")
    print(f"Right to Wrong samples: {len(right_to_wrong)}")
    
    if len(data1) > 0:
        file1_correct = sum(1 for item in data1 if item.get('score') == 1)
        file2_correct = sum(1 for item in data2 if item.get('score') == 1)
        print(f"\nFile 1 accuracy: {file1_correct}/{len(data1)} = {file1_correct/len(data1)*100:.2f}%")
        print(f"File 2 accuracy: {file2_correct}/{len(data2)} = {file2_correct/len(data2)*100:.2f}%")
    
    # 保存结果
    print(f"\n{'='*60}")
    print(f"Saving results...")
    
    # 保存从对到错的样本
    print(f"Saving Right to Wrong samples to: {output_path}")
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in right_to_wrong:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print(f"Saved {len(right_to_wrong)} samples to {output_path}")
    
    # 保存从错到对的样本
    # 生成wrong_to_right的输出路径
    output_path_obj = Path(output_path)
    wrong_to_right_path = str(output_path_obj.parent / output_path_obj.name.replace('right_to_wrong', 'wrong_to_right'))
    print(f"\nSaving Wrong to Right samples to: {wrong_to_right_path}")
    with open(wrong_to_right_path, 'w', encoding='utf-8') as f:
        for item in wrong_to_right:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    print(f"Saved {len(wrong_to_right)} samples to {wrong_to_right_path}")
    
    print(f"{'='*60}")

if __name__ == "__main__":
    main()