#!/bin/bash
# run_experiments.sh - Script to run evaluation, data attribution, forgetting, and final evaluation

# Set GPU configuration here
GPU_CONFIG="0,1,2,3"
NUM_GPUS=4

# Base parameters
MODEL_FAMILY="llama2-7b"  # Alternatives: llama2-7b, phi
SPLIT="forget10"  # Alternatives: full, forget01, forget05, forget10
LR="1e-5"
WEIGHT_DECAY="0.01"
NUM_EPOCHS="5"
BATCH_SIZE="1"
GRAD_ACCUM_STEPS="1"
DS_SIZE="300"
USE_FT="False"

# Attribution parameters
UNIFY_METHOD="exp"  # Alternatives: exp, power
ATTRIBUTION_METHOD="g_prod"  # Alternatives: g_norm, g_grad
TAU_VALUES=("0.03")

# Forgetting parameters
FORGET_LOSS="grad_ascent"  # Alternatives: grad_ascent, grad_diff, KL, idk, dpo
ATTRIBUTION="none"  # Alternatives: none, g_prod, g_norm
UNIFICATION="none"  # Alternatives: none, power, exp
TAU="0.03"

# Path to the finetuned model - IMPORTANT: Set this to your finetuned model path
FINETUNED_MODEL_PATH="checkpoints/ft_epoch5_lr1e-05_llama2-7b_full_wd0.01"

# Find all checkpoint directories
echo "Looking for checkpoint directories in ${FINETUNED_MODEL_PATH}"
CHECKPOINTS=$(find ${FINETUNED_MODEL_PATH} -maxdepth 1 -name "checkpoint-*" -type d | sort -V)

if [ -z "$CHECKPOINTS" ]; then
    echo "No checkpoint directories found in ${FINETUNED_MODEL_PATH}"
    exit 1
fi

echo "Found the following checkpoints: ${CHECKPOINTS}"

