#!/bin/bash

# Configuration for PGD attacks with cosine similarity thresholds
declare -A datasets_thresholds=(
    ["cora"]="0.0 0.3 0.5 0.7"
    ["citeseer"]="0.0 0.3 0.5 0.7"
    ["pubmed"]="0.0 0.3 0.5 0.7"
    ["wikics"]="0.0 0.3 0.5 0.7"
)

ptb_rates=(0.20)
emb_types=("roberta")
MAX_MEMORY_PERCENT=90
MAX_CONCURRENT_PROCESSES=2

# Setup logging
LOG_DIR="./log_guard"
mkdir -p $LOG_DIR
MAIN_LOG="$LOG_DIR/guard_attack_progress.log"
COMPLETED_LOG="$LOG_DIR/completed_guard_tasks.log"
FAILED_LOG="$LOG_DIR/failed_guard_tasks.log"

# Calculate total number of tasks
TOTAL_TASKS=0
for dataset in "${!datasets_thresholds[@]}"; do
    for atk_threshold in ${datasets_thresholds[$dataset]}; do
        for ptb_rate in "${ptb_rates[@]}"; do
            for emb in "${emb_types[@]}"; do
                ((TOTAL_TASKS++))
            done
        done
    done
done

COMPLETED_TASKS=0
if [ -f "$COMPLETED_LOG" ]; then
    COMPLETED_TASKS=$(wc -l < "$COMPLETED_LOG")
fi

# Function to log message with timestamp
log_message() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$MAIN_LOG"
}

# Function to mark task as completed
mark_completed() {
    local task_str="$1"
    echo "$task_str" >> "$COMPLETED_LOG"
    ((COMPLETED_TASKS++))
    local progress=$((COMPLETED_TASKS * 100 / TOTAL_TASKS))
    log_message "Progress: $progress% ($COMPLETED_TASKS/$TOTAL_TASKS) - Completed: $task_str"
}

# Function to mark task as failed
mark_failed() {
    local task_str="$1"
    echo "$task_str" >> "$FAILED_LOG"
    log_message "FAILED: $task_str"
}

# Array to store background process PIDs
declare -a PIDS=()
declare -A PID_TO_TASK

# Trap Ctrl+C and cleanup
cleanup() {
    echo -e "\nReceived stop signal. Cleaning up..."
    log_message "Received stop signal. Cleaning up..."
    # Kill all background processes
    for pid in "${!PID_TO_TASK[@]}"; do
        if kill -0 "$pid" 2>/dev/null; then
            echo "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            log_message "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            kill -9 "$pid"
        fi
    done
    log_message "Cleanup complete. Exiting."
    echo "Cleanup complete. Exiting."
    exit 1
}

trap cleanup SIGINT SIGTERM

get_gpu_memory_usage() {
    local gpu_id=$1
    if ! command -v nvidia-smi &> /dev/null; then
        echo "0"
        return
    fi
    
    # Check if the GPU exists
    if ! nvidia-smi -i "$gpu_id" &> /dev/null; then
        echo "100"  # Return high usage for non-existent GPU
        return
    fi
    
    memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    total_memory=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    
    if [ -z "$total_memory" ] || [ "$total_memory" -eq 0 ]; then
        echo "100"
    else
        echo $((memory_usage * 100 / total_memory))
    fi
}

get_available_gpu() {
    # First check if nvidia-smi exists and can be executed
    if ! command -v nvidia-smi &> /dev/null; then
        echo "-1"
        return
    fi
    
    # Get number of GPUs directly from nvidia-smi list
    num_gpus=$(nvidia-smi --list-gpus | wc -l)
    
    # Validate num_gpus is a number and greater than 0
    if ! [[ "$num_gpus" =~ ^[0-9]+$ ]] || [ "$num_gpus" -eq 0 ]; then
        echo "-1"
        return
    fi

    local best_gpu=-1
    local min_usage=101

    for (( gpu_id=0; gpu_id<num_gpus; gpu_id++ )); do
        # Check if GPU is accessible
        if nvidia-smi -i "$gpu_id" &> /dev/null; then
            usage=$(get_gpu_memory_usage "$gpu_id")
            # Ensure usage is a number
            if ! [[ "$usage" =~ ^[0-9]+$ ]]; then
                usage=100
            fi
            if [ "$usage" -lt "$min_usage" ]; then
                min_usage=$usage
                best_gpu=$gpu_id
            fi
        fi
    done

    if [ "$best_gpu" -ge 0 ] && [ "$min_usage" -lt "$MAX_MEMORY_PERCENT" ]; then
        echo "$best_gpu"
    else
        # Return -2 if all GPUs are busy or inaccessible
        echo "-2"
    fi
}

wait_for_gpu() {
    while true; do
        local gpu_id
        gpu_id=$(get_available_gpu)
        
        case "$gpu_id" in
            "-1")
                log_message "No NVIDIA drivers or GPUs detected. Waiting..." >&2
                sleep 5
                ;;
            "-2")
                log_message "All GPUs are busy (usage > $MAX_MEMORY_PERCENT%). Waiting..." >&2
                sleep 3
                ;;
            *)
                if [ "$gpu_id" -ge 0 ]; then
                    log_message "Found available GPU $gpu_id" >&2
                    echo "$gpu_id"
                    return
                fi
                ;;
        esac
    done
}

# Function to check if task is already completed
is_task_completed() {
    local task_str="$1"
    if [ -f "$COMPLETED_LOG" ]; then
        grep -q "^$task_str$" "$COMPLETED_LOG"
        return $?
    fi
    return 1
}

# Function to run a single PGD attack with threshold task
run_guard_attack() {
    local dataset=$1
    local atk_threshold=$2
    local ptb_rate=$3
    local emb=$4
    
    local task_str="pgdguard $dataset atk_threshold=$atk_threshold ptb_rate=$ptb_rate emb=$emb"
    
    # Skip if task is already completed
    if is_task_completed "$task_str"; then
        log_message "Skipping completed task: $task_str"
        return
    fi

    # Wait for available GPU
    local gpu_id
    gpu_id=$(wait_for_gpu)
    
    if [ "$atk_threshold" == "0.0" ]; then
        log_message "Starting original PGD attack (baseline) on GPU $gpu_id: $dataset (ptb_rate=$ptb_rate, emb=$emb)"
        log_message "  - Surrogate model: GCN (hid=64)"
        log_message "  - No cosine similarity filtering (equivalent to original PGD)"
    else
        log_message "Starting PGD attack with cosine threshold on GPU $gpu_id: $dataset (atk_threshold=$atk_threshold, ptb_rate=$ptb_rate, emb=$emb)"
        log_message "  - Surrogate model: GCN (hid=64)"
        log_message "  - Cosine similarity threshold: $atk_threshold"
    fi
    log_message "  - Defense models: GCN (hid=128), GNNGUARD (hid=128, thresholds=0.1,0.3,0.5,0.7)"

    python gen_attacks_inductive_guard.py \
        --dataset "$dataset" \
        --ptb_rate "$ptb_rate" \
        --emb_type "$emb" \
        --atk_threshold "$atk_threshold" \
        --device "$gpu_id" \
        --re_split 2 &
    
    local pid=$!
    PIDS+=($pid)
    PID_TO_TASK[$pid]="$task_str"

    # Small delay to allow nvidia-smi to update
    sleep 5
}

# Function to check and update completed tasks
check_completed_tasks() {
    local any_completed=0
    local new_pids=()
    
    for pid in "${PIDS[@]}"; do
        if [ -z "$pid" ]; then
            continue
        fi
        if ! kill -0 "$pid" 2>/dev/null; then
            # Process has finished
            wait "$pid"
            local status=$?
            if [ $status -eq 0 ]; then
                mark_completed "${PID_TO_TASK[$pid]}"
            else
                mark_failed "${PID_TO_TASK[$pid]}"
            fi
            unset "PID_TO_TASK[$pid]"
            any_completed=1
        else
            new_pids+=($pid)
        fi
    done
    
    PIDS=("${new_pids[@]}")
    return $any_completed
}

log_message "Starting PGD attacks with cosine similarity thresholds... (Press Ctrl+C to stop)"
log_message "Total tasks to process: $TOTAL_TASKS"
log_message "Attack method: PGD only"
log_message "Attack thresholds: 0.0 (baseline), 0.1, 0.3, 0.5, 0.7"
log_message "Defense thresholds: 0.1, 0.3, 0.5, 0.7 (evaluated for each attack)"
echo "Starting PGD attacks with cosine similarity thresholds... (Press Ctrl+C to stop)"
echo "Total tasks to process: $TOTAL_TASKS"
echo "Progress will be logged to $MAIN_LOG"

# Main execution loop - process all datasets and attack thresholds
for dataset in "${!datasets_thresholds[@]}"; do
    for atk_threshold in ${datasets_thresholds[$dataset]}; do
        for ptb_rate in "${ptb_rates[@]}"; do
            for emb in "${emb_types[@]}"; do
                # Check if we have too many running processes
                while [ ${#PIDS[@]} -ge $MAX_CONCURRENT_PROCESSES ]; do
                    check_completed_tasks
                    sleep 5
                done
                run_guard_attack "$dataset" "$atk_threshold" "$ptb_rate" "$emb"
            done
        done
    done
done

# Wait for remaining processes to complete
while [ ${#PIDS[@]} -gt 0 ]; do
    check_completed_tasks
    sleep 10
done

log_message "All PGD attack tasks completed!"
echo "All PGD attack tasks completed!"
echo "Check $MAIN_LOG for full progress log"
echo "Check $COMPLETED_LOG for completed tasks"
echo "Check $FAILED_LOG for failed tasks"
echo "Individual experiment logs are saved in log_guard/ directory" 