import json
import os
# IT IS JUST A SMALL TOOL, OK TO DELETE IT
def process_protein_scores(file_path, output_dir):
    # 第一步：检查条目个数
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return
    except json.JSONDecodeError as e:
        print(f"Failed to parse JSON: {e}")
        return

    # 确保数据是字典
    if not isinstance(data, dict):
        print("Invalid JSON format: Expected a dictionary.")
        return

    # 检查条目个数
    total_entries = len(data)
    print(f"Total entries in the file: {total_entries}")

    # 第二步：按照loss排序
    sorted_items = sorted(data.items(), key=lambda x: x[1])  # 按loss排序

    # 取出最低、中间、最高的各5000个
    lowest_5000 = dict(sorted_items[:5000])
    middle_5000_start = max(0, (total_entries // 2) - 2500)  # 确保索引有效
    middle_5000 = dict(sorted_items[middle_5000_start:middle_5000_start + 5000])
    highest_5000 = dict(sorted_items[-5000:])

    # 第三步：保存为三个不同的JSON文件
    os.makedirs(output_dir, exist_ok=True)

    output_files = {
        "lowest_5000.json": lowest_5000,
        "middle_5000.json": middle_5000,
        "highest_5000.json": highest_5000
    }

    for file_name, content in output_files.items():
        output_path = os.path.join(output_dir, file_name)
        with open(output_path, 'w', encoding='utf-8') as out_file:
            json.dump(content, out_file, indent=4)
        print(f"Saved {file_name} with {len(content)} entries to {output_path}")

# 示例使用
file_path = "/storage/tancheng/dataset_condensation/dataset_condensation/results/alltest_forward/loss_record.json"
output_dir = "/storage/tancheng/dataset_condensation/dataset_condensation/results/alltest_forward"
process_protein_scores(file_path, output_dir)