# Process each checkpoint
for CHECKPOINT in $CHECKPOINTS; do
    CHECKPOINT_NAME=$(basename $CHECKPOINT)
    echo "======= Processing ${CHECKPOINT_NAME} ======="

    # Step 1: Run evaluation for this checkpoint
    echo "Step 1: Evaluating checkpoint ${CHECKPOINT_NAME}"
    EVAL_DIR="${CHECKPOINT}/eval_results/ds_size${DS_SIZE}"
    mkdir -p ${EVAL_DIR}

    #CUDA_VISIBLE_DEVICES=${GPU_CONFIG} torchrun --nproc_per_node=${NUM_GPUS} 
    CUDA_VISIBLE_DEVICES=0 python evaluate_all.py \
        --config-name=eval_everything.yaml \
        model_family=${MODEL_FAMILY} \
        model_path=${CHECKPOINT} \
        split=${SPLIT}_perturbed \
        save_dir=${EVAL_DIR} \
        use_flash_attention_2=${USE_FT}

    # Step 2: Run data attribution for this checkpoint with specific parameters
    echo "Step 2: Running data attribution for ${CHECKPOINT_NAME}"

    for tau in "${TAU_VALUES[@]}"; do
        echo "  Running attribution with tau=${tau}"

        CUDA_VISIBLE_DEVICES=${GPU_CONFIG} torchrun --nproc_per_node=${NUM_GPUS} data_attribution.py \
            --config-name=data_attribution.yaml \
            split=${SPLIT} \
            model_family=${MODEL_FAMILY} \
            unify_method=${UNIFY_METHOD} \
            attribution_method=${ATTRIBUTION_METHOD} \
            model_path=${CHECKPOINT} \
            use_flash_attention_2=${USE_FT}

        # Check if attribution was successful
        ATTR_DICT_PATH="${CHECKPOINT}/${SPLIT}_${ATTRIBUTION_METHOD}_${UNIFICATION}t${tau}_influence_dict.json"
        if [ ! -f "${ATTR_DICT_PATH}" ]; then
            echo "Warning: Attribution result not found at ${ATTR_DICT_PATH}"
            continue
        fi
    done

    # Step 3: Run unlearning (forgetting) process
    echo "Step 3: Running unlearning for ${CHECKPOINT_NAME}"

    # Set paths for this checkpoint
    SCORE_DICT_PATH="${CHECKPOINT}/${SPLIT}_${ATTRIBUTION_METHOD}_${UNIFY_METHOD}t${TAU}_influence_dict.json"
    if [ ! -f "${SCORE_DICT_PATH}" ] || [ "${ATTRIBUTION}" == "none" ]; then
        echo "Using default influence_nonDA.json as score_dict_path"
        SCORE_DICT_PATH="influence_nonDA.json"
    fi

    FORGET_MODEL_DIR="${CHECKPOINT}/${SPLIT}_${FORGET_LOSS}_tao${TAU}_lr${LR}_wd_${WEIGHT_DECAY}_ep${NUM_EPOCHS}_bs${BATCH_SIZE}_DA_${ATTRIBUTION}_${UNIFICATION}"

    TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=${GPU_CONFIG} torchrun --nproc_per_node=${NUM_GPUS} forget.py \
        --config-name=forget.yaml \
        forget_loss=${FORGET_LOSS} \
        attribution=${ATTRIBUTION} \
        unification=${UNIFICATION} \
        tao=${TAU} \
        score_dict_path=${SCORE_DICT_PATH} \
        model_path=${CHECKPOINT} \
        save_dir=${FORGET_MODEL_DIR} \
        split=${SPLIT} \
        batch_size=${BATCH_SIZE} \
        gradient_accumulation_steps=${GRAD_ACCUM_STEPS} \
        model_family=${MODEL_FAMILY} \
        lr=${LR} \
        weight_decay=${WEIGHT_DECAY} \
        num_epochs=${NUM_EPOCHS} \
        use_flash_attention_2=${USE_FT}

    # Find the latest checkpoint in the forget model directory
    FORGET_CHECKPOINTS=$(find ${FORGET_MODEL_DIR} -maxdepth 1 -name "checkpoint-*" -type d | sort -V)
    LATEST_FORGET_CHECKPOINT=$(echo "$FORGET_CHECKPOINTS" | tail -n 1)

    if [ -z "$LATEST_FORGET_CHECKPOINT" ]; then
        echo "Warning: No checkpoint found in ${FORGET_MODEL_DIR}"
        continue
    fi

    # Step 4: Final evaluation of the forgotten model
    echo "Step 4: Evaluating forgotten model from ${LATEST_FORGET_CHECKPOINT}"

    FORGET_EVAL_DIR="${LATEST_FORGET_CHECKPOINT}/eval_results/ds_size${DS_SIZE}"
    mkdir -p ${FORGET_EVAL_DIR}

    CUDA_VISIBLE_DEVICES=${GPU_CONFIG} torchrun --nproc_per_node=${NUM_GPUS} evaluate_all.py \
        --config-name=eval_everything.yaml \
        model_family=${MODEL_FAMILY} \
        model_path=${LATEST_FORGET_CHECKPOINT} \
        ckpt_result="${CHECKPOINT}/eval_results/ds_size${DS_SIZE}/eval_log_aggregated.json" \
        split=${SPLIT}_perturbed \
        save_dir=${FORGET_EVAL_DIR} \
        use_flash_attention_2=${USE_FT}

    # Step 5: Calculate ROUGE scores
    echo "Step 5: Calculating ROUGE scores"
    CUDA_VISIBLE_DEVICES=0 python get_rouge.py file_path="${FORGET_EVAL_DIR}/eval_log_aggregated.json"
    
    echo "======= Finished processing ${CHECKPOINT_NAME} ======="
done

echo "All experiments completed."

