#!/bin/bash

SCRIPTS=("sum_2" "sum_4" "sum_4_self")
SEEDS=(
    1234 5073 7762 3576 2369 3490 4054 5735 1484 9087 1915 1813 8597 6089 9644 1292 6210 7823 7205 5692
    7780 2061 8519 2423 7419 2259 5620 5970 6887 6430 6540 5441 4582 3335 5976 4762 5566 5069 7520 8132
)
PROVS=("dtkp-am")

n=${#SEEDS[@]}

for SCRIPT in "${SCRIPTS[@]}"; do
    for PROV in "${PROVS[@]}"; do
        gpu_count=0
        flag=true

        while $flag ; do
            pids=()

            for gpu in "$@"; do
                if ((gpu_count >= n)) ; then
                    flag=false
                    break
                fi
                seed=${SEEDS[$gpu_count]}
                gpu_count=$((gpu_count + 1))
                export CUDA_VISIBLE_DEVICES=$gpu
                nohup python $SCRIPT.py --device="cuda" --provenance=$PROV --seed=$seed --n-epochs=5 > /dev/null 2>&1 &
                pids+=($!)
            done

            for pid in "${pids[@]}"; do
                wait $pid
            done
        done
    done
done

printf "import wandb\nwandb.login()\nwandb.init()\nwandb.alert(title='Trials finished',text='run_mnist finished')" | python