import os
import re
from collections import defaultdict
import json
from tqdm import tqdm

task = 'aime_2025'
amlt_results_dir = '/n/fs/nlp-il-scale/anonymous_generations/aime-06-21'
EXPECTED_NUM_JSON_FILES = 30

# first get high-level dirs
high_dirs = os.listdir(amlt_results_dir)

# look for duplicate seeds. if there are, throw one out
seed_to_dir = defaultdict(lambda: defaultdict(list))
dir_to_file_count = defaultdict(int)
for high_dir in high_dirs:
    match = re.search(r'see_(\d+)(?:-|$)', high_dir)
    if match:
        seed = match.group(1)
    else:
        print('WARNING: No seed found in dir name', high_dir)
        continue

    prefix = high_dir.split(f'see_{seed}')[0]

    seed_to_dir[prefix][seed].append(high_dir)
    for root, dirs, files in os.walk(os.path.join(amlt_results_dir, high_dir)):
        for file in files:
            if file.endswith('.json'):
                dir_to_file_count[high_dir] += 1

filtered_high_dirs = []
for prefix in seed_to_dir:
    for seed in seed_to_dir[prefix]:
        if len(seed_to_dir[prefix][seed]) > 1:
            print('WARNING: Seed', seed, 'has', len(seed_to_dir[prefix][seed]), 'files')
            # pick the one with the most files
            max_dir = max(seed_to_dir[prefix][seed], key=lambda x: dir_to_file_count[x])
            filtered_high_dirs.append(max_dir)
        else:
            filtered_high_dirs.append(seed_to_dir[prefix][seed][0])

high_dirs = filtered_high_dirs
print('Done filtering high dirs!')

# list all json files in current directory and subdirectories
json_files = []
for high_dir in high_dirs:
    high_dir_json_files = []
    for root, dirs, files in os.walk(os.path.join(amlt_results_dir, high_dir)):
        for file in files:
            if file.endswith('.json'):
                high_dir_json_files.append(os.path.join(root, file))
    if len(high_dir_json_files) == EXPECTED_NUM_JSON_FILES:
        json_files.extend(high_dir_json_files)
    else:
        print('WARNING: High dir', high_dir, 'has', len(high_dir_json_files), 'json files, expected', EXPECTED_NUM_JSON_FILES)

print(len(json_files))
model_prompt_idx_to_files = defaultdict(lambda: defaultdict(list))
for file in json_files:
    file_name = file.split('/')[-1]
    if 'qwen-25-3b' in file_name:
        policy = 'qwen-25-3b'
    elif 'llama-3-8b' in file_name:
        policy = 'llama-3-8b'
    elif 'llama-3-3b' in file_name:
        policy = 'llama-3-3b'
    elif 'mistral-7b' in file_name:
        policy = 'mistral-7b'
    elif 'phi-35-mini' in file_name:
        policy = 'phi-35-mini'
    elif 'qwen-25-05b' in file_name:
        policy = 'qwen-25-05b'
    elif 'qwen-25-7b' in file_name:
        policy = 'qwen-25-7b'
    elif 'qwen-25-14b' in file_name:
        policy = 'qwen-25-14b'
    elif 'qwen-25-32b' in file_name:
        policy = 'qwen-25-32b'
    elif 'qwen-25-coder-3b' in file_name:
        policy = 'qwen-25-coder-3b'
    elif 'qwen-25-coder-7b' in file_name:
        policy = 'qwen-25-coder-7b'
    elif 'qwen-25-coder-14b' in file_name:
        policy = 'qwen-25-coder-14b'
    elif 'phi-4' in file_name:
        policy = 'phi-4'
    elif 'phi-3-medium' in file_name:
        policy = 'phi-3-medium'
    else:
        raise ValueError(f'Unknown policy: {file_name}')

    # parse prompt idx from file name
    prompt_idx = re.search(r'prompt-idx-(\d+)', file_name).group(1)
    model_prompt_idx_to_files[policy][prompt_idx].append(file)

# NOTE: it's okay if the number of seeds per question differs between policies, but it should be consistent for each prompt within a policy
for p in model_prompt_idx_to_files:
    seed_lens = set()
    prompt_idx_to_num_files = defaultdict(int)
    for i in model_prompt_idx_to_files[p]:
        seed_lens.add(len(model_prompt_idx_to_files[p][i]))
        prompt_idx_to_num_files[i] = len(model_prompt_idx_to_files[p][i])
    assert len(seed_lens) == 1, f'Policy {p} has {seed_lens} seeds'
    print(f'Policy {p} has {seed_lens.pop()} seeds')

for policy in tqdm(model_prompt_idx_to_files):
    print('policy', policy)
    for prompt_idx in tqdm(model_prompt_idx_to_files[policy]):
        data_with_all_seeds = []
        for file in model_prompt_idx_to_files[policy][prompt_idx]:
            try:
                with open(file, 'r') as f:
                    data = json.load(f)
                data_with_all_seeds.extend(data)
            except Exception as e:
                print('WARNING: Error loading file', file, e)
                continue

        os.makedirs(f'./data/{task}/{policy}/', exist_ok=True)
        seed_idx = re.search(r'seed-(\d+)', file.split('/')[-1]).group(1)
        with open(os.path.join(f'./data/{task}/{policy}/', file.split('/')[-1].replace(f'--seed-{seed_idx}', '')), 'w') as f:
            json.dump(data_with_all_seeds, f)


# # os.makedirs(f'./data/{task}/{policy}/', exist_ok=True)
    
