#!/bin/bash

# ==============================================================================
# Configuration
# ==============================================================================

# Export CUDA_HOME once at the start
export CUDA_HOME=/usr/local/cuda-12.5

# Datasets to evaluate (excluding arxiv for transductive)
declare -a datasets=("cora" "citeseer" "pubmed" "wikics" "instagram" "history" "reddit")

# Models to evaluate (including prognn, stable, purify)
declare -a models=("gcn" "gat" "gcorn" "gnnguard" "appnp" "gprgnn" "evennet" "prognn" "stable" "purify_jaccard" "purify_cosine"  "rung" "elasticgnn" "noisy_gcn" "robustgcn" "twirls" "softmediangdc" "grand")

# Perturbation rates
declare -a ptb_rates=(0.30)

# Embedding types for attack and defense
declare -a atk_emb_types=("bow")
declare -a def_emb_types=("bow" "roberta" "MiniLM")

# Parallel execution settings
MAX_MEMORY_PERCENT=90      # Max memory usage percentage for a GPU to be considered "available"

get_max_processes() {
    local dataset=$1
    local model=$2
    echo 2
}

# Mapping of dataset to attack method
declare -A dataset_to_attack=(
    ["cora"]="strg"
    ["citeseer"]="strg"
    ["pubmed"]="strg"
    ["wikics"]="strg"
    ["instagram"]="strg"
    ["computer"]="strg"
    ["photo"]="strg"
    ["reddit"]="strg"
    ["history"]="strg"
)

# ==============================================================================
# Logging Setup
# ==============================================================================
LOG_DIR="./eval_logs_trans"
mkdir -p "$LOG_DIR"
MAIN_LOG="$LOG_DIR/eval_progress.log"
COMPLETED_LOG="$LOG_DIR/completed_tasks.log"
FAILED_LOG="$LOG_DIR/failed_tasks.log"
PROGRESS_LOG="$LOG_DIR/progress.log"

# Initialize log files if they don't exist
touch "$MAIN_LOG" "$COMPLETED_LOG" "$FAILED_LOG" "$PROGRESS_LOG"

# Initialize start time
START_TIME=$(date +%s)

# ==============================================================================
# Task Calculation
# ==============================================================================
TOTAL_TASKS=0
for dataset in "${datasets[@]}"; do
    for model in "${models[@]}"; do
        for ptb_rate in "${ptb_rates[@]}"; do
            for atk_emb in "${atk_emb_types[@]}"; do
                for def_emb in "${def_emb_types[@]}"; do
                    ((TOTAL_TASKS++))
                done
            done
        done
    done
done

COMPLETED_TASKS=0
if [ -f "$COMPLETED_LOG" ]; then
    COMPLETED_TASKS=$(sort -u "$COMPLETED_LOG" | wc -l | tr -d ' ')
fi

# ==============================================================================
# Helper Functions
# ==============================================================================

log_message() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$MAIN_LOG"
}

log_progress() {
    local completed=$1
    local total=$2
    local current_time=$(date +%s)
    local elapsed=$((current_time - START_TIME))
    
    # Calculate progress percentage
    local progress=$((completed * 100 / total))
    
    # Calculate estimated time remaining
    local time_per_task=0
    if [ "$completed" -gt 0 ]; then
        time_per_task=$((elapsed / completed))
        local remaining_tasks=$((total - completed))
        local estimated_remaining=$((time_per_task * remaining_tasks))
        
        # Convert seconds to hours, minutes, seconds
        local hours=$((estimated_remaining / 3600))
        local minutes=$(((estimated_remaining % 3600) / 60))
        local seconds=$((estimated_remaining % 60))
        
        # Format elapsed time
        local elapsed_hours=$((elapsed / 3600))
        local elapsed_minutes=$((elapsed % 3600 / 60))
        local elapsed_seconds=$((elapsed % 60))
        
        echo "[$(date '+%Y-%m-%d %H:%M:%S')] Progress: $progress% ($completed/$total) | Elapsed: ${elapsed_hours}h:${elapsed_minutes}m:${elapsed_seconds}s | ETA: ${hours}h:${minutes}m:${seconds}s" >> "$PROGRESS_LOG"
    else
        echo "[$(date '+%Y-%m-%d %H:%M:%S')] Progress: $progress% ($completed/$total) | Just started" >> "$PROGRESS_LOG"
    fi
}

mark_completed() {
    local task_str="$1"
    echo "$task_str" >> "$COMPLETED_LOG"
    COMPLETED_TASKS=$(sort -u "$COMPLETED_LOG" | wc -l)
    log_progress "$COMPLETED_TASKS" "$TOTAL_TASKS"
    log_message "Progress: $((COMPLETED_TASKS * 100 / TOTAL_TASKS))% ($COMPLETED_TASKS/$TOTAL_TASKS) | Completed: $task_str"
}

mark_failed() {
    local task_str="$1"
    echo "$task_str" >> "$FAILED_LOG"
    log_message "!!!!!!!! FAILED: $task_str !!!!!!!!"
}

declare -a PIDS=()
declare -A PID_TO_TASK

cleanup() {
    log_message "Received stop signal. Cleaning up..."
    for pid in "${!PID_TO_TASK[@]}"; do
        if kill -0 "$pid" 2>/dev/null; then
            log_message "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            kill -9 "$pid"
        fi
    done
    log_message "Cleanup complete. Exiting."
    exit 1
}
trap cleanup SIGINT SIGTERM

get_gpu_memory_usage() {
    local gpu_id=$1
    if ! command -v nvidia-smi &> /dev/null; then
        echo "0"
        return
    fi
    
    # Check if the GPU exists
    if ! nvidia-smi -i "$gpu_id" &> /dev/null; then
        echo "100"  # Return high usage for non-existent GPU
        return
    fi
    
    memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    total_memory=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    
    if [ -z "$total_memory" ] || [ "$total_memory" -eq 0 ]; then
        echo "100"
    else
        echo $((memory_usage * 100 / total_memory))
    fi
}

get_available_gpu() {
    # First check if nvidia-smi exists and can be executed
    if ! command -v nvidia-smi &> /dev/null; then
        echo "-1"
        return
    fi
    
    # Get number of GPUs directly from nvidia-smi list
    num_gpus=$(nvidia-smi --list-gpus | wc -l)
    
    # Validate num_gpus is a number and greater than 0
    if ! [[ "$num_gpus" =~ ^[0-9]+$ ]] || [ "$num_gpus" -eq 0 ]; then
        echo "-1"
        return
    fi

    local best_gpu=-1
    local min_usage=101

    for (( gpu_id=0; gpu_id<num_gpus; gpu_id++ )); do
        # Check if GPU is accessible
        if nvidia-smi -i "$gpu_id" &> /dev/null; then
            usage=$(get_gpu_memory_usage "$gpu_id")
            # Ensure usage is a number
            if ! [[ "$usage" =~ ^[0-9]+$ ]]; then
                usage=100
            fi
            if [ "$usage" -lt "$min_usage" ]; then
                min_usage=$usage
                best_gpu=$gpu_id
            fi
        fi
    done

    if [ "$best_gpu" -ge 0 ] && [ "$min_usage" -lt "$MAX_MEMORY_PERCENT" ]; then
        echo "$best_gpu"
    else
        # Return -2 if all GPUs are busy or inaccessible
        echo "-2"
    fi
}

wait_for_gpu() {
    while true; do
        local gpu_id
        gpu_id=$(get_available_gpu)
        
        case "$gpu_id" in
            "-1")
                log_message "No NVIDIA drivers or GPUs detected. Waiting..." >&2
                sleep 5
                ;;
            "-2")
                log_message "All GPUs are busy (usage > $MAX_MEMORY_PERCENT%). Waiting..." >&2
                sleep 3
                ;;
            *)
                if [ "$gpu_id" -ge 0 ]; then
                    log_message "Found available GPU $gpu_id" >&2
                    echo "$gpu_id"
                    return
                fi
                ;;
        esac
    done
}

