#!/bin/bash

# LLaMA2-7B Gradient Descent + Sine based Parameter-Efficient Unlearning Script
# Default configuration: GD + Sine based Parameter-Efficient approach, rank 4, forget10

set -e  # Exit on any error

# Configuration
MODEL_FAMILY="llama2-7b"
SPLIT="forget10"
BATCH_SIZE=2
GRADIENT_ACCUMULATION_STEPS=16
NUM_EPOCHS=5
LEARNING_RATE=0.0001
LORA_RANK=4
LORA_ALPHA=$(( $LORA_RANK * 2 ))
LORA_DROPOUT=0
FORGET_LOSS="grad_diff"
SINE_FREQ=100
SINE_SCALE=45.25
SEED=42

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        --model_family)
            MODEL_FAMILY="$2"
            shift 2
            ;;
        --split)
            SPLIT="$2"
            shift 2
            ;;
        --batch_size)
            BATCH_SIZE="$2"
            shift 2
            ;;
        --gradient_accumulation_steps)
            GRADIENT_ACCUMULATION_STEPS="$2"
            shift 2
            ;;
        --num_epochs)
            NUM_EPOCHS="$2"
            shift 2
            ;;
        --lr)
            LEARNING_RATE="$2"
            shift 2
            ;;
        --lora_rank)
            LORA_RANK="$2"
            LORA_ALPHA=$(( $LORA_RANK * 2 ))
            shift 2
            ;;
        --forget_loss)
            FORGET_LOSS="$2"
            shift 2
            ;;
        --sine_freq)
            SINE_FREQ="$2"
            shift 2
            ;;
        --sine_scale)
            SINE_SCALE="$2"
            shift 2
            ;;
        --seed)
            SEED="$2"
            shift 2
            ;;
        --help|-h)
            echo "Usage: $0 [OPTIONS]"
            echo ""
            echo "Options:"
            echo "  --model_family MODEL         Model family to use (default: llama2-7b)"
            echo "  --split SPLIT               Dataset split to use (default: forget10)"
            echo "  --batch_size SIZE           Batch size (default: 2)"
            echo "  --gradient_accumulation_steps STEPS  Gradient accumulation steps (default: 16)"
            echo "  --num_epochs EPOCHS         Number of epochs (default: 5)"
            echo "  --lr LEARNING_RATE          Learning rate (default: 0.0001)"
            echo "  --lora_rank RANK            LoRA rank (default: 4)"
            echo "  --forget_loss LOSS          Forget loss type (default: grad_diff)"
            echo "  --sine_freq FREQ            Sine frequency (default: 100)"
            echo "  --sine_scale SCALE          Sine scale (default: 45.25)"
            echo "  --seed SEED                 Random seed (default: 42)"
            echo "  --help, -h                  Show this help message"
            exit 0
            ;;
        *)
            echo "Unknown option: $1"
            echo "Use --help for usage information"
            exit 1
            ;;
    esac
done

echo "=================================="
echo "LLaMA2 Sine based Parameter-Efficient Unlearning Setup"
echo "=================================="
echo "Model Family: $MODEL_FAMILY"
echo "Split: $SPLIT"
echo "Batch Size: $BATCH_SIZE"
echo "Gradient Accumulation Steps: $GRADIENT_ACCUMULATION_STEPS"
echo "Number of Epochs: $NUM_EPOCHS"
echo "Learning Rate: $LEARNING_RATE"
echo "LoRA Rank: $LORA_RANK"
echo "LoRA Alpha: $LORA_ALPHA"
echo "LoRA Dropout: $LORA_DROPOUT"
echo "Forget Loss: $FORGET_LOSS"
echo "Sine Frequency: $SINE_FREQ"
echo "Sine Scale: $SINE_SCALE"
echo "Seed: $SEED"
echo "=================================="

# Change to TOFU directory
cd TOFU

# Check if GPU is available
if command -v nvidia-smi &> /dev/null; then
    echo "GPU detected:"
    nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader,nounits
    echo ""
fi

# Set number of processes based on available GPUs
if command -v nvidia-smi &> /dev/null; then
    NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
    if [ $NUM_GPUS -eq 0 ]; then
        NUM_GPUS=1
    fi
else
    NUM_GPUS=1
fi

echo "Using $NUM_GPUS GPU(s)"

# Create directories if they don't exist
mkdir -p saved_weights
mkdir -p final_results

# Run the unlearning process
echo "Starting unlearning process..."
CUDA_VISIBLE_DEVICES=0 torchrun \
    --nproc_per_node=$NUM_GPUS \
    --master_port=28765 \
    forget.py \
    --config-name=forget.yaml \
    model_family=$MODEL_FAMILY \
    split=$SPLIT \
    batch_size=$BATCH_SIZE \
    gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
    num_epochs=$NUM_EPOCHS \
    lr=$LEARNING_RATE \
    LoRA.targets=ffn \
    LoRA.r=$LORA_RANK \
    LoRA.alpha=$LORA_ALPHA \
    LoRA.dropout=$LORA_DROPOUT \
    forget_loss=$FORGET_LOSS \
    use_sinelora=true \
    sine_freq=$SINE_FREQ \
    sine_scale=$SINE_SCALE \
    learnable_sine=false \
    seed=$SEED

echo "Unlearning process completed!"
echo "Results saved in saved_weights/ directory"