#!/bin/bash
# -------------------------------
# Environment variables
# -------------------------------
export HF_HUB_OFFLINE=1
export HF_ALLOW_CODE_EVAL=1
export SEED=42

# -------------------------------
# General training parameters
# -------------------------------
LR=1e-4
EPOCHS=6
BATCH_SIZE_TRAIN=2
BATCH_SIZE_EVAL=2
WARMUP_STEPS=100
EVAL_STEPS=100
SAVE_STEPS=100
SAVE_TOTAL_LIMIT=1
BF16=True
DATASET_NAME="custom_dataset"
WEIGHT_DECAY=1e-4

# LoRA parameters
LORA_R=4
LORA_ALPHA=16
LORA_DROPOUT=0.05

# Fisher information paths
LORA_FISHER_PATH=""
FULL_FISHER_PATH=""

# Reference dataset
REF_DATASET_NAME="./ref_dataset"
FISHER_RATIO=0.6
EWC_LAMBDA=0.0
DYNAMIC_STRATEGY=history

# -------------------------------
# Experiment configuration
# -------------------------------
# Default configuration: LoRA training with Fisher constraints
JOB_NAME="lora_fisher"
USE_LORA=True
USE_FISHER=True
GPU_COUNT=2
FISHER_MODE="hard"

echo "Running $JOB_NAME | USE_LORA=$USE_LORA | USE_FISHER=$USE_FISHER | GPU_COUNT=$GPU_COUNT | FISHER_MODE=$FISHER_MODE"

# -------------------------------
# Set Fisher path
# -------------------------------
if [ "$USE_LORA" = "True" ]; then
    FISHER_PATH="$LORA_FISHER_PATH"
else
    FISHER_PATH="$FULL_FISHER_PATH"
fi
echo "Using FISHER_PATH=$FISHER_PATH"

# -------------------------------
# Output directory
# -------------------------------
OUTPUT_DIR="./outputs/lora_fisher_output"
mkdir -p $OUTPUT_DIR
echo "Output directory: $OUTPUT_DIR"

# -------------------------------
# Display GPU usage
# -------------------------------
nvidia-smi

# -------------------------------
# Call training script
# -------------------------------
python train.py \
  --logging-steps=2 \
  --per-device-train-batch-size=$BATCH_SIZE_TRAIN \
  --per-device-eval-batch-size=$BATCH_SIZE_EVAL \
  --num-train-epochs=$EPOCHS \
  --learning-rate=$LR \
  --do-eval=True \
  --eval-strategy="steps" \
  --save-strategy="steps" \
  --eval-steps=$EVAL_STEPS \
  --save-steps=$SAVE_STEPS \
  --save_total_limit=$SAVE_TOTAL_LIMIT \
  --load_best_model_at_end=True \
  --lr_scheduler_type="cosine" \
  --bf16=$BF16 \
  --dataset_name=$DATASET_NAME \
  --remove_unused_columns=True \
  --warmup_steps=$WARMUP_STEPS \
  --output_dir=$OUTPUT_DIR \
  --fisher_path=$FISHER_PATH \
  --fisher_ratio=$FISHER_RATIO \
  --fisher_mode=$FISHER_MODE \
  --ewc_lambda=$EWC_LAMBDA \
  --weight_decay=$WEIGHT_DECAY \
  --dynamic_k_strategy=$DYNAMIC_STRATEGY \
  --use_lora=True \
  --lora_r=$LORA_R \
  --lora_alpha=$LORA_ALPHA \
  --lora_dropout=$LORA_DROPOUT \
  --ref_dataset_name=$REF_DATASET_NAME

echo "Training completed."