#!/bin/bash
# Script to run recurrent depth experiments in parallel
# - Tests 4 equations with different depth, k_bar and hidden_channels combinations
# - Runs experiments in batches, one per available GPU
# - For each equation runs 6 experiments:
#   - Depth 1, k_bar 32: channels 48, 80, 96
#   - Depth 4, k_bar 8: channels 32, 48, 64

# Configuration
NUM_GPUS=7  # Number of GPUs to use (0 to NUM_GPUS-1)

# Set environment variable to limit thread usage for better GPU utilization
export OMP_NUM_THREADS=1

# Data paths for different PDE equations
declare -a EQUATION_NAMES=("burgers") #("heat" "burgers" "kdv" "ks")
declare -a TRAIN_PATHS=("burgers1d_8192_256_train_500.parquet") #("heat1d_8192_256_train_500.parquet" "burgers1d_8192_256_train_500.parquet" "kdv1d_1024_256_train_500.parquet" "ks1d_4096_256_train_500_wo_burnin.parquet")
declare -a VALID_PATHS=("burgers1d_8192_256_valid_100.parquet") #("heat1d_8192_256_valid_100.parquet" "burgers1d_8192_256_valid_100.parquet" "kdv1d_1024_256_valid_100.parquet" "ks1d_4096_256_valid_100_long_wo_burnin.parquet")

# Concatenation methods to test (only the new ones)
declare -a CONCAT_METHODS=("projection") #("add" "weighted_add" "projection" "concat")

# Experiment configurations
# Format: depth:k_bar:channels:layer_type (layer_type determines encoder_class and decoder_class)
declare -a EXP_CONFIGS=(
    "1:32:32:FourierLayer"
    "1:32:48:FourierLayer"
    "1:32:64:FourierLayer"
    "1:32:96:FourierLayer"
    "1:32:128:FourierLayer"
    "1:32:180:FourierLayer"
)

# Build the full experiment configurations array with concat methods
declare -a FULL_CONFIGS=()
for base_config in "${EXP_CONFIGS[@]}"; do
    for method in "${CONCAT_METHODS[@]}"; do
        FULL_CONFIGS+=("$base_config:$method")
    done
done

# Function to run a single experiment
run_experiment() {
    local device=$1
    local train_path=$2
    local valid_path=$3
    local depth=$4
    local k_bar=$5
    local hidden_channels=$6
    local layer_type=$7
    local combine_method=$8
    
    # Set encoder_class and decoder_class based on layer_type
    local encoder_class=$layer_type
    local decoder_class=$layer_type
    
    # Extract equation name from train path for logging
    local equation_name=${train_path%%_*}
    
    echo "Starting experiment on GPU $device: equation=$equation_name, depth=$depth, k_bar=$k_bar, hidden_channels=$hidden_channels, encoder/decoder=$layer_type, combine_method=$combine_method"
    
    # Run the experiment with the specified parameters
    python main.py --config configs/experiment_concat_test.yaml \
                   --set device=$device \
                   --set data.train_path=$train_path \
                   --set data.valid_path=$valid_path \
                   --set model.params.recurrent_depth=$depth \
                   --set model.params.hidden_channels=$hidden_channels \
                   --set model.params.recurrent_k_distribution.k_bar=$k_bar \
                   --set model.params.encoder_class=$encoder_class \
                   --set model.params.decoder_class=$decoder_class \
                   --set model.params.combine_method=$combine_method \
                   --set tags+="[\"$combine_method\"]"
    
    return $?
}

# Main function to orchestrate all experiments
main() {
    # Create an array to hold all experiment configurations
    declare -a experiment_configs=()
    
    # Fill the array with all experiment configurations
    for ((eq_idx=0; eq_idx<${#EQUATION_NAMES[@]}; eq_idx++)); do
        for config in "${FULL_CONFIGS[@]}"; do
            experiment_configs+=("$eq_idx:$config")
        done
    done
    
    # Count total experiments
    local total_experiments=${#experiment_configs[@]}
    echo "Preparing to run $total_experiments total experiments"
    
    # Create arrays to track active PIDs for each GPU
    declare -a gpu_pids=()
    declare -a gpu_busy=()
    for ((i=0; i<NUM_GPUS; i++)); do
        gpu_pids[$i]=""
        gpu_busy[$i]=0
    done
    
    local active_experiments=0
    local completed_experiments=0
    local experiment_index=0
    
    echo "Starting experiment batches..."
    
    # Keep going until all experiments are completed
    while [ $completed_experiments -lt $total_experiments ]; do
        # First, check for and handle any completed experiments
        for ((i=0; i<NUM_GPUS; i++)); do
            if [ "${gpu_busy[$i]}" -eq 1 ]; then
                # Check if this GPU's process has finished
                if ! kill -0 ${gpu_pids[$i]} 2>/dev/null; then
                    # Process is done
                    wait ${gpu_pids[$i]} || echo "WARNING: Process ${gpu_pids[$i]} exited with non-zero status"
                    ((completed_experiments++))
                    ((active_experiments--))
                    gpu_busy[$i]=0
                    echo "Completed experiment $completed_experiments/$total_experiments on GPU $i"
                fi
            fi
        done
        
        # Then start new experiments on free GPUs if we have more to run
        while [ $experiment_index -lt $total_experiments ] && [ $active_experiments -lt $NUM_GPUS ]; do
            # Find a free GPU
            for ((i=0; i<NUM_GPUS; i++)); do
                if [ "${gpu_busy[$i]}" -eq 0 ]; then
                    # This GPU is free, start an experiment
                    config=${experiment_configs[$experiment_index]}
                    IFS=':' read -r eq_idx depth k_bar hidden_channels layer_type combine_method <<< "$config"
                    
                    # Start the experiment on this GPU
                    run_experiment $i \
                                   "${TRAIN_PATHS[$eq_idx]}" \
                                   "${VALID_PATHS[$eq_idx]}" \
                                   "$depth" \
                                   "$k_bar" \
                                   "$hidden_channels" \
                                   "$layer_type" \
                                   "$combine_method" &
                    
                    # Store the process ID and mark GPU as busy
                    gpu_pids[$i]=$!
                    gpu_busy[$i]=1
                    ((active_experiments++))
                    ((experiment_index++))
                    
                    echo "Scheduled experiment $experiment_index/$total_experiments on GPU $i (equation: ${EQUATION_NAMES[$eq_idx]}, depth: $depth, k_bar: $k_bar, hidden_channels: $hidden_channels, layer_type: $layer_type, combine_method: $combine_method, active: $active_experiments)"
                    
                    # Break the GPU loop to recheck how many active experiments we have
                    break
                fi
            done
            
            # If all GPUs are busy, break out of the experiment starting loop
            if [ $active_experiments -eq $NUM_GPUS ]; then
                break
            fi
        done
        
        # Sleep briefly to avoid hammering the CPU with checks
        sleep 1
    done
    
    echo "All experiments completed. Ran $completed_experiments of $total_experiments experiments."
}

# Execute the main function
main 