# Optional: Run additional experiments with different parameter combinations
# Add these if needed based on your requirements
#
# run_experiment() {
#     local checkpoint=$1
#     local loss=$2
#     local attr=$3
#     local unif=$4
#     local tau=$5
#
#     echo "Running experiment: ${checkpoint} - ${loss}/${attr}/${unif}/tau${tau}"
#     
#     # Set paths for this experiment
#     local score_dict_path="${checkpoint}/${SPLIT}_${attr}_${unif}t${tau}_influence_dict.json"
#     if [ ! -f "${score_dict_path}" ] || [ "${attr}" == "none" ]; then
#         score_dict_path="influence_nonDA.json"
#     fi
#     
#     local forget_dir="${checkpoint}/${SPLIT}_${loss}_tao${tau}_lr${LR}_wd_${WEIGHT_DECAY}_ep${NUM_EPOCHS}_bs${BATCH_SIZE}_DA_${attr}_${unif}"
#     
#     # Run unlearning
#     TOKENIZERS_PARALLELISM=false CUDA_VISIBLE_DEVICES=${GPU_CONFIG} torchrun --nproc_per_node=${NUM_GPUS} forget.py \
#         --config-name=forget.yaml \
#         forget_loss=${loss} \
#         attribution=${attr} \
#         unification=${unif} \
#         tao=${tau} \
#         score_dict_path=${score_dict_path} \
#         model_path=${checkpoint} \
#         save_dir=${forget_dir} \
#         split=${SPLIT} \
#         batch_size=${BATCH_SIZE} \
#         gradient_accumulation_steps=${GRAD_ACCUM_STEPS} \
#         model_family=${MODEL_FAMILY} \
#         lr=${LR} \
#         weight_decay=${WEIGHT_DECAY} \
#         num_epochs=${NUM_EPOCHS} \
#         > ${SPLIT}_${loss}_${attr}_${unif}_tao${tau}_ep${NUM_EPOCHS}.log 2>&1
#
#     # Find the latest checkpoint
#     local forget_checkpoints=$(find ${forget_dir} -maxdepth 1 -name "checkpoint-*" -type d | sort -V)
#     local latest_forget_checkpoint=$(echo "$forget_checkpoints" | tail -n 1)
#     
#     if [ -z "$latest_forget_checkpoint" ]; then
#         echo "Warning: No checkpoint found in ${forget_dir}"
#         return
#     fi
#     
#     # Evaluate
#     local forget_eval_dir="${latest_forget_checkpoint}/eval_results/ds_size${DS_SIZE}"
#     mkdir -p ${forget_eval_dir}
#     
#     CUDA_VISIBLE_DEVICES=${GPU_CONFIG} torchrun --nproc_per_node=${NUM_GPUS} evaluate_all.py \
#         --config-name=eval_everything.yaml \
#         model_family=${MODEL_FAMILY} \
#         model_path=${latest_forget_checkpoint} \
#         ckpt_result="${checkpoint}/eval_results/ds_size${DS_SIZE}/eval_log_aggregated.json" \
#         split=${SPLIT}_perturbed \
#         save_dir=${forget_eval_dir} \
#         > ${SPLIT}_eval_${loss}_${attr}_${unif}_tao${tau}_ep${NUM_EPOCHS}.log 2>&1
#     
#     # Calculate ROUGE
#     CUDA_VISIBLE_DEVICES=0 python get_rouge.py file_path="${forget_eval_dir}/eval_log_aggregated.json" \
#         > ${SPLIT}_rouge_${loss}_${attr}_${unif}_tao${tau}_ep${NUM_EPOCHS}.log 2>&1
# }
#
# # Example usage of the run_experiment function:
# # for CHECKPOINT in $CHECKPOINTS; do
# #    for LOSS in "grad_ascent" "grad_diff" "KL" "dpo"; do
# #        run_experiment "$CHECKPOINT" "$LOSS" "g_prod" "exp" "0.03"
# #    done
# # done