#!/bin/bash

# Change to InstructionTuning directory
cd ../LLMPredictor/InstructionTuning/

# Check if required parameters are provided
if [ $# -lt 3 ]; then
    echo "Usage: bash run_sft_atk_trans.sh <dataset_name> <attack> <ptb_rate> [seed] [llm_model] [atk_emb_type] [prompt_type]"
    echo "Supported datasets: cora, citeseer, wikics, instagram, pubmed, reddit, photo, computer, history"
    echo "Supported attacks: pgd, grbcd, prbcd, textfooler, llm"
    echo "Supported LLMs: Mistral-7B (default)"
    echo "Note: Attack seed will automatically be set to the same value as data seed"
    echo "Special: Set seed=-1 to run seeds 0,1,2 in parallel on one GPU"
    echo "Example: bash run_sft_atk_trans.sh cora llm 0.2 0 Mistral-7B bow neighbor_label"
    echo "Example (parallel): bash run_sft_atk_trans.sh cora llm 0.2 -1 Mistral-7B bow neighbor_label"
    exit 1
fi

DATASET=$1
ATTACK=$2
PTB_RATE=$3
SEED=${4:-0}                      # Default seed is 0
LLM=${5:-Mistral-7B}              # Default LLM is Mistral-7B
ATK_EMB_TYPE=${6:-bow}            # Default attack embedding type is bow
PROMPT_TYPE=${7:-neighbor}  # Default prompt type

# Function to run single experiment
run_single_experiment() {
    local current_seed=$1
    local atk_seed=$current_seed  # Attack seed must be same as data seed
    
    # Determine attack type based on attack name
    if [[ "$ATTACK" == "textfooler" || "$ATTACK" == "llm" || "$ATTACK" == "gpt" ]]; then
        ATK_TYPE="text"
    else
        ATK_TYPE="structure"
    fi

    echo "Starting experiment with seed $current_seed at $(date)..."
    echo "Dataset: $DATASET, Attack: $ATTACK, PTB Rate: $PTB_RATE, Seed: $current_seed, LLM: $LLM"
    echo "Attack Embedding Type: $ATK_EMB_TYPE, Attack Seed: $atk_seed, Attack Type: $ATK_TYPE"
    echo "Prompt Type: $PROMPT_TYPE"

    # Create logs directory if it doesn't exist
    mkdir -p ./logs_atk

    # Define Python commands based on dataset
    if [ "$DATASET" = "cora" ] || [ "$DATASET" = "citeseer" ]; then
        accelerate launch train_atk.py --num_epoch=16 --llm=$LLM --batch_size=4 --re_split=1 --dataset=$DATASET --max_txt_length=80 --max_origin_txt_length=200 --seed=$current_seed --prompt_type=$PROMPT_TYPE --attack=$ATTACK --atk_type=$ATK_TYPE --ptb_rate=$PTB_RATE --atk_emb_type=$ATK_EMB_TYPE --atk_seed=$atk_seed 2>&1 | tee ./logs_atk/attack_transductive_${DATASET}_${LLM}_${PROMPT_TYPE}_${ATTACK}_${ATK_TYPE}_ptb${PTB_RATE//.}_${ATK_EMB_TYPE}_atkseed${atk_seed}_seed${current_seed}.log

    elif [ "$DATASET" = "wikics" ] || [ "$DATASET" = "pubmed" ]; then
        # Middle-scale datasets (wikics, pubmed)
        echo "Running $DATASET with middle-scale configuration (wikics/pubmed)..."
        # Pubmed: Avg Query Prompt Length 51.0000 | Avg OriginTxT Length 425.6667 | Avg Output Length 4.3333
        # WikiCS: Avg Query Prompt Length 76.0000 | Avg OriginTxT Length 597.9836 | Avg Output Length 3.1882
        accelerate launch train_atk.py --num_epoch=4 --llm=$LLM --batch_size=4 --re_split=1 --dataset=$DATASET --max_txt_length=80 --max_origin_txt_length=200 --seed=$current_seed --prompt_type=$PROMPT_TYPE --attack=$ATTACK --atk_type=$ATK_TYPE --ptb_rate=$PTB_RATE --atk_emb_type=$ATK_EMB_TYPE --atk_seed=$atk_seed 2>&1 | tee ./logs_atk/attack_transductive_${DATASET}_${LLM}_${PROMPT_TYPE}_${ATTACK}_${ATK_TYPE}_ptb${PTB_RATE//.}_${ATK_EMB_TYPE}_atkseed${atk_seed}_seed${current_seed}.log

    elif [ "$DATASET" = "instagram" ]; then
        # Instagram special configuration
        echo "Running $DATASET with instagram configuration..."
        # Instagram: Avg Query Prompt Length 37.0000 | Avg OriginTxT Length 54.8707 | Avg Output Length 2.0000
        accelerate launch train_atk.py --num_epoch=8 --llm=$LLM --batch_size=12 --re_split=1 --dataset=$DATASET --max_txt_length=80 --max_origin_txt_length=100 --seed=$current_seed --prompt_type=$PROMPT_TYPE --attack=$ATTACK --atk_type=$ATK_TYPE --ptb_rate=$PTB_RATE --atk_emb_type=$ATK_EMB_TYPE --atk_seed=$atk_seed 2>&1 | tee ./logs_atk/attack_transductive_${DATASET}_${LLM}_${PROMPT_TYPE}_${ATTACK}_${ATK_TYPE}_ptb${PTB_RATE//.}_${ATK_EMB_TYPE}_atkseed${atk_seed}_seed${current_seed}.log

    elif [ "$DATASET" = "reddit" ] || [ "$DATASET" = "photo" ] || [ "$DATASET" = "computer" ] || [ "$DATASET" = "history" ]; then
        # Large-scale datasets
        echo "Running $DATASET with large-scale configuration..."
        accelerate launch train_atk.py --num_epoch=2 --llm=$LLM --batch_size=4 --re_split=1 --dataset=$DATASET --max_txt_length=80 --max_origin_txt_length=200 --seed=$current_seed --prompt_type=$PROMPT_TYPE --attack=$ATTACK --atk_type=$ATK_TYPE --ptb_rate=$PTB_RATE --atk_emb_type=$ATK_EMB_TYPE --atk_seed=$atk_seed 2>&1 | tee ./logs_atk/attack_transductive_${DATASET}_${LLM}_${PROMPT_TYPE}_${ATTACK}_${ATK_TYPE}_ptb${PTB_RATE//.}_${ATK_EMB_TYPE}_atkseed${atk_seed}_seed${current_seed}.log
    else
        echo "Error: Unsupported dataset '$DATASET'"
        echo "Supported datasets: cora, citeseer, wikics, instagram, pubmed, reddit, photo, computer, history"
        exit 1
    fi
}

# Check if parallel execution is requested (seed=-1)
if [ "$SEED" = "-1" ]; then
    echo "==================== PARALLEL EXECUTION MODE ===================="
    echo "Running experiments with seeds 0, 1, 2 in parallel at $(date)..."
    echo "Dataset: $DATASET, Attack: $ATTACK, PTB Rate: $PTB_RATE, LLM: $LLM"
    echo "Attack Embedding Type: $ATK_EMB_TYPE, Prompt Type: $PROMPT_TYPE"
    echo "=================================================================="
    
    # Start experiments in parallel
    run_single_experiment 0 &
    PID1=$!
    
    run_single_experiment 1 &
    PID2=$!
    
    run_single_experiment 2 &
    PID3=$!
    
    echo "Started 3 parallel experiments with PIDs: $PID1, $PID2, $PID3"
    
    # Wait for all background processes to complete
    wait $PID1
    EXIT_CODE1=$?
    echo "Experiment with seed 0 completed with exit code: $EXIT_CODE1"
    
    wait $PID2
    EXIT_CODE2=$?
    echo "Experiment with seed 1 completed with exit code: $EXIT_CODE2"
    
    wait $PID3
    EXIT_CODE3=$?
    echo "Experiment with seed 2 completed with exit code: $EXIT_CODE3"
    
    echo "==================== ALL EXPERIMENTS COMPLETED ===================="
    echo "Parallel execution completed at $(date)."
    echo "Exit codes: seed0=$EXIT_CODE1, seed1=$EXIT_CODE2, seed2=$EXIT_CODE3"
    echo "Check log files for individual results."
    echo "=================================================================="
else
    # Single experiment execution
    echo "==================== SINGLE EXPERIMENT MODE ===================="
    run_single_experiment $SEED
    echo "Single experiment completed at $(date). Check log files for results."
fi 