import os
from itertools import product
import argparse

parser = argparse.ArgumentParser(f'add multi-gpu option')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parallel_arg = parser.parse_args()
parallel = parallel_arg.parallel

ks = [2]
kernels = ["spd"]
models = ["KEGIN"]
grid = product(kernels, ks, models)

if parallel:
    for parameters in grid:
        kernel, k, model = parameters
        script = f"python train_EXP.py --model_name={model} --kernel={kernel} --K={k} --num_layer=2 --reprocess --wo_path_encoding --parallel --sample"
        os.system(script)

else:
    for parameters in grid:
        kernel, k, model = parameters
        script = f"python train_EXP.py --model_name={model}  --kernel={kernel} --K={k} --num_layer=2 --wo_path_encoding --reprocess --sample"
        os.system(script)

