#!/bin/bash
# Usage: bash cluster_scripts/submit_cluster_jobs.sh <config_file> <output_dir> [cpu_cores] [memory] [run_ids]
#   - run_ids (optional): comma-separated 0-indexed run IDs to submit (e.g. "0" or "0,3,7")
# Set these for your cluster (anonymous review: no institution-specific paths in repo):
#   SLURM_PARTITION, SINGULARITY_IMAGE, CONDA_ACTIVATE_SCRIPT, CONDA_ENV_PATH
CONFIG_FILE=$1
OUTPUT_DIR=$2
CPU_CORES=${3:-32}  # Use $3 if provided, otherwise default to 32
MEMORY=${4:-128G}   # Use $4 if provided, otherwise default to 128G
RUN_IDS=${5:-}      # Optional comma-separated list of 0-indexed run IDs

# Get absolute paths (Singularity needs full paths)
CURRENT_DIR=$(pwd)
ABS_CONFIG="$CURRENT_DIR/$CONFIG_FILE"
ABS_OUTPUT="$CURRENT_DIR/$OUTPUT_DIR"
# Cluster-specific (set these for your site; not set in repo for anonymous review):
IMAGE=${SINGULARITY_IMAGE:-}
PARTITION=${SLURM_PARTITION:-}
CONDA_ACTIVATE=${CONDA_ACTIVATE_SCRIPT:-}
CONDA_ENV=${CONDA_ENV_PATH:-}
if [ -z "$IMAGE" ] || [ -z "$PARTITION" ] || [ -z "$CONDA_ACTIVATE" ] || [ -z "$CONDA_ENV" ]; then
    echo "For anonymous review, cluster paths are not in the repo. Set: SINGULARITY_IMAGE, SLURM_PARTITION, CONDA_ACTIVATE_SCRIPT, CONDA_ENV_PATH"
    exit 1
fi

# Calculate Array Size from JSON
TOTAL_RUNS=$(python3 -c "import json; print(json.load(open('$CONFIG_FILE'))['metadata']['total_runs'])")
ARRAY_LIMIT=$((TOTAL_RUNS - 1))

# If specific run IDs are provided, submit ONLY those and exit
if [ -n "$RUN_IDS" ]; then
    echo "Submitting selected run IDs only (Array: $RUN_IDS)..."

    # Create logs folder
    mkdir -p "$ABS_OUTPUT/logs"

    # Copy config file to output directory (for reproducibility)
    CONFIG_FILENAME=$(basename "$ABS_CONFIG")
    CONFIG_COPY_PATH="$ABS_OUTPUT/$CONFIG_FILENAME"
    if [ ! -f "$CONFIG_COPY_PATH" ]; then
        cp "$ABS_CONFIG" "$CONFIG_COPY_PATH"
        echo "📋 Configuration copied to: $CONFIG_COPY_PATH"
    else
        echo "📋 Configuration already exists at: $CONFIG_COPY_PATH"
    fi

    sbatch <<EOT
#!/bin/bash
#SBATCH --job-name=mpf_bench_selected
#SBATCH --partition=$PARTITION
#SBATCH --array=$RUN_IDS
#SBATCH --cpus-per-task=$CPU_CORES
#SBATCH --mem=$MEMORY
#SBATCH --output=$ABS_OUTPUT/logs/job_%A_%a.out
#SBATCH --error=$ABS_OUTPUT/logs/job_%A_%a.err

echo "Task ID: \$SLURM_ARRAY_TASK_ID on node \$(hostname)"

# Run inside the container
srun singularity exec $IMAGE bash -c "
    # 1. Activate the environment (The one with MPF installed)
    source $CONDA_ACTIVATE
    conda activate $CONDA_ENV
    
    # 2. Runtime Linker Paths (Crucial for OpenBLAS/OpenSSL at runtime)
    export LD_LIBRARY_PATH=\$CONDA_PREFIX/lib:\$LD_LIBRARY_PATH
    
    # 3. Run the Python Script
    # We use 'python3' from the conda env
    python3 cluster_scripts/run_cluster_experiment.py \\
        --config '$ABS_CONFIG' \\
        --run-id \$SLURM_ARRAY_TASK_ID \\
        --output '$ABS_OUTPUT'
