#!/bin/bash

# Set environment variable to limit thread usage for better GPU utilization
export OMP_NUM_THREADS=1

# Base configuration
CONFIG="/home/s1612415/RDS/recurrent-depth-pde/configs/experiment_receptive_field.yaml"
PROJECT="receptive_field"

# Define equation configurations
declare -A TRAIN_PATHS=(
  ["heat"]="heat1d_8192_256_train_500.parquet"
  ["burgers"]="burgers1d_8192_256_train_500.parquet"
  ["kdv"]="kdv1d_1024_256_train_500.parquet"
  ["ks"]="ks1d_4096_256_train_500.parquet"
)

declare -A VALID_PATHS=(
  ["heat"]="heat1d_8192_256_valid_100.parquet"
  ["burgers"]="burgers1d_8192_256_valid_100.parquet"
  ["kdv"]="kdv1d_1024_256_valid_100.parquet"
  ["ks"]="ks1d_4096_256_valid_100_long.parquet"
)

declare -A EPOCHS=(
  ["heat"]=100
  ["burgers"]=100
  ["kdv"]=10
  ["ks"]=10
)

# Define all kernel sizes
KERNEL_SIZES=(
  3 5 7 9 11 13 15
  17 19 21 23 25 27 29
  31 33 35 37 39 41 43
  45 47 49 51 53 55 57
  59 61 63 65
)

# Number of GPUs
NUM_GPUS=7

# Function to run a single experiment
run_experiment() {
    local device=$1
    local equation=$2
    local kernel_size=$3
    
    echo "Starting ${equation} experiment on device $device with kernel_size $kernel_size"
    python main.py --config="$CONFIG" \
        --set project="${PROJECT}_${equation}" \
        --set device="$device" \
        --set data.train_path="${TRAIN_PATHS[$equation]}" \
        --set data.valid_path="${VALID_PATHS[$equation]}" \
        --set model.params.recurrent_kwargs.kernel_size="$kernel_size" \
        --set epochs="${EPOCHS[$equation]}"
    
    return $?
}

# Main function to orchestrate all experiments
main() {
    # Create arrays to track active PIDs for each GPU
    declare -a gpu_pids=()
    declare -a gpu_busy=()
    for ((i=0; i<NUM_GPUS; i++)); do
        gpu_pids[$i]=""
        gpu_busy[$i]=0
    done
    
    # Loop through each equation
    for equation in "heat" "burgers" "kdv" "ks"; do
        echo "Starting experiments for ${equation} equation"
        
        local active_experiments=0
        local completed_experiments=0
        local experiment_index=0
        local total_experiments=${#KERNEL_SIZES[@]}
        
        echo "Preparing to run $total_experiments experiments for ${equation}"
        
        # Keep going until all experiments for this equation are completed
        while [ $completed_experiments -lt $total_experiments ]; do
            # First, check for and handle any completed experiments
            for ((i=0; i<NUM_GPUS; i++)); do
                if [ "${gpu_busy[$i]}" -eq 1 ]; then
                    # Check if this GPU's process has finished
                    if ! kill -0 ${gpu_pids[$i]} 2>/dev/null; then
                        # Process is done
                        wait ${gpu_pids[$i]} || echo "WARNING: Process ${gpu_pids[$i]} exited with non-zero status"
                        ((completed_experiments++))
                        ((active_experiments--))
                        gpu_busy[$i]=0
                        echo "Completed experiment $completed_experiments/$total_experiments for ${equation} on GPU $i"
                    fi
                fi
            done
            
            # Then start new experiments on free GPUs if we have more to run
            while [ $experiment_index -lt $total_experiments ] && [ $active_experiments -lt $NUM_GPUS ]; do
                # Find a free GPU
                for ((i=0; i<NUM_GPUS; i++)); do
                    if [ "${gpu_busy[$i]}" -eq 0 ]; then
                        # This GPU is free, start an experiment
                        kernel_size=${KERNEL_SIZES[$experiment_index]}
                        
                        # Start the experiment on this GPU
                        run_experiment $i "$equation" "$kernel_size" &
                        
                        # Store the process ID and mark GPU as busy
                        gpu_pids[$i]=$!
                        gpu_busy[$i]=1
                        ((active_experiments++))
                        ((experiment_index++))
                        
                        echo "Scheduled experiment $experiment_index/$total_experiments for ${equation} on GPU $i (kernel_size: $kernel_size, active: $active_experiments)"
                        
                        # Break the GPU loop to recheck how many active experiments we have
                        break
                    fi
                done
                
                # If all GPUs are busy, break out of the experiment starting loop
                if [ $active_experiments -eq $NUM_GPUS ]; then
                    break
                fi
            done
            
            # Sleep briefly to avoid hammering the CPU with checks
            sleep 1
        done
        
        echo "All experiments for ${equation} completed."
    done
    
    echo "All experiments for all equations completed."
}

# Execute the main function
main