import itertools

cmd = """#!/bin/bash
#SBATCH --job-name={name}
#SBATCH --qos=big
#SBATCH --gres=gpu:1
#SBATCH --mem=64G
#SBATCH --partition=dgx
#SBATCH --cpus-per-task=9

cd $HOME/RxnFlow
conda activate rxn_flow
source activate rxn_flow

python -m script.opt {params} -o logs/{name} --override
"""

if __name__ == '__main__':
    params_list = [
        {
            'setup': ['rgfn_new_filtered'],
            'task': ['gsk'],
            'seed': [0, 1, 2],
            'subsampling_ratio': [1.0],
        },
    ]
    all_grid_dicts = []
    for params in params_list:
        # generate dict with all possible combinations using more_itertools
        keys, values = zip(*params.items())
        all_grid_dicts.extend(dict(zip(keys, v)) for v in itertools.product(*values))
    # remove all run_*.sh files
    import os
    for f in os.listdir('.'):
        if f.startswith('run_'):
            os.remove(f)
    for i, param_dict in enumerate(all_grid_dicts):
        param_str = ' '.join([f'--{k} {v}' for k, v in param_dict.items()])
        name = "_".join([f'{k}_{v}' for k, v in param_dict.items()])
        new_cmd = cmd.format(name=name, params=param_str)
        with open(f'run_{i}.sh', 'w') as f:
            f.write(new_cmd)

    # create run_all.sh
    with open('run_all.sh', 'w') as f:
        for i in range(len(all_grid_dicts)):
            f.write(f'sbatch scripts/gmum/run_{i}.sh\n')
