files = ['olympiad-bench-OE', 'olympiad-bench', 'omni_math_proof_new', 'putnam_archive_filtered', 'usamo_processed_filtered']

import os
from copy import deepcopy
import json
import random

data_list = []

for file in files:
    file_path = f"dataset/{file}.jsonl"
    with open(file_path, 'r', encoding='utf-8') as f:
        for i,line in enumerate(f.readlines()):
            data = json.loads(line)
            if 'geometry' in data.get('subfield', '').lower():
                continue
            for j, solution in enumerate(data.get('solution', [])):
                new_data = deepcopy(data)
                new_data['solution'] = solution
                new_data['proof_id'] = j
                new_data['source'] = f"{file}_problem_{i}_proof_{j}"
                data_list.append(new_data)

simple_negative = []
num_samples = 3
for data in data_list:
    # find all '.\n' in solution 
    solution = data['solution']
    positions = []
    start = 0
    while True:
        pos = solution.find('.\n', start)
        if pos == -1:
            break
        positions.append(pos + 2)  
        # +2 to include '.\n'
        start = pos + 1

    front_positions = [pos for pos in positions if pos <= len(solution) - len(solution) // 4]
    truncate_positions = random.sample(front_positions, k=min(num_samples, len(front_positions))) if front_positions else []

    for pos in truncate_positions:
        new_data = deepcopy(data)
        new_data['solution'] = solution[:pos]
        simple_negative.append(new_data)
    
    back_positions = [pos for pos in positions if pos >= len(solution) // 4]
    truncate_positions = random.sample(back_positions, k=min(num_samples, len(back_positions))) if back_positions else []
    
    for pos in truncate_positions:
        new_data = deepcopy(data)
        new_data['solution'] = solution[pos:]
        simple_negative.append(new_data)

    position_pairs = []

    for i in range(len(positions)):
        if positions[i] > len(solution) // 3:
            break
        for j in range(i + 1, len(positions)):
            if positions[j] <= 2 * (len(solution) // 3):
                continue
            position_pairs.append((positions[i], positions[j]))
    
    mask_position_pairs = random.sample(position_pairs, k=min(num_samples, len(position_pairs))) if position_pairs else []

    for pos1, pos2 in mask_position_pairs:
        new_data = deepcopy(data)
        if positions:
            new_data['solution'] = solution[:pos1] + solution[pos2:]
        simple_negative.append(new_data)

print(f"Total {len(simple_negative)} simple negative samples generated.")

output_path = 'dataset/simple_negative.jsonl'
with open(output_path, 'w', encoding='utf-8') as f:
    for data in simple_negative:
        f.write(json.dumps(data, ensure_ascii=False) + '\n')