#!/bin/bash

# Configuration
TOTAL_DATA_SIZE=1632
MAX_DATA_PER_JOB=42
GPUS=(0 1 2 3 4 5 6 7)
MAX_JOBS_PER_GPU=5

# Create output directory if it doesn't exist
mkdir -p logs_imagen

# Function to calculate indices for each job
calculate_indices() {
    local job_number=$1
    local start_idx=$((job_number * MAX_DATA_PER_JOB))
    local end_idx=$((start_idx + MAX_DATA_PER_JOB - 1))
    
    # Ensure we don't exceed total data size
    if [ $end_idx -ge $TOTAL_DATA_SIZE ]; then
        end_idx=$((TOTAL_DATA_SIZE - 1))
    fi
    
    echo "$start_idx $end_idx"
}

# Counter for managing jobs
job_counter=0

# Process data in chunks
current_start=0
while [ $current_start -lt $TOTAL_DATA_SIZE ]; do
    # Calculate indices for this job
    indices=($(calculate_indices $job_counter))
    start_idx=${indices[0]}
    end_idx=${indices[1]}
    
    # Calculate which GPU to use (rotating through GPUs)
    gpu_idx=$((job_counter % ${#GPUS[@]}))
    
    # Launch the job
    echo "Launching job for indices $start_idx to $end_idx on GPU ${GPUS[$gpu_idx]}"
    nohup python -u image_generation.py $start_idx $end_idx "sdxl" cuda:${GPUS[$gpu_idx]} "data/flickr_train_it1_prompts.csv" "it-1"> "logs_imagen/job_${start_idx}_${end_idx}_gpu${GPUS[$gpu_idx]}.log" 2>&1 &
    
    # Increment counters
    ((job_counter++))
    current_start=$((end_idx + 1))
    
    # If we've launched max jobs per GPU across all GPUs, wait
    if [ $((job_counter % (${#GPUS[@]} * MAX_JOBS_PER_GPU))) -eq 0 ]; then
        echo "Waiting for current batch of jobs to complete..."
        wait
    fi
done

