#!/bin/bash

# ==============================
# Configurable Variables
# ==============================
DEBUG="false"

DATASET="mnist"
TRAIN_INDEX="outputs/dataset_indexes/${DATASET}-train-index.csv"
TEST_INDEX="outputs/dataset_indexes/${DATASET}-test-index.csv"
CHECKPOINT_BASE="outputs/in_models/${DATASET}_scratch.pt"
RETRAINED_CHECKPOINT="outputs/retrained/models/${DATASET}_retrained.pt"
EVAL_SAVE_PATH="outputs/eval"
MASK_PATH="outputs/masks/${DATASET}/with_0.7.pt"

NUM_RETAINS=10
NUM_FORGETS=1000
FORGET_LABELS=7
LR_SCHEDULER="cswr"
SEED=42
BATCH_SIZE=1024
EVAL_BATCH_SIZE=1024

# ==============================
# Debug Mode Settings
# ==============================
RETRAIN_ARGS=""
if [ "$DEBUG" = "true" ]; then
    echo "!!! RUNNING IN DEBUG MODE (2 epochs) !!!"
    RETRAIN_ARGS="--num_epochs 2"
fi

# ==============================
# 1. Run Retraining
# ==============================
echo "Starting Retraining..."
python retraining.py \
    --dataset ${DATASET} \
    --forget_labels ${FORGET_LABELS} \
    --index_file ${TRAIN_INDEX} \
    --test_index_file ${TEST_INDEX} \
    --lr_scheduler ${LR_SCHEDULER} \
    --batch_size ${BATCH_SIZE} \
    -lr 1e-3 \
    --weight_decay 1e-5 \
    --seed ${SEED} \
    ${RETRAIN_ARGS}
echo "Retraining Completed"

# ==============================
# 2. Methods and their parameters
# ==============================
declare -A methods
methods["ft"]="--method ft --num_epochs 10 --learning_rate 1e-2"
methods["l1_sparse"]="--method l1_sparse --learning_rate 1e-2 --alpha 1e-4 --num_epochs 10"
methods["ga"]="--method ga --num_epochs 8 --learning_rate 5e-4"
methods["neg_grad"]="--method neg_grad --num_epochs 8 --learning_rate 5e-4 --alpha 0.9"
methods["w_fisher"]="--method w_fisher --alpha 5"
methods["random_label"]="--method random_label --learning_rate 1e-3 --num_epochs 10"
methods["scrub"]="--method scrub --forget_steps 3 --num_epochs 10 --learning_rate 1e-3 --beta 0.3 --gamma 0.7"
methods["uul"]="--method uul --num_epochs 10 --learning_rate 5e-4"
methods["salun"]="--method salun --mask_path ${MASK_PATH} --learning_rate 1e-3 --num_epochs 10"

# ==============================
# Shared Arguments
# ==============================
UNLEARN_COMMON_ARGS="--num_retains ${NUM_RETAINS} --num_forgets ${NUM_FORGETS} --batch_size ${BATCH_SIZE} --dataset ${DATASET} --index_file ${TRAIN_INDEX} --test_index_file ${TEST_INDEX} --lr_scheduler ${LR_SCHEDULER} --forget_labels ${FORGET_LABELS} --checkpoint ${CHECKPOINT_BASE} --seed ${SEED}"
EVAL_COMMON_ARGS="--dataset ${DATASET} --forget_labels ${FORGET_LABELS} --samples_per_member 10000 --val_percent 0.2 --batch_size ${EVAL_BATCH_SIZE} --seed ${SEED} --index_file ${TRAIN_INDEX} --test_index_file ${TEST_INDEX} --retrained_checkpoint ${RETRAINED_CHECKPOINT} --save_path ${EVAL_SAVE_PATH}"

# ==============================
# 3. Run Unlearning Methods and Evaluate
# ==============================
run_methods=("ft" "l1_sparse" "ga" "neg_grad" "w_fisher" "random_label" "scrub" "uul")
# run_methods=("w_fisher" "uul")

for method_name in "${run_methods[@]}"; do
    echo "--- Running Unlearning for ${method_name} ---"
    method_args="${methods[$method_name]}"
    if [ "$DEBUG" = "true" ]; then
        method_args="${method_args} --num_runs 2"
        if [[ $method_args == *"--num_epochs"* ]]; then
            echo "Debug mode: Overriding epochs to 2 for ${method_name}"
            method_args=$(echo "$method_args" | sed -E 's/--num_epochs[[:space:]]+[0-9]+(\.[0-9]+)?/--num_epochs 2/g')
        else
            echo "Debug mode: Adding epochs=2 to ${method_name}"
            method_args="${method_args} --num_epochs 2"
        fi
    fi

    python unlearning.py ${method_args} ${UNLEARN_COMMON_ARGS}
    echo "Unlearning for ${method_name} Completed"

    echo "--- Evaluating ${method_name} ---"
    METHOD_CHECKPOINT_DIR="outputs/unlearn/${method_name}/${DATASET}/models/"
    python eval_unlearning.py --method ${method_name} --checkpoints ${METHOD_CHECKPOINT_DIR} ${EVAL_COMMON_ARGS}
    echo "Evaluation for ${method_name} Completed"
done

# ==============================
# 4. Saliency Mask Generation (Prerequisite for SalUN)
# ==============================
echo "--- Generating Saliency Mask ---"
python gen_saliency_masks.py \
    --num_retains ${NUM_RETAINS} \
    --num_forgets ${NUM_FORGETS} \
    --dataset ${DATASET} \
    --index_file ${TRAIN_INDEX} \
    --test_index_file ${TEST_INDEX} \
    --forget_labels ${FORGET_LABELS} \
    --checkpoint ${CHECKPOINT_BASE} \
    --batch_size ${BATCH_SIZE} \
    --seed ${SEED} \
    --learning_rate 1e-3 \
    --momentum 0.01 \
    --weight_decay 1e-5 \
    --gpu 0
echo "Saliency Map Generation Completed"

# ==============================
# 5. Run and Evaluate SalUN
# ==============================
SALUN_METHOD_NAME="salun"
method_args="${methods[$SALUN_METHOD_NAME]}"
if [ "$DEBUG" = "true" ]; then
    method_args="${method_args} --num_runs 2"
    if [[ $method_args == *"--num_epochs"* ]]; then
        echo "Debug mode: Overriding epochs to 2 for ${SALUN_METHOD_NAME}"
        method_args=$(echo "$method_args" | sed -E 's/--num_epochs[[:space:]]+[0-9]+(\.[0-9]+)?/--num_epochs 2/g')
    else
        echo "Debug mode: Adding epochs=2 to ${SALUN_METHOD_NAME}"
        method_args="${method_args} --num_epochs 2"
    fi
fi
echo "--- Running Unlearning for ${SALUN_METHOD_NAME} ---"
python unlearning.py ${method_args} ${UNLEARN_COMMON_ARGS}
echo "Unlearning for ${SALUN_METHOD_NAME} Completed"

echo "--- Evaluating ${SALUN_METHOD_NAME} ---"
SALUN_CHECKPOINT_DIR="outputs/unlearn/${SALUN_METHOD_NAME}/${DATASET}/models/"
python eval_unlearning.py --method ${SALUN_METHOD_NAME} --checkpoints ${SALUN_CHECKPOINT_DIR} ${EVAL_COMMON_ARGS}
echo "Evaluation for ${SALUN_METHOD_NAME} Completed"

echo "All processes finished."
