#!/bin/bash

# Configuration
declare -A datasets_attacks=(
    #["strg"]="cora citeseer pubmed wikics instagram reddit history"
    #["pgd"]="cora citeseer pubmed wikics instagram"
    #["grbcd"]="computer photo history reddit arxiv"
    ["mettack"]="cora citeseer instagram"
)
ptb_rates=(0.20)
emb_types=("bow") #"bow" "roberta" "MiniLM"
MAX_MEMORY_PERCENT=90  # Maximum memory usage percentage
MAX_CONCURRENT_PROCESSES=2  # Maximum number of concurrent processes

# Setup logging
LOG_DIR="./logs"
mkdir -p $LOG_DIR
MAIN_LOG="$LOG_DIR/attack_progress.log"
COMPLETED_LOG="$LOG_DIR/completed_tasks.log"
FAILED_LOG="$LOG_DIR/failed_tasks.log"

# Calculate total number of tasks
TOTAL_TASKS=0
for attack in "${!datasets_attacks[@]}"; do
    for dataset in ${datasets_attacks[$attack]}; do
        for ptb_rate in "${ptb_rates[@]}"; do
            for emb in "${emb_types[@]}"; do
                ((TOTAL_TASKS++))
            done
        done
    done
done

COMPLETED_TASKS=0
if [ -f "$COMPLETED_LOG" ]; then
    COMPLETED_TASKS=$(wc -l < "$COMPLETED_LOG")
fi

# Function to log message with timestamp
log_message() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$MAIN_LOG"
}

# Function to mark task as completed
mark_completed() {
    local task_str="$1"
    echo "$task_str" >> "$COMPLETED_LOG"
    ((COMPLETED_TASKS++))
    local progress=$((COMPLETED_TASKS * 100 / TOTAL_TASKS))
    log_message "Progress: $progress% ($COMPLETED_TASKS/$TOTAL_TASKS) - Completed: $task_str"
}

# Function to mark task as failed
mark_failed() {
    local task_str="$1"
    echo "$task_str" >> "$FAILED_LOG"
    log_message "FAILED: $task_str"
}

# Array to store background process PIDs
declare -a PIDS=()
declare -A PID_TO_TASK

# Trap Ctrl+C and cleanup
cleanup() {
    echo -e "\nReceived stop signal. Cleaning up..."
    log_message "Received stop signal. Cleaning up..."
    # Kill all background processes
    for pid in "${!PID_TO_TASK[@]}"; do
        if kill -0 "$pid" 2>/dev/null; then
            echo "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            log_message "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            kill -9 "$pid"
        fi
    done
    log_message "Cleanup complete. Exiting."
    echo "Cleanup complete. Exiting."
    exit 1
}

trap cleanup SIGINT SIGTERM

get_gpu_memory_usage() {
    local gpu_id=$1
    if ! command -v nvidia-smi &> /dev/null; then
        echo "0"
        return
    fi
    
    # Check if the GPU exists
    if ! nvidia-smi -i "$gpu_id" &> /dev/null; then
        echo "100"  # Return high usage for non-existent GPU
        return
    fi
    
    memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    total_memory=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    
    if [ -z "$total_memory" ] || [ "$total_memory" -eq 0 ]; then
        echo "100"
    else
        echo $((memory_usage * 100 / total_memory))
    fi
}

get_available_gpu() {
    # First check if nvidia-smi exists and can be executed
    if ! command -v nvidia-smi &> /dev/null; then
        echo "-1"
        return
    fi
    
    # Get number of GPUs directly from nvidia-smi list
    num_gpus=$(nvidia-smi --list-gpus | wc -l)
    
    # Validate num_gpus is a number and greater than 0
    if ! [[ "$num_gpus" =~ ^[0-9]+$ ]] || [ "$num_gpus" -eq 0 ]; then
        echo "-1"
        return
    fi

    local best_gpu=-1
    local min_usage=101

    for (( gpu_id=0; gpu_id<num_gpus; gpu_id++ )); do
        # Check if GPU is accessible
        if nvidia-smi -i "$gpu_id" &> /dev/null; then
            usage=$(get_gpu_memory_usage "$gpu_id")
            # Ensure usage is a number
            if ! [[ "$usage" =~ ^[0-9]+$ ]]; then
                usage=100
            fi
            if [ "$usage" -lt "$min_usage" ]; then
                min_usage=$usage
                best_gpu=$gpu_id
            fi
        fi
    done

    if [ "$best_gpu" -ge 0 ] && [ "$min_usage" -lt "$MAX_MEMORY_PERCENT" ]; then
        echo "$best_gpu"
    else
        # Return -2 if all GPUs are busy or inaccessible
        echo "-2"
    fi
}

wait_for_gpu() {
    while true; do
        local gpu_id
        gpu_id=$(get_available_gpu)
        
        case "$gpu_id" in
            "-1")
                log_message "No NVIDIA drivers or GPUs detected. Waiting..." >&2
                sleep 5
                ;;
            "-2")
                log_message "All GPUs are busy (usage > $MAX_MEMORY_PERCENT%). Waiting..." >&2
                sleep 3
                ;;
            *)
                if [ "$gpu_id" -ge 0 ]; then
                    log_message "Found available GPU $gpu_id" >&2
                    echo "$gpu_id"
                    return
                fi
                ;;
        esac
    done
}

# Function to check if task is already completed
is_task_completed() {
    local task_str="$1"
    if [ -f "$COMPLETED_LOG" ]; then
        grep -q "^$task_str$" "$COMPLETED_LOG"
        return $?
    fi
    return 1
}

# Function to run a single attack task
run_attack() {
    local attack=$1
    local dataset=$2
    local ptb_rate=$3
    local emb=$4
    
    local task_str="$attack $dataset $ptb_rate $emb"
    
    # Skip if task is already completed
    if is_task_completed "$task_str"; then
        log_message "Skipping completed task: $task_str"
        return
    fi

    # Wait for available GPU
    local gpu_id
    gpu_id=$(wait_for_gpu)
    
    log_message "Starting $attack on GPU $gpu_id: $dataset (ptb_rate=$ptb_rate, emb=$emb)"

    if [ "$attack" = "strg" ]; then
        for threshold in 0.5 ; do
            python gen_attacks.py \
                --dataset "$dataset" \
                --ptb_rate "$ptb_rate" \
                --attack "$attack" \
                --emb_type "$emb" \
                --threshold "$threshold" \
                --device "$gpu_id" \
                --re_split 1 &
            local pid=$!
            PIDS+=($pid)
            PID_TO_TASK[$pid]="$task_str threshold=$threshold"
        done
    elif [ "$attack" = "mettack" ]; then    
        python gen_attacks.py \
            --dataset "$dataset" \
            --ptb_rate "$ptb_rate" \
            --attack "$attack" \
            --emb_type "$emb" \
            --device "$gpu_id" \
            --re_split 1 &
        local pid=$!
        PIDS+=($pid)
        PID_TO_TASK[$pid]="$task_str"
    else
        python gen_attacks_inductive.py \
            --dataset "$dataset" \
            --ptb_rate "$ptb_rate" \
            --attack "$attack" \
            --emb_type "$emb" \
            --device "$gpu_id" &
        local pid=$!
        PIDS+=($pid)
        PID_TO_TASK[$pid]="$task_str"
    fi

    # Small delay to allow nvidia-smi to update
    sleep 5
}

# Function to check and update completed tasks
check_completed_tasks() {
    local any_completed=0
    local new_pids=()
    
    for pid in "${PIDS[@]}"; do
        if [ -z "$pid" ]; then
            continue
        fi
        if ! kill -0 "$pid" 2>/dev/null; then
            # Process has finished
            wait "$pid"
            local status=$?
            if [ $status -eq 0 ]; then
                mark_completed "${PID_TO_TASK[$pid]}"
            else
                mark_failed "${PID_TO_TASK[$pid]}"
            fi
            unset "PID_TO_TASK[$pid]"
            any_completed=1
        else
            new_pids+=($pid)
        fi
    done
    
    PIDS=("${new_pids[@]}")
    return $any_completed
}

log_message "Starting attack tasks... (Press Ctrl+C to stop)"
log_message "Total tasks to process: $TOTAL_TASKS"
echo "Starting attack tasks... (Press Ctrl+C to stop)"
echo "Total tasks to process: $TOTAL_TASKS"
echo "Progress will be logged to $MAIN_LOG"

# Main execution loop - process all attack types
for attack in "${!datasets_attacks[@]}"; do
    for ptb_rate in "${ptb_rates[@]}"; do
        for emb in "${emb_types[@]}"; do
            for dataset in ${datasets_attacks[$attack]}; do
                # Check if we have too many running processes
                while [ ${#PIDS[@]} -ge $MAX_CONCURRENT_PROCESSES ]; do
                    check_completed_tasks
                    sleep 5
                done
                run_attack "$attack" "$dataset" "$ptb_rate" "$emb"
            done
        done
    done
done

# Wait for remaining processes to complete
while [ ${#PIDS[@]} -gt 0 ]; do
    check_completed_tasks
    sleep 10
done

log_message "All attack tasks completed!"
echo "All attack tasks completed!"
echo "Check $MAIN_LOG for full progress log"
echo "Check $COMPLETED_LOG for completed tasks"
echo "Check $FAILED_LOG for failed tasks" 