#!/bin/bash

mkdir -p logs

COMBINATIONS=(
    "resnet9 mnist"
    "lenet mnist"
    "lenet svhn"
    "resnet9 svhn"
    "lenet cifar10"
    "resnet9 cifar10"
    "resnet18 cifar10"
    "resnet9 cifar100"
    "resnet18 cifar100"
    "resnet18 tinyimagenet"
)
TOTAL_TASKS=${#COMBINATIONS[@]}
if [ $# -eq 1 ]; then
    if [[ $1 =~ ^[0-9]+$ ]]; then
        if [ $1 -ge 1 ] && [ $1 -le $TOTAL_TASKS ]; then
            index=$(($1-1))
            SELECTED_COMBINATIONS=("${COMBINATIONS[$index]}")
            echo "Only run training job $1: ${SELECTED_COMBINATIONS[0]}"
        else
            echo "Error: argument must be between 1 and $TOTAL_TASKS"
            exit 1
        fi
    else
        echo "Error: argument must be a number"
        exit 1
    fi
else
    SELECTED_COMBINATIONS=("${COMBINATIONS[@]}")
    echo "Starting $TOTAL_TASKS training jobs..."
fi
STARTED_TASKS=0
for combo in "${SELECTED_COMBINATIONS[@]}"; do
    set -- $combo
    model=$1
    dataset=$2
    log_file="logs/${model}_${dataset}.log"
    echo "Starting training: model=$model, dataset=$dataset"
    nohup python train.py --model $model --dataset $dataset --epochs 200 > $log_file 2>&1 &
    pid=$!
    echo "PID: $pid, log file: $log_file"
    STARTED_TASKS=$((STARTED_TASKS + 1))
    echo "$STARTED_TASKS/${#SELECTED_COMBINATIONS[@]} jobs started"
    sleep 1
done
echo "All training jobs started!"
echo "Use 'ps aux | grep train.py' to check running jobs"
echo "Use 'tail -f logs/<model>_<dataset>.log' to view a specific job's log"
