#!/bin/bash
# Script to run experiments with specific depth and hidden_channels combinations
# - Tests 4 equations with specific depth and hidden_channels combinations
# - Runs experiments in batches, one per available GPU
# - For optimal GPU utilization

# Configuration
NUM_GPUS=1  # Number of GPUs to use (0 to NUM_GPUS-1)

# Set environment variable to limit thread usage for better GPU utilization
export OMP_NUM_THREADS=1

# Data paths for different PDE equations
declare -a EQUATION_NAMES=("ks") #("heat" "burgers" "kdv" "ks")
declare -a TRAIN_PATHS=("ks1d_4096_256_train_500_wo_burnin.parquet") #("heat1d_8192_256_train_500.parquet" "burgers1d_8192_256_train_500.parquet" "kdv1d_1024_256_train_500.parquet" "ks1d_4096_256_train_500.parquet")
declare -a VALID_PATHS=("ks1d_4096_256_valid_100_long_wo_burnin.parquet") #("heat1d_8192_256_valid_100.parquet" "burgers1d_8192_256_valid_100.parquet" "kdv1d_1024_256_valid_100.parquet" "ks1d_4096_256_valid_100_long.parquet")

# Define experiment configurations as pairs of depth:hidden_channels
declare -a EXPERIMENT_CONFIGS=(
    "1:64"
    "2:64"
    "4:64"
    "8:64"
    "16:64"
    "32:64"
    "2:80"    # depth=2, hidden_channels=80
    "4:112"   # depth=4, hidden_channels=112
    "8:120"   # depth=8, hidden_channels=120
    "16:112"  # depth=16, hidden_channels=112
    "8:40"    # depth=8, hidden_channels=40
    "8:80"    # depth=8, hidden_channels=80
    "8:120"   # depth=8, hidden_channels=120 (duplicate, will be handled)
    "8:160"   # depth=8, hidden_channels=160
)

# Function to run a single experiment
run_experiment() {
    local device=$1
    local train_path=$2
    local valid_path=$3
    local depth=$4
    local hidden_channels=$5
    
    # Extract equation name from train path for logging
    local equation_name=${train_path%%_*}
    
    echo "Starting experiment on GPU $device: equation=$equation_name, depth=$depth, hidden_channels=$hidden_channels"
    
    # Run the experiment with the specified parameters
    python main.py --config configs/baselines/baseline_fno.yaml \
                   --set device=$device \
                   --set data.train_path=$train_path \
                   --set data.valid_path=$valid_path \
                   --set model.params.depth=$depth \
                   --set model.params.hidden_channels=$hidden_channels
    
    return $?
}

# Main function to orchestrate all experiments
main() {
    # Create an array to hold all experiment configurations
    declare -a all_experiments=()
    
    # Fill the array with all experiment configurations
    for ((eq_idx=0; eq_idx<${#EQUATION_NAMES[@]}; eq_idx++)); do
        for config in "${EXPERIMENT_CONFIGS[@]}"; do
            IFS=':' read -r depth hidden_channels <<< "$config"
            all_experiments+=("$eq_idx:$depth:$hidden_channels")
        done
    done
    
    # Count total experiments
    local total_experiments=${#all_experiments[@]}
    echo "Preparing to run $total_experiments total experiments"
    
    # Create arrays to track active PIDs for each GPU
    declare -a gpu_pids=()
    declare -a gpu_busy=()
    for ((i=0; i<NUM_GPUS; i++)); do
        gpu_pids[$i]=""
        gpu_busy[$i]=0
    done
    
    local active_experiments=0
    local completed_experiments=0
    local experiment_index=0
    
    echo "Starting experiment batches..."
    
    # Keep going until all experiments are completed
    while [ $completed_experiments -lt $total_experiments ]; do
        # First, check for and handle any completed experiments
        for ((i=0; i<NUM_GPUS; i++)); do
            if [ "${gpu_busy[$i]}" -eq 1 ]; then
                # Check if this GPU's process has finished
                if ! kill -0 ${gpu_pids[$i]} 2>/dev/null; then
                    # Process is done
                    wait ${gpu_pids[$i]} || echo "WARNING: Process ${gpu_pids[$i]} exited with non-zero status"
                    ((completed_experiments++))
                    ((active_experiments--))
                    gpu_busy[$i]=0
                    echo "Completed experiment $completed_experiments/$total_experiments on GPU $i"
                fi
            fi
        done
        
        # Then start new experiments on free GPUs if we have more to run
        while [ $experiment_index -lt $total_experiments ] && [ $active_experiments -lt $NUM_GPUS ]; do
            # Find a free GPU
            for ((i=0; i<NUM_GPUS; i++)); do
                if [ "${gpu_busy[$i]}" -eq 0 ]; then
                    # This GPU is free, start an experiment
                    config=${all_experiments[$experiment_index]}
                    IFS=':' read -r eq_idx depth hidden_channels <<< "$config"
                    
                    # Start the experiment on this GPU
                    run_experiment $i \
                                  "${TRAIN_PATHS[$eq_idx]}" \
                                  "${VALID_PATHS[$eq_idx]}" \
                                  "$depth" \
                                  "$hidden_channels" &
                    
                    # Store the process ID and mark GPU as busy
                    gpu_pids[$i]=$!
                    gpu_busy[$i]=1
                    ((active_experiments++))
                    ((experiment_index++))
                    
                    echo "Scheduled experiment $experiment_index/$total_experiments on GPU $i (equation: ${EQUATION_NAMES[$eq_idx]}, depth: $depth, hidden_channels: $hidden_channels, active: $active_experiments)"
                    
                    # Break the GPU loop to recheck how many active experiments we have
                    break
                fi
            done
            
            # If all GPUs are busy, break out of the experiment starting loop
            if [ $active_experiments -eq $NUM_GPUS ]; then
                break
            fi
        done
        
        # Sleep briefly to avoid hammering the CPU with checks
        sleep 1
    done
    
    echo "All experiments completed. Ran $completed_experiments of $total_experiments experiments."
}

# Execute the main function
main