#!/bin/bash

# Check if correct number of arguments is provided
if [ "$#" -ne 5 ]; then
    echo "Usage: $0 EXPERIMENT_DIR EXPERIMENT_TYPE MAX_CONCURRENCY WORKER_ID TOTAL_WORKERS"
    echo "  EXPERIMENT_DIR: Directory to save results"
    echo "  EXPERIMENT_TYPE: Type of experiment to run (logreg, resnet, gpt2)"
    echo "  MAX_CONCURRENCY: Maximum number of parallel jobs on this worker (0 for sequential)"
    echo "  WORKER_ID: ID of this worker (0-based, from 0 to TOTAL_WORKERS-1)"
    echo "  TOTAL_WORKERS: Total number of workers across all machines"
    exit 1
fi

EXPERIMENT_DIR="$1"
EXPERIMENT_TYPE="$2"
MAX_CONCURRENCY="$3"
WORKER_ID="$4"
TOTAL_WORKERS="$5"

# Define CONFIG_DIR based on EXPERIMENT_DIR
CONFIG_DIR="$EXPERIMENT_DIR/source_folder"

# Validate worker ID and total workers
if [ "$WORKER_ID" -lt 0 ] || [ "$WORKER_ID" -ge "$TOTAL_WORKERS" ]; then
    echo "Error: WORKER_ID must be between 0 and TOTAL_WORKERS-1"
    exit 1
fi

# Export EXPERIMENT_DIR so it's available to parallel processes
export EXPERIMENT_DIR
export EXPERIMENT_TYPE

# Create output directories
mkdir -p "$EXPERIMENT_DIR"
mkdir -p "$EXPERIMENT_DIR/results"
mkdir -p "$EXPERIMENT_DIR/stdout"
mkdir -p "$EXPERIMENT_DIR/stderr"

# Function to run a single experiment
run_experiment() {
    config_file="$1"
    base_name=$(basename "$config_file" .yaml)
    
    echo "Running experiment: $base_name"
    
    # Try up to 3 times if the command fails
    max_attempts=1
    attempt=1
    
    while [ $attempt -le $max_attempts ]; do
        # Choose experiment script based on EXPERIMENT_TYPE
        if [ "$EXPERIMENT_TYPE" = "logreg" ]; then
            EXP_SCRIPT=simulator.run_experiment
        elif [ "$EXPERIMENT_TYPE" = "resnet" ]; then
            EXP_SCRIPT=simulator.run_experiment_torch
        elif [ "$EXPERIMENT_TYPE" = "gpt2" ]; then
            EXP_SCRIPT=simulator.run_experiment_llm
        else
            echo "Error: Unknown EXPERIMENT_TYPE: $EXPERIMENT_TYPE. Must be 'logreg', 'resnet', or 'gpt2'" >&2
            exit 1
        fi 

        python3 -m "$EXP_SCRIPT" \
            --save_path "$EXPERIMENT_DIR/results/${base_name}_results.json" \
            --config "$config_file" \
            1>"$EXPERIMENT_DIR/stdout/${base_name}.out" \
            2>"$EXPERIMENT_DIR/stderr/${base_name}.err"
        
        exit_code=$?
        
        if [ $exit_code -eq 0 ]; then
            # Command succeeded, break out of the loop
            break
        else
            # Command failed, print error message and try again
            echo "Experiment $base_name failed with exit code $exit_code (attempt $attempt/$max_attempts). Retrying..." >&2
            attempt=$((attempt + 1))
            
            if [ $attempt -gt $max_attempts ]; then
                echo "Experiment $base_name failed after $max_attempts attempts." >&2
            fi
        fi
    done
}

# Export the function so it's available to parallel
export -f run_experiment

export PYTHONUNBUFFERED='1'

# Check if CONFIG_DIR exists
if [ ! -d "$CONFIG_DIR" ]; then
    echo "Error: Configuration directory $CONFIG_DIR does not exist"
    exit 1
fi

# Create a temporary file with sorted yaml files
TEMP_FILE=$(mktemp)
find "$CONFIG_DIR" -name "*.yaml" -type f | sort > "$TEMP_FILE"

# Calculate total number of files
TOTAL_FILES=$(wc -l < "$TEMP_FILE")

# Calculate start and end indices for this worker
# We use bash arithmetic to divide the work as evenly as possible
START_IDX=$(( (TOTAL_FILES * WORKER_ID) / TOTAL_WORKERS ))
END_IDX=$(( (TOTAL_FILES * (WORKER_ID + 1)) / TOTAL_WORKERS - 1 ))

if [ "$END_IDX" -ge "$START_IDX" ]; then
    echo "Worker $WORKER_ID/$TOTAL_WORKERS will process files $START_IDX to $END_IDX (out of $TOTAL_FILES)"
    
    # Extract the files for this worker
    WORKER_FILES=$(mktemp)
    sed -n "$((START_IDX + 1)),$((END_IDX + 1))p" "$TEMP_FILE" > "$WORKER_FILES"
    
    # Choose execution method based on MAX_CONCURRENCY
    if [ "$MAX_CONCURRENCY" -gt 0 ]; then
        # Parallel execution with xargs when MAX_CONCURRENCY > 0
        cat "$WORKER_FILES" | xargs -P "$MAX_CONCURRENCY" -I {} bash -c 'run_experiment "{}"'
    else
        # Sequential execution when MAX_CONCURRENCY = 0
        while read -r config_file; do
            echo "Running $config_file"
            start_time=$(date +%s)
            run_experiment "$config_file"
            end_time=$(date +%s)
            duration=$((end_time - start_time))
            echo "Experiment $config_file took $duration seconds"
        done < "$WORKER_FILES"
    fi
    
    # Clean up temporary worker files
    rm "$WORKER_FILES"
else
    echo "Worker $WORKER_ID/$TOTAL_WORKERS has no files to process"
fi

# Clean up temporary file
rm "$TEMP_FILE"

echo "Worker $WORKER_ID/$TOTAL_WORKERS completed. Results are in $EXPERIMENT_DIR" 