import subprocess
import time

import argparse
# Argument parser for command line arguments
parser = argparse.ArgumentParser(description='Generate scaling data for ListOps.')
parser.add_argument('--num_train', type=int, default=50_000, help='Number of training examples to generate.')
parser.add_argument('--max_depth', type=int, default=2, help='Maximum depth of the operations.')

args = parser.parse_args()

num_train_base = args.num_train
max_depth = args.max_depth
save_dir = '../data/listops-scaling/'

# Generate list of prime numbers between m and n
m,n = 11, 114
primes = []
for i in range(m,n):
    pr = True
    for j in range(2, int(i**0.5)+1):
        if i % j == 0:
            pr = False
            break
    if pr:
        primes.append(i)

# only sample every 3rd prime to reduce total number of runs
primes = primes[::3]

processes = []
# for funcs in ['add', 'max median add']:
# for funcs in ['add', 'prod add']:
# for funcs in ['prod', 'prod nadd', 'nadd']:


for funcs in ['prod max', 'prod']:
    for mod in primes + [i+1 for i in primes]:
        num_train = int(num_train_base/10 * mod)
        cmd = f"python generate_data.py --num_train {num_train} --max_depth {max_depth} --funcs_to_use {funcs} --mod {mod} --save_dir {save_dir} &"
        # Using subprocess.Popen to launch in parallel
        process = subprocess.Popen(cmd, shell=True)
        processes.append(process)
            
        time.sleep(0.5)
        
    
# Wait for all processes to finish
for p in processes:
    p.wait()