import subprocess
import time

import argparse
# Argument parser for command line arguments
parser = argparse.ArgumentParser(description='Generate scaling data for ListOps.')
parser.add_argument('--num_train', type=int, default=50_000, help='Number of training examples to generate.')
parser.add_argument('--max_depth', type=int, default=2, help='Maximum depth of the operations.')

args = parser.parse_args()

num_train_base = args.num_train
max_depth = args.max_depth
save_dir = '../data/listops-scaling/'

processes = []
# for funcs in ['add', 'max median add']:
for funcs in ['add', 'prod add', 'prod']:
    for p in range(5):
        p = p #+ 0.5 
        mod = int(10 * 2**p)
        num_train = int(num_train_base/10 * mod)
        cmd = f"python generate_data.py --num_train {num_train} --max_depth {max_depth} --funcs_to_use {funcs} --mod {mod} --save_dir {save_dir} &"
        # Using subprocess.Popen to launch in parallel
        process = subprocess.Popen(cmd, shell=True)
        processes.append(process)
            
        time.sleep(0.5)
        
    
# Wait for all processes to finish
for p in processes:
    p.wait()