#!/usr/bin/env python3
import subprocess
import itertools
import time

# Define the hyperparameter grids.
learning_rates = [0.0025, 0.001, 0.00075, 0.0005, 0.00025, 0.0001, 0.000075, 0.00005, 0.000025, 0.00001]
weight_decays = [
    0.001, 0.00075, 0.0005, 0.00025, 0.0001, 0.000075, 0.00005, 0.000025, 0.00001, 0.0000075,
    0.000005, 0.0000025, 0.000001, 0.00000075, 0.0000005, 0.00000025, 0.0000001, 0.000000075, 0.00000005, 0.00000001
]

# List of k values to iterate over.
ks = [50]

# Default values for other parameters (these must match those expected by your training script).
p = "59"
batch_size = "59"
optimizer = "adam"          # Change to "SGD..." if desired.
epochs = "2000"
batch_experiment = "random_random"
num_neurons = "8"
zeta = "1"
momentum = "0.0"            # Not used if optimizer is "adam", but required.
injected_noise = "0.0"      # Set as needed.
num_mlp_layers = "1"
# List of random seeds from 1 to 10 (as strings).
random_seeds = [str(seed) for seed in range(1, 20)]

# Path to the sbatch script (using the bare name as you use 'sbatch polynomials_momentum.sh').
sbatch_script = "modular_addition_2.sh"

# Loop over each combination of learning_rate, weight_decay, and k.
for lr, wd, k_val in itertools.product(learning_rates, weight_decays, ks):
    # Convert the hyperparameter values to strings.
    lr_str = str(lr)
    wd_str = str(wd)
    k_str = str(k_val)
    
    # Recompute training_set_size as k * batch_size.
    training_set_size = str(int(k_str) * int(batch_size))
    
    # Build the command-line arguments list.
    # The order is:
    #   <learning_rate> <weight_decay> <p> <batch_size> <optimizer> <epochs>
    #   <k> <batch_experiment> <num_neurons> <zeta> <training_set_size>
    #   <momentum> <injected_noise> <random_seed_int> [<random_seed_int> ...]
    args = [
        lr_str,
        wd_str,
        p,
        batch_size,
        optimizer,
        epochs,
        k_str,
        batch_experiment,
        num_neurons,
        zeta,
        training_set_size,
        momentum,
        injected_noise,
        num_mlp_layers,
    ] + random_seeds  # Append seeds 1 through 10.
    
    # Build the command. The 'sbatch' command takes the script name and the arguments.
    cmd = ["sbatch", sbatch_script] + args
    
    # Print the command that will be executed.
    print("Submitting job with command:")
    print(" ".join(cmd))
    
    # Submit the job. (This call will block until the system submits the job).
    try:
        result = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("Job submitted successfully. Output:")
        print(result.stdout.decode())
        time.sleep(0.03)
    except subprocess.CalledProcessError as e:
        print("Error submitting job:")
        print(e.stderr.decode())
