import os
import argparse
from itertools import product

parser = argparse.ArgumentParser(f'add multi-gpu option')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parallel_arg = parser.parse_args()
parallel = parallel_arg.parallel
ks = [4]
kernels = ["gd"]
models = ["KEGIN"]
grid = product(kernels, ks, models)

if parallel:                                                                                                                                                                             
    for parameters in grid:
        kernel, k, model = parameters
        script = f"python3 train_CSL.py --model_name={model} --kernel={kernel} --K={k}  --num_layer=2 --reprocess --wo_path_encoding --parallel"
        os.system(script)

else:
    for parameters in grid:
        kernel, k, model = parameters

        script = f"python3 train_CSL.py --model_name={model} --kernel={kernel} --K={k}  --num_layer=2 --reprocess --wo_path_encoding --sample"
        os.system(script)



