import itertools

cmd = """#!/bin/bash
#SBATCH --job-name={name}
#SBATCH --qos=big
#SBATCH --gres=gpu:1
#SBATCH --mem=64G
#SBATCH --partition=dgx
#SBATCH --cpus-per-task=9

cd $HOME/synflownet
conda activate synflow
source activate synflow

python src/gflownet/tasks/reactions_task.py {params} --wandb_run_name {name}
"""

if __name__ == '__main__':
    params_list = [

        {
            'setup': ['synflow_128'],
            'task': ['seh_reaction'],
            'reinforce': [False],
            'action_embedding': [True],
            'temp': [32],
            'seed': [0, 1, 2],
        },
        {
            'setup': ['synflow_128'],
            'task': ['gsk', 'jnk3'],
            'reinforce': [False],
            'action_embedding': [True],
            'temp': [16],
            'seed': [0, 1, 2],
        },
    ]
    all_grid_dicts = []
    for params in params_list:
        # generate dict with all possible combinations using more_itertools
        keys, values = zip(*params.items())
        all_grid_dicts.extend(dict(zip(keys, v)) for v in itertools.product(*values))
    # remove all run_*.sh files
    import os
    for f in os.listdir('..'):
        if f.startswith('run_'):
            os.remove(f)
    for i, param_dict in enumerate(all_grid_dicts):
        param_str = ' '.join([f'--{k} {v}' for k, v in param_dict.items()])
        name = "_".join([f'{k}_{v}' for k, v in param_dict.items()])
        new_cmd = cmd.format(name=name, params=param_str)
        with open(f'run_{i}.sh', 'w') as f:
            f.write(new_cmd)

    # create run_all.sh
    with open('run_all.sh', 'w') as f:
        for i in range(len(all_grid_dicts)):
            f.write(f'sbatch scripts/gmum/run_{i}.sh\n')
