import os

# Fixed config
lr = 0.5
steps = 1000
repetitions = 10
modes = ["random", "SGD", "SAM"]
rho = [0.01]
alpha = [0.1, 0.03, 0.05, 0.07]

job_dir = "jobs"
os.makedirs(job_dir, exist_ok=True)

for m in modes:
    if m == "SGD" or m == "random":
        job_name = f"job_mode_{m}"
        script_path = os.path.join(job_dir, f"{job_name}.sh")
        r = 0
        a = 0
        with open(script_path, "w") as f:
            f.write(f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --output={job_name}.out
#SBATCH --error={job_name}.err
#SBATCH --time=04:00:00
#SBATCH --gpus-per-node=1

module load gcc/12.3.0
module load cuda/12.1.1
module load rcac
module load conda
conda activate env

cd $SLURM_SUBMIT_DIR/linear-stability

python run_exp.py \\
    --lr {lr} \\
    --steps {steps} \\
    --repetitions {repetitions} \\
    --mode {m} \\
    --rho {r} \\
    --alpha {a}
""")
        os.system(f"sbatch {script_path}")

    elif m == "SAM":
        for r in rho:
            for a in alpha:
                job_name = f"job_mode_{m}_r{r}_a{a}"
                script_path = os.path.join(job_dir, f"{job_name}.sh")

                with open(script_path, "w") as f:
                    f.write(f"""#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --output={job_name}.out
#SBATCH --error={job_name}.err
#SBATCH --time=04:00:00
#SBATCH --cpus-per-task=1
#SBATCH --gpus-per-node=1

module load gcc/12.3.0
module load cuda/12.1.1
module load rcac
module load conda
conda activate env

cd $SLURM_SUBMIT_DIR/linear-stability

python run_exp.py \\
    --lr {lr} \\
    --steps {steps} \\
    --repetitions {repetitions} \\
    --mode {m} \\
    --rho {r} \\
    --alpha {a}
""")

                os.system(f"sbatch {script_path}")
