import json
import random
# 配置：文件路径和对应的模型名称
files_and_models = [
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/Qwen3-4B-Base_32768_test.jsonl", "Qwen3-4B-Base"),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-BASELINE-8k-step330-valid_32768_test.jsonl", "DAPO-BASELINE-8k"),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/DAPO-Qwen3-4B-Base-deepscaler-40k-BASELINE-16k-step290-valid_32768_test.jsonl", "DAPO-BASELINE-16k"),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-overlong-filter-70step-valid_32768_test.jsonl", "DAPO-stage2-overlong-filter"),
    ("/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/results_oct/add1k-new-60steps-continue-max12k-150step-valid_32768_test.jsonl", "DAPO-stage2-max12k"),
]

# 从每个文件读取所有数据，按模型分组
data_by_model = {}
for file_path, model_name in files_and_models:
    try:
        model_data = []
        i = 0
        with open(file_path, 'r') as f:
            for line in f:
                item = json.loads(line)
                # 不再过滤data_source，保留所有数据
                model_data.append({
                    "id": i,
                    'model': model_name,
                    'prompt': item.get('prompt', ''),
                    'output': item.get('generated_text', ''),
                    'correctness': item.get('correctness', False),
                    'answer': item.get('answer', ''),
                    'data_source': item.get('data_source', '')
                })
                i += 1
        data_by_model[model_name] = model_data
        print(f"从 {model_name} 读取了 {len(model_data)} 条数据")
    except Exception as e:
        print(f"读取文件失败 {file_path}: {e}")

# 找到所有模型都有的题目（通过id对齐）
# 先找到最小的数据量
min_count = min(len(data) for data in data_by_model.values()) if data_by_model else 0
print(f"所有模型中最少有 {min_count} 条数据")

# 定义需要错误的模型和需要至少一个正确的模型
baseline_models = ["Qwen3-4B-Base", "DAPO-BASELINE-8k", "DAPO-BASELINE-16k"]
stage2_models = ["DAPO-stage2-overlong-filter", "DAPO-stage2-max12k"]

# 找到满足条件的题目id：baseline模型都错误，stage2模型至少一个正确
selected_ids = []
for idx in range(min_count):
    # 检查baseline模型是否都错误
    baseline_all_wrong = all(
        not data_by_model[model][idx]['correctness'] 
        for model in baseline_models
    )
    
    # 检查stage2模型是否至少有一个正确
    stage2_at_least_one_correct = any(
        data_by_model[model][idx]['correctness'] 
        for model in stage2_models
    )
    
    # 如果都满足条件，加入选中列表
    if baseline_all_wrong and stage2_at_least_one_correct:
        selected_ids.append(idx)

print(f"找到 {len(selected_ids)} 个满足条件的题目")

selected_ids = random.sample(selected_ids, 10)
print(f"选择的题目id: {selected_ids}")
# 从每个模型中选取相同id的题目
all_selected = []
for model_name, model_data in data_by_model.items():
    for item in model_data:
        if item['id'] in selected_ids:
            all_selected.append(item)
    print(f"从 {model_name} 选取了 {len([i for i in model_data if i['id'] in selected_ids])} 个题目")

print(f"总共选取 {len(all_selected)} 条数据")

# 保存到文件
output_file = "samples.json"
with open(output_file, 'w') as f:
    
    json.dump(all_selected, f, ensure_ascii=False, indent=2)

print(f"已保存到 {output_file}")