"
EOT

    echo "✅ Submitted selected run IDs: $RUN_IDS"
    exit 0
fi

# Function to submit a single run
submit_single_run() {
    local run_id=$1
    sbatch <<EOT
#!/bin/bash
#SBATCH --job-name=mpf_bench_${run_id}
#SBATCH --partition=$PARTITION
#SBATCH --cpus-per-task=$CPU_CORES
#SBATCH --mem=$MEMORY
#SBATCH --output=$ABS_OUTPUT/logs/job_%j_${run_id}.out
#SBATCH --error=$ABS_OUTPUT/logs/job_%j_${run_id}.err

echo "Task ID: ${run_id} on node \$(hostname)"

# Run inside the container
srun singularity exec $IMAGE bash -c "
    # 1. Activate the environment (The one with MPF installed)
    source $CONDA_ACTIVATE
    conda activate $CONDA_ENV
    
    # 2. Runtime Linker Paths (Crucial for OpenBLAS/OpenSSL at runtime)
    export LD_LIBRARY_PATH=\$CONDA_PREFIX/lib:\$LD_LIBRARY_PATH
    
    # 3. Run the Python Script
    # We use 'python3' from the conda env
    python3 cluster_scripts/run_cluster_experiment.py \\
        --config '$ABS_CONFIG' \\
        --run-id ${run_id} \\
        --output '$ABS_OUTPUT'
"
EOT
}

# Check if output directory and config exist
if [ -d "$ABS_OUTPUT" ]; then
    CONFIG_FILENAME=$(basename "$ABS_CONFIG")
    CONFIG_COPY_PATH="$ABS_OUTPUT/$CONFIG_FILENAME"
    
    if [ -f "$CONFIG_COPY_PATH" ]; then
        echo "📁 Output directory and config file exist. Checking for incomplete/timed-out runs..."
        
        LOGS_DIR="$ABS_OUTPUT/logs"
        if [ -d "$LOGS_DIR" ]; then
            # Normalize config paths for comparison (get basename)
            CURRENT_CONFIG_BASENAME=$(basename "$ABS_CONFIG")
            
            # Arrays to store run IDs
            declare -a existing_run_ids=()
            declare -a timed_out_run_ids=()
            declare -a failed_run_ids=()
            
            # Enable nullglob to handle case when no .out files exist
            shopt -s nullglob
            
            # Process all .out files
            for out_file in "$LOGS_DIR"/*.out; do
                
                # Extract config path from .out file header
                config_in_file=$(grep -m 1 "^Configuration:" "$out_file" 2>/dev/null | sed 's/^Configuration: //' | xargs)
                
                if [ -z "$config_in_file" ]; then
                    continue
                fi
                
                # Compare config basenames
                config_in_file_basename=$(basename "$config_in_file")
                if [ "$config_in_file_basename" != "$CURRENT_CONFIG_BASENAME" ]; then
                    continue
                fi
                
                # Extract run ID from "Run ID: X/Y" line
                run_id_line=$(grep -m 1 "^Run ID:" "$out_file" 2>/dev/null)
                if [ -z "$run_id_line" ]; then
                    continue
                fi
                
                # Extract the first number (run_id) from "Run ID: X/Y"
                run_id=$(echo "$run_id_line" | sed -n 's/^Run ID: \([0-9]*\)\/.*/\1/p')
                if [ -z "$run_id" ]; then
                    continue
                fi
                
                # Convert to 0-indexed (subtract 1)
                run_id=$((run_id - 1))
                
                # Check if run_id is valid
                if [ "$run_id" -ge 0 ] && [ "$run_id" -lt "$TOTAL_RUNS" ]; then
                    existing_run_ids+=($run_id)
                    
                    # Check corresponding .err file for "DUE TO TIME LIMIT"
                    err_file="${out_file%.out}.err"
                    if [ -f "$err_file" ]; then
                        if grep -q "DUE TO TIME LIMIT" "$err_file" 2>/dev/null; then
                            timed_out_run_ids+=($run_id)
                            echo "⏱️  Found timed-out run: $run_id"
                        fi
                    fi
                    
                    # Check for "Experiment failed: All folds failed" in .out file
                    if grep -q "Experiment failed: All folds failed" "$out_file" 2>/dev/null; then
                        failed_run_ids+=($run_id)
                        echo "❌ Found failed run: $run_id"
                        
                        # Extract dataset and model name from the output file
                        # Look for "Running experiment: dataset x model" line
                        experiment_line=$(grep -m 1 "^Running experiment:" "$out_file" 2>/dev/null | sed 's/^Running experiment: //')
                        
                        if [ -n "$experiment_line" ]; then
                            # Extract dataset and model (format: "dataset x model")
                            dataset_model=$(echo "$experiment_line" | sed 's/ x /|/' | cut -d'|' -f1)
                            model_name=$(echo "$experiment_line" | sed 's/ x /|/' | cut -d'|' -f2)
                            
                            # Also try to get the actual dataset name from "Dataset loaded:" line
                            dataset_loaded_line=$(grep -m 1 "✅ Dataset loaded:" "$out_file" 2>/dev/null)
                            if [ -n "$dataset_loaded_line" ]; then
                                # Extract dataset name (format: "✅ Dataset loaded: dataset_name, shape...")
                                actual_dataset=$(echo "$dataset_loaded_line" | sed -n 's/.*Dataset loaded: \([^,]*\).*/\1/p')
                                if [ -n "$actual_dataset" ]; then
                                    dataset_model="$actual_dataset"
                                fi
                            fi
                            
                            # Sanitize names for filesystem (same as Python code)
                            safe_dataset=$(echo "$dataset_model" | sed 's/ /_/g' | sed 's/\//_/g')
                            safe_model=$(echo "$model_name" | sed 's/ /_/g' | sed 's/\//_/g')
                            
                            # Find and remove optuna log files
                            # Check for simple CV log file
                            simple_log="$ABS_OUTPUT/$safe_dataset/$safe_model/${safe_dataset}_${safe_model}.log"
                            if [ -f "$simple_log" ]; then
                                echo "   🗑️  Removing optuna log: $simple_log"
                                rm -f "$simple_log"
                            fi
                            
                            # Check for nested CV log files (fold*.log)
                            nested_log_dir="$ABS_OUTPUT/$safe_dataset/$safe_model"
                            if [ -d "$nested_log_dir" ]; then
                                # Use glob expansion (unquoted) to find fold log files
                                for nested_log in "$nested_log_dir"/${safe_dataset}_${safe_model}_fold*.log; do
                                    # Check if file exists (glob didn't match if file doesn't exist)
                                    if [ -f "$nested_log" ]; then
                                        echo "   🗑️  Removing optuna log: $nested_log"
                                        rm -f "$nested_log"
                                    fi
                                done
                            fi
                        else
                            # Fallback: try to extract from storage path line
                            storage_line=$(grep -m 1 "^Storage:" "$out_file" 2>/dev/null | sed 's/^Storage: //')
                            if [ -z "$storage_line" ]; then
                                # Try alternative format: "Using shared log file storage:"
                                storage_line=$(grep -m 1 "Using shared log file storage:" "$out_file" 2>/dev/null | sed 's/.*Using shared log file storage: //')
                            fi
                            
                            if [ -n "$storage_line" ]; then
                                # Remove the log file
                                if [ -f "$storage_line" ]; then
                                    echo "   🗑️  Removing optuna log: $storage_line"
                                    rm -f "$storage_line"
                                fi
                                
                                # Extract directory and remove any fold*.log files in the same directory
                                log_dir=$(dirname "$storage_line")
                                if [ -d "$log_dir" ]; then
                                    for fold_log in "$log_dir"/*_fold*.log; do
                                        if [ -f "$fold_log" ]; then
                                            echo "   🗑️  Removing optuna log: $fold_log"
                                            rm -f "$fold_log"
                                        fi
                                    done
                                fi
                            fi
                        fi
                    fi
                fi
            done
            
            # Restore nullglob setting
            shopt -u nullglob
            
            # Find missing run IDs
            declare -a missing_run_ids=()
            for i in $(seq 0 $ARRAY_LIMIT); do
                found=0
                for existing_id in "${existing_run_ids[@]}"; do
                    if [ "$existing_id" -eq "$i" ]; then
                        found=1
                        break
                    fi
                done
                if [ $found -eq 0 ]; then
                    missing_run_ids+=($i)
                fi
            done
            
            # Combine timed-out, failed, and missing run IDs
            declare -a runs_to_requeue=()
            runs_to_requeue=("${timed_out_run_ids[@]}" "${failed_run_ids[@]}" "${missing_run_ids[@]}")
            
            # Remove duplicates and sort
            IFS=$'\n' runs_to_requeue=($(printf '%s\n' "${runs_to_requeue[@]}" | sort -n | uniq))
            unset IFS
            
            if [ ${#runs_to_requeue[@]} -gt 0 ]; then
                echo "🔄 Found ${#runs_to_requeue[@]} runs to requeue:"
                echo "   Timed-out: ${#timed_out_run_ids[@]} runs"
                echo "   Failed: ${#failed_run_ids[@]} runs"
                echo "   Missing: ${#missing_run_ids[@]} runs"
                echo "   Run IDs: ${runs_to_requeue[*]}"
                echo ""
                
                # Build array specification string for SLURM
                # Convert array of run IDs to comma-separated list
                array_spec=$(IFS=','; echo "${runs_to_requeue[*]}")
                
                echo "Submitting job array with ${#runs_to_requeue[@]} tasks (Array: $array_spec)..."
                
                # Create logs folder if it doesn't exist
                mkdir -p "$ABS_OUTPUT/logs"
                
                # Copy config file to output directory if not already there
                CONFIG_FILENAME=$(basename "$ABS_CONFIG")
                CONFIG_COPY_PATH="$ABS_OUTPUT/$CONFIG_FILENAME"
                if [ ! -f "$CONFIG_COPY_PATH" ]; then
                    cp "$ABS_CONFIG" "$CONFIG_COPY_PATH"
                    echo "📋 Configuration copied to: $CONFIG_COPY_PATH"
                fi
                
                # Submit as job array
                sbatch <<EOT
#!/bin/bash
#SBATCH --job-name=mpf_bench_requeue
#SBATCH --partition=$PARTITION
#SBATCH --array=$array_spec
#SBATCH --cpus-per-task=$CPU_CORES
#SBATCH --mem=$MEMORY
#SBATCH --output=$ABS_OUTPUT/logs/job_%A_%a.out
#SBATCH --error=$ABS_OUTPUT/logs/job_%A_%a.err

echo "Task ID: \$SLURM_ARRAY_TASK_ID on node \$(hostname)"

# Run inside the container
srun singularity exec $IMAGE bash -c "
    # 1. Activate the environment (The one with MPF installed)
    source $CONDA_ACTIVATE
    conda activate $CONDA_ENV
    
    # 2. Runtime Linker Paths (Crucial for OpenBLAS/OpenSSL at runtime)
    export LD_LIBRARY_PATH=\$CONDA_PREFIX/lib:\$LD_LIBRARY_PATH
    
    # 3. Run the Python Script
    # We use 'python3' from the conda env
    python3 cluster_scripts/run_cluster_experiment.py \\
        --config '$ABS_CONFIG' \\
        --run-id \$SLURM_ARRAY_TASK_ID \\
        --output '$ABS_OUTPUT'
"
EOT
                
                echo "✅ Finished requeuing ${#runs_to_requeue[@]} runs as job array"
                exit 0
            else
                echo "✅ All runs are complete. No runs to requeue."
                exit 0
            fi
        fi
    fi
fi

# If we get here, either output dir doesn't exist or config doesn't exist
# Proceed with normal submission of all runs
echo "Submitting $TOTAL_RUNS jobs (Array 0-$ARRAY_LIMIT)..."

# Create logs folder
mkdir -p "$ABS_OUTPUT/logs"

# Copy config file to output directory once (before submitting jobs)
CONFIG_FILENAME=$(basename "$ABS_CONFIG")
CONFIG_COPY_PATH="$ABS_OUTPUT/$CONFIG_FILENAME"
cp "$ABS_CONFIG" "$CONFIG_COPY_PATH"
echo "📋 Configuration copied to: $CONFIG_COPY_PATH"

# Submit to SLURM
sbatch <<EOT
#!/bin/bash
#SBATCH --job-name=mpf_bench
#SBATCH --partition=$PARTITION
#SBATCH --array=0-$ARRAY_LIMIT
#SBATCH --cpus-per-task=$CPU_CORES
#SBATCH --mem=$MEMORY
#SBATCH --output=$ABS_OUTPUT/logs/job_%A_%a.out
#SBATCH --error=$ABS_OUTPUT/logs/job_%A_%a.err

echo "Task ID: \$SLURM_ARRAY_TASK_ID on node \$(hostname)"

# Run inside the container
srun singularity exec $IMAGE bash -c "
    # 1. Activate the environment (The one with MPF installed)
    source $CONDA_ACTIVATE
    conda activate $CONDA_ENV
    
    # 2. Runtime Linker Paths (Crucial for OpenBLAS/OpenSSL at runtime)
    export LD_LIBRARY_PATH=\$CONDA_PREFIX/lib:\$LD_LIBRARY_PATH
    
    # 3. Run the Python Script
    # We use 'python3' from the conda env
    python3 cluster_scripts/run_cluster_experiment.py \\
        --config '$ABS_CONFIG' \\
        --run-id \$SLURM_ARRAY_TASK_ID \\
        --output '$ABS_OUTPUT'
"
EOT

# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/interpretable.json --suite "353[1-22]" --models MPFRegressor XGBRegressor LGBMRegressor RandomForestRegressor ExplainableBoostingRegressor --output cluster_scripts/configs/interpretable_config.json
# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/blackbox.json --suite "353[1-22]" --models MPFRegressor XGBRegressor LGBMRegressor RandomForestRegressor --output cluster_scripts/configs/blackbox_config.json

# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/interpretable.json --suite "353" --models ExplainableBoostingRegressor --output cluster_scripts/configs/explainableboosting_config.json
# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/interpretable.json --suite "353[23-]" --models XGBRegressor LGBMRegressor RandomForestRegressor --output cluster_scripts/configs/interpretable_config_high_np.json
# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/interpretable.json --suite "353[23-]" --models MPFRegressor --output cluster_scripts/configs/interpretable_mpf_high_np.json

# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/blackbox.json --suite "353[23-]" --models XGBRegressor LGBMRegressor RandomForestRegressor --output cluster_scripts/configs/blackbox_config_high_np.json
# python cluster_scripts/generate_cluster_config.py --hyperparams cluster_scripts/hyperparams/blackbox.json --suite "353[23-]" --models MPFRegressor --output cluster_scripts/configs/blackbox_mpf_high_np.json

# bash cluster_scripts/submit_cluster_jobs.sh cluster_scripts/configs/explainableboosting_config.json cluster_scripts/results/interpretable/others 8
# bash cluster_scripts/submit_cluster_jobs.sh cluster_scripts/configs/interpretable_config_high_np.json cluster_scripts/results/interpretable/others 8
# bash cluster_scripts/submit_cluster_jobs.sh cluster_scripts/configs/interpretable_mpf_high_np.json cluster_scripts/results/interpretable/mpf 64
# bash cluster_scripts/submit_cluster_jobs.sh cluster_scripts/configs/blackbox_config_high_np.json cluster_scripts/results/blackbox/others 8
# bash cluster_scripts/submit_cluster_jobs.sh cluster_scripts/configs/blackbox_mpf_high_np.json cluster_scripts/results/blackbox/mpf 64