is_task_completed() {
    local task_str="$1"
    [ -f "$COMPLETED_LOG" ] && grep -Fxq "$task_str" "$COMPLETED_LOG"
}

# ==============================================================================
# Main Execution Logic
# ==============================================================================

run_evaluation_task() {
    local dataset=$1 model=$2 atk_emb=$3 def_emb=$4 ptb_rate=$5
    local attack=${dataset_to_attack[$dataset]}
    
    local task_str="$dataset | $model | $attack | atk:$atk_emb def:$def_emb | ptb:$ptb_rate"

    if is_task_completed "$task_str"; then
        return
    fi

    local epochs patience
    case "$dataset" in
        cora|citeseer|instagram|pubmed|wikics)
            epochs=400; patience=100 ;;
        computer|photo|reddit|history)
            epochs=600; patience=200 ;;
        *)
            log_message "WARNING: No epoch/patience setting for $dataset. Using defaults."
            epochs=200; patience=50 ;;
    esac

    local gpu_id
    gpu_id=$(wait_for_gpu)
    
    log_message "Starting on GPU $gpu_id: $task_str"

    python eval_transductive.py \
        --dataset "$dataset" \
        --model "$model" \
        --attack "$attack" \
        --atk_emb_type "$atk_emb" \
        --def_emb_type "$def_emb" \
        --ptb_rate "$ptb_rate" \
        --epochs "$epochs" \
        --patience "$patience" \
        --device "$gpu_id" \
        --use_existing_logs &

    local pid=$!
    PIDS+=("$pid")
    PID_TO_TASK["$pid"]="$task_str"
    sleep 5
}

check_completed_tasks() {
    local new_pids=()
    for pid in "${PIDS[@]}"; do
        if [ -z "$pid" ]; then
            continue
        fi
        if ! kill -0 "$pid" 2>/dev/null; then
            wait "$pid"
            local status=$?
            local task_str="${PID_TO_TASK[$pid]}"
            if [ $status -eq 0 ]; then
                mark_completed "$task_str"
            else
                mark_failed "$task_str"
            fi
            unset "PID_TO_TASK[$pid]"
        else
            new_pids+=("$pid")
        fi
    done
    PIDS=("${new_pids[@]}")
}

# Initialize empty arrays
PIDS=()
declare -A PID_TO_TASK

log_message "Starting transductive evaluation tasks... (Press Ctrl+C to stop)"
log_message "Total tasks to process: $TOTAL_TASKS. Completed: $COMPLETED_TASKS"
echo "Starting... Progress will be logged to $MAIN_LOG"

if [ -f "$COMPLETED_LOG" ]; then
    log_message "Resuming from completed tasks log. $(wc -l < "$COMPLETED_LOG") tasks already done."
fi

for dataset in "${datasets[@]}"; do
    for ptb_rate in "${ptb_rates[@]}"; do
        for model in "${models[@]}"; do
            for atk_emb in "${atk_emb_types[@]}"; do
                for def_emb in "${def_emb_types[@]}"; do
                    MAX_CONCURRENT_PROCESSES=$(get_max_processes "$dataset" "$model")
                    while [ ${#PIDS[@]} -ge $MAX_CONCURRENT_PROCESSES ]; do
                        check_completed_tasks
                        sleep 5
                    done
                    
                    run_evaluation_task "$dataset" "$model" "$atk_emb" "$def_emb" "$ptb_rate"
                done
            done
        done
    done
done

log_message "All tasks launched. Waiting for remaining processes to complete..."
while [ ${#PIDS[@]} -gt 0 ]; do
    check_completed_tasks
    sleep 5
done

log_message "All evaluation tasks completed!"
echo "All tasks completed!"
echo "Check $MAIN_LOG for full progress log."
echo "Completed tasks are listed in $COMPLETED_LOG."
echo "Failed tasks are listed in $FAILED_LOG." 