import json
from collections import defaultdict

def read_jsonl_file(file_path):
    """ȡJSONLļ"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line.strip()))
    return data

def write_jsonl_file(data, file_path):
    """дJSONLļ"""
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def categorize_by_difficulty(data_1_5b, data_7b, data_32b):
    """
    ģ͵ȷԽݷΪĸѶȼ
    - Level 1 (): 1.5Bȷ (correctness=1)
    - Level 2: 1.5B7Bȷ (1.5B correctness=0, 7B correctness=1)
    - Level 3: 1.5B7B32Bȷ (1.5B&7B correctness=0, 32B correctness=1)
    - Level 4 (): ģͶ (1.5B&7B&32B correctness=0)
    """
    
    # ʹproblemΪΨһʶƥ䲻ͬģ͵Ľ
    problem_dict = defaultdict(dict)
    
    # ռݣproblem
    for item in data_1_5b:
        problem_dict[item['problem']]['1.5b'] = item
    
    for item in data_7b:
        problem_dict[item['problem']]['7b'] = item
    
    for item in data_32b:
        problem_dict[item['problem']]['32b'] = item
    
    # ʼĸѶȼб
    level_1 = []  # : 1.5Bȷ
    level_2 = []  # 1.5B7Bȷ
    level_3 = []  # 1.5B7B32Bȷ
    level_4 = []  # : ģͶ
    
    # ദ
    for problem, model_results in problem_dict.items():
        # ȡģ͵ȷԣĬΪ0ݲڣ
        correctness_1_5b = model_results.get('1.5b', {}).get('correctness', [0])[0] if '1.5b' in model_results else 0
        correctness_7b = model_results.get('7b', {}).get('correctness', [0])[0] if '7b' in model_results else 0
        correctness_32b = model_results.get('32b', {}).get('correctness', [0])[0] if '32b' in model_results else 0
        
        # Ѷȷ
        if correctness_1_5b == 1:
            # Level 1: ʹ1.5Bģ͵Ϊ
            level_1.append(model_results['1.5b'])
        elif correctness_7b == 1:
            # Level 2: ʹ7Bģ͵Ϊ
            level_2.append(model_results['7b'])
        elif correctness_32b == 1:
            # Level 3: ʹ32Bģ͵Ϊ
            level_3.append(model_results['32b'])
        else:
            # Level 4: ģͶʹһģ͵ݣѡ1.5B
            if '1.5b' in model_results:
                level_4.append(model_results['1.5b'])
            elif '7b' in model_results:
                level_4.append(model_results['7b'])
            elif '32b' in model_results:
                level_4.append(model_results['32b'])
    
    return level_1, level_2, level_3, level_4

def main():
    # ļ·
    file_1_5b = '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/step1_init_data-r1_1.5b.jsonl'
    file_7b = '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/step1_init_data-r1_7b.jsonl'
    file_32b = '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/step1_init_data-qwq_32b.jsonl'
    
    # ļ·
    output_files = [
        '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/difficulty_splits/level_1.jsonl',  # 
        '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/difficulty_splits/evel_2.jsonl',
        '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/difficulty_splits/level_3.jsonl',
        '/home/cwy/project/LLM_Inference/Curriculum_Reinforced_learning/data/step1/difficulty_splits/_level_4.jsonl'   # 
    ]
    
    print("ڶȡļ...")
    data_1_5b = read_jsonl_file(file_1_5b)
    data_7b = read_jsonl_file(file_7b)
    data_32b = read_jsonl_file(file_32b)
    
    print(f"1.5B: {len(data_1_5b)}")
    print(f"7B: {len(data_7b)}")
    print(f"32B: {len(data_32b)}")
    
    print("ڰѶȷ...")
    level_1, level_2, level_3, level_4 = categorize_by_difficulty(data_1_5b, data_7b, data_32b)
    
    print(f"Ѷȵȼ1 (): {len(level_1)} ")
    print(f"Ѷȵȼ2: {len(level_2)} ")
    print(f"Ѷȵȼ3: {len(level_3)} ")
    print(f"Ѷȵȼ4 (): {len(level_4)} ")
    
    # 浽ļ
    print("ڱ...")
    write_jsonl_file(level_1, output_files[0])
    write_jsonl_file(level_2, output_files[1])
    write_jsonl_file(level_3, output_files[2])
    write_jsonl_file(level_4, output_files[3])
    
    print("ɣļѱ档")
    
    # ʾһЩͳϢ
    total = len(level_1) + len(level_2) + len(level_3) + len(level_4)
    print(f"\nܼ: {total} Ψһ")
    print("Ѷȼռ:")
    print(f"Level 1 (): {len(level_1)/total*100:.2f}%")
    print(f"Level 2: {len(level_2)/total*100:.2f}%")
    print(f"Level 3: {len(level_3)/total*100:.2f}%")
    print(f"Level 4 (): {len(level_4)/total*100:.2f}%")

if __name__ == "__main__":
    main()