import argparse
import os

parser = argparse.ArgumentParser(description='Feed in k and sigma')
parser.add_argument('-L', type=int, required=True, nargs='?', help="Depth")
parser.add_argument('-D', type=int, required=True, nargs='?', help="Input Dimension")
parser.add_argument('-N', type=int, required=False, nargs='?', help="Width")
parser.add_argument('-P', type=int, required=False, nargs='?', help="Dataset size")
parser.add_argument('-t', type=str, required=True, nargs='+', help="Task")
parser.add_argument('-s', type=float, required=True, nargs='+', help="Initialization scale")
parser.add_argument('-n', type=int, required=False, default=5, nargs='?', help="Number of repeats")
parser.add_argument('-d', type=int, required=False, default=5, nargs='?', help="Number of datasets")

args = parser.parse_args()


depth = args.L
dim = args.D
width = args.N
dataset_size = args.P
ts = args.t
sigmas = args.s
num_repeats = args.n
num_datasets = args.d

for sigma in sigmas:
  for d_key in range(num_datasets):
    for t in ts:
      if width is None:
        file_name = f"bash_scripts/L={depth}_D={dim}_t={t}_s={sigma:.2f}_d={d_key}.sh"
      else:
        file_name = f"bash_scripts/L={depth}_D={dim}_N={width}_t={t}_s={sigma:.2f}_d={d_key}.sh"
      os.system('sbatch ' + file_name)
      