#!/bin/bash

# Configuration for text attacks
declare -A text_datasets_attacks_inductive=(
    ["textfooler"]="cora citeseer pubmed wikics instagram reddit history photo"
)

declare -A text_datasets_attacks_transductive=(
    ["textfooler"]="cora citeseer pubmed wikics instagram reddit history"
)

ptb_rates=(0.40)
emb_types=("MiniLM")  # Text embedding types
settings=("inductive")  # Both settings
MAX_MEMORY_PERCENT=95  # Maximum memory usage percentage
MAX_CONCURRENT_PROCESSES=2  # Maximum number of concurrent processes

# Setup logging
LOG_DIR="./logs_text_attack"
mkdir -p $LOG_DIR
MAIN_LOG="$LOG_DIR/text_attack_progress.log"
COMPLETED_LOG="$LOG_DIR/text_completed_tasks.log"
FAILED_LOG="$LOG_DIR/text_failed_tasks.log"

# Calculate total number of tasks
TOTAL_TASKS=0
for setting in "${settings[@]}"; do
    if [ "$setting" = "inductive" ]; then
        attacks_map="text_datasets_attacks_inductive"
    else
        attacks_map="text_datasets_attacks_transductive"
    fi
    
    # Get attack types for this setting
    if [ "$setting" = "inductive" ]; then
        attack_types=($(printf '%s\n' "${!text_datasets_attacks_inductive[@]}"))
    else
        attack_types=($(printf '%s\n' "${!text_datasets_attacks_transductive[@]}"))
    fi
    
    for attack in "${attack_types[@]}"; do
        if [ "$setting" = "inductive" ]; then
            datasets=${text_datasets_attacks_inductive[$attack]}
        else
            datasets=${text_datasets_attacks_transductive[$attack]}
        fi
        
        for dataset in $datasets; do
            for ptb_rate in "${ptb_rates[@]}"; do
                for emb in "${emb_types[@]}"; do
                    ((TOTAL_TASKS++))
                done
            done
        done
    done
done

COMPLETED_TASKS=0
if [ -f "$COMPLETED_LOG" ]; then
    COMPLETED_TASKS=$(wc -l < "$COMPLETED_LOG")
fi

# Function to log message with timestamp
log_message() {
    echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" | tee -a "$MAIN_LOG"
}

# Function to mark task as completed
mark_completed() {
    local task_str="$1"
    echo "$task_str" >> "$COMPLETED_LOG"
    ((COMPLETED_TASKS++))
    local progress=$((COMPLETED_TASKS * 100 / TOTAL_TASKS))
    log_message "Progress: $progress% ($COMPLETED_TASKS/$TOTAL_TASKS) - Completed: $task_str"
}

# Function to mark task as failed
mark_failed() {
    local task_str="$1"
    echo "$task_str" >> "$FAILED_LOG"
    log_message "FAILED: $task_str"
}

# Array to store background process PIDs
declare -a PIDS=()
declare -A PID_TO_TASK

# Trap Ctrl+C and cleanup
cleanup() {
    echo -e "\nReceived stop signal. Cleaning up..."
    log_message "Received stop signal. Cleaning up..."
    # Kill all background processes
    for pid in "${!PID_TO_TASK[@]}"; do
        if kill -0 "$pid" 2>/dev/null; then
            echo "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            log_message "Stopping process $pid (${PID_TO_TASK[$pid]})..."
            kill -9 "$pid"
        fi
    done
    log_message "Cleanup complete. Exiting."
    echo "Cleanup complete. Exiting."
    exit 1
}

trap cleanup SIGINT SIGTERM

get_gpu_memory_usage() {
    local gpu_id=$1
    if ! command -v nvidia-smi &> /dev/null; then
        echo "0"
        return
    fi
    
    # Check if the GPU exists
    if ! nvidia-smi -i "$gpu_id" &> /dev/null; then
        echo "100"  # Return high usage for non-existent GPU
        return
    fi
    
    memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    total_memory=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits -i "$gpu_id" 2>/dev/null || echo "0")
    
    if [ -z "$total_memory" ] || [ "$total_memory" -eq 0 ]; then
        echo "100"
    else
        echo $((memory_usage * 100 / total_memory))
    fi
}

get_available_gpu() {
    # First check if nvidia-smi exists and can be executed
    if ! command -v nvidia-smi &> /dev/null; then
        echo "-1"
        return
    fi
    
    # Get number of GPUs directly from nvidia-smi list
    num_gpus=$(nvidia-smi --list-gpus | wc -l)
    
    # Validate num_gpus is a number and greater than 0
    if ! [[ "$num_gpus" =~ ^[0-9]+$ ]] || [ "$num_gpus" -eq 0 ]; then
        echo "-1"
        return
    fi

    local best_gpu=-1
    local min_usage=101

    for (( gpu_id=0; gpu_id<num_gpus; gpu_id++ )); do
        # Check if GPU is accessible
        if nvidia-smi -i "$gpu_id" &> /dev/null; then
            usage=$(get_gpu_memory_usage "$gpu_id")
            # Ensure usage is a number
            if ! [[ "$usage" =~ ^[0-9]+$ ]]; then
                usage=100
            fi
            if [ "$usage" -lt "$min_usage" ]; then
                min_usage=$usage
                best_gpu=$gpu_id
            fi
        fi
    done

    if [ "$best_gpu" -ge 0 ] && [ "$min_usage" -lt "$MAX_MEMORY_PERCENT" ]; then
        echo "$best_gpu"
    else
        # Return -2 if all GPUs are busy or inaccessible
        echo "-2"
    fi
}

wait_for_gpu() {
    while true; do
        local gpu_id
        gpu_id=$(get_available_gpu)
        
        case "$gpu_id" in
            "-1")
                log_message "No NVIDIA drivers or GPUs detected. Waiting..." >&2
                sleep 5
                ;;
            "-2")
                log_message "All GPUs are busy (usage > $MAX_MEMORY_PERCENT%). Waiting..." >&2
                sleep 3
                ;;
            *)
                if [ "$gpu_id" -ge 0 ]; then
                    log_message "Found available GPU $gpu_id" >&2
                    echo "$gpu_id"
                    return
                fi
                ;;
        esac
    done
}

# Function to check if task is already completed
is_task_completed() {
    local task_str="$1"
    if [ -f "$COMPLETED_LOG" ]; then
        grep -q "^$task_str$" "$COMPLETED_LOG"
        return $?
    fi
    return 1
}

# Function to run a single text attack task
run_text_attack() {
    local setting=$1
    local attack=$2
    local dataset=$3
    local ptb_rate=$4
    local emb=$5
    
    local task_str="$setting $attack $dataset $ptb_rate $emb"
    
    # Skip if task is already completed
    if is_task_completed "$task_str"; then
        log_message "Skipping completed task: $task_str"
        return
    fi

    # Wait for available GPU
    local gpu_id
    gpu_id=$(wait_for_gpu)
    
    log_message "Starting text attack $attack ($setting) on GPU $gpu_id: $dataset (ptb_rate=$ptb_rate, emb=$emb)"

    if [ "$setting" = "inductive" ]; then
        python gen_text_attacks_inductive.py \
            --dataset "$dataset" \
            --ptb_rate "$ptb_rate" \
            --attack "$attack" \
            --emb_type "$emb" \
            --device 1 \
            --re_split 2 \
            --seeds 3 \
            --epochs 200 \
            --patience 15 \
            --use_batch \
            --attack_batch_size 16 &
    else
        python gen_text_attacks_transductive.py \
            --dataset "$dataset" \
            --ptb_rate "$ptb_rate" \
            --attack "$attack" \
            --emb_type "$emb" \
            --device 0 \
            --re_split 1 \
            --seeds 3 \
            --epochs 200 \
            --patience 15 \
            --use_batch \
            --attack_batch_size 16 &
    fi
    
    local pid=$!
    PIDS+=($pid)
    PID_TO_TASK[$pid]="$task_str"

    # Small delay to allow nvidia-smi to update
    sleep 5
}

# Function to check and update completed tasks
check_completed_tasks() {
    local any_completed=0
    local new_pids=()
    
    for pid in "${PIDS[@]}"; do
        if [ -z "$pid" ]; then
            continue
        fi
        if ! kill -0 "$pid" 2>/dev/null; then
            # Process has finished
            wait "$pid"
            local status=$?
            if [ $status -eq 0 ]; then
                mark_completed "${PID_TO_TASK[$pid]}"
            else
                mark_failed "${PID_TO_TASK[$pid]}"
            fi
            unset "PID_TO_TASK[$pid]"
            any_completed=1
        else
            new_pids+=($pid)
        fi
    done
    
    PIDS=("${new_pids[@]}")
    return $any_completed
}

log_message "Starting text attack tasks... (Press Ctrl+C to stop)"
log_message "Total tasks to process: $TOTAL_TASKS"
echo "Starting text attack tasks... (Press Ctrl+C to stop)"
echo "Total tasks to process: $TOTAL_TASKS"
echo "Progress will be logged to $MAIN_LOG"

# Main execution loop - process all text attack types for both settings
for setting in "${settings[@]}"; do
    if [ "$setting" = "inductive" ]; then
        attack_types=($(printf '%s\n' "${!text_datasets_attacks_inductive[@]}"))
    else
        attack_types=($(printf '%s\n' "${!text_datasets_attacks_transductive[@]}"))
    fi
    
    for attack in "${attack_types[@]}"; do
        if [ "$setting" = "inductive" ]; then
            datasets=${text_datasets_attacks_inductive[$attack]}
        else
            datasets=${text_datasets_attacks_transductive[$attack]}
        fi
        
        for ptb_rate in "${ptb_rates[@]}"; do
            for emb in "${emb_types[@]}"; do
                for dataset in $datasets; do
                    # Check if we have too many running processes
                    while [ ${#PIDS[@]} -ge $MAX_CONCURRENT_PROCESSES ]; do
                        check_completed_tasks
                        sleep 5
                    done
                    run_text_attack "$setting" "$attack" "$dataset" "$ptb_rate" "$emb"
                done
            done
        done
    done
done

# Wait for remaining processes to complete
while [ ${#PIDS[@]} -gt 0 ]; do
    check_completed_tasks
    sleep 10
done

log_message "All text attack tasks completed!"
echo "All text attack tasks completed!"
echo "Check $MAIN_LOG for full progress log"
echo "Check $COMPLETED_LOG for completed tasks"
echo "Check $FAILED_LOG for failed tasks" 