#!/bin/bash
set -euo pipefail

# Grid search over learning rate and alpha for NPO unlearning on TOFU

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

# Single model (Hydra model key), trainer, and experiment
model="Llama-3.2-1B-Instruct"
trainer="RMU"
experiment="unlearn/tofu/default.yaml"

# Dataset splits (CLI-overridable)
forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

# TOFU model path for initialization (CLI-overridable)
path_to_tofu_model="open-unlearning/tofu_${model}_full"

# Training batch parameters
per_device_train_batch_size=4
gradient_accumulation_steps=4

# Optional flags
use_constant_lr=0
forget_only=0

# Hyperparameter grid (CLI-overridable)
alphas=(1 2 4)


usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --model NAME                 Hydra model key (default: ${model})
  --tofu_model PATH            Pretrained TOFU model path or HF id (default: ${path_to_tofu_model})
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --lrs "lr1 lr2 ..."         Space/comma-separated learning rates (default: ${lrs[*]})
  --alphas "a1 a2 ..."        Space/comma-separated alphas (default: ${alphas[*]})
  --use_constant_lr            Use constant LR scheduler and suffix task_name
  --forget-only               Use forget_only experiment and suffix task_name
  -h | --help                  Show this help

Example:
  ./scripts/custom_tofu_unlearn/npo_tofu_unlearn_grid_search.sh \
    --model Llama-3.2-1B-Instruct \
    --tofu_model open-unlearning/tofu_Llama-3.2-1B-Instruct_full \
    --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
    --lrs "1e-5 1e-4" --alphas "1 2 4 8"
EOF
}

# Parse CLI
while [[ $# -gt 0 ]]; do
  case "$1" in
    --model) model="$2"; shift 2;;
    --tofu_model) path_to_tofu_model="$2"; shift 2;;
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --alphas) s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    --forget-only) forget_only=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# Adjust experiment based on forget_only flag
if [[ "${forget_only}" -eq 1 ]]; then
  experiment="unlearn/tofu/forget_only.yaml"
fi

# Select evaluation checkpoints based on forget_split
checkpoints_for_forget01=(
  "checkpoint-1"
  "checkpoint-2"
  "checkpoint-3"
  "checkpoint-5"
  "checkpoint-6"
  "checkpoint-7"
  "checkpoint-8"
  "checkpoint-10"
)

checkpoints_for_forget05=(
  "checkpoint-6"
  "checkpoint-12"
  "checkpoint-18"
  "checkpoint-25"
  "checkpoint-31"
  "checkpoint-37"
  "checkpoint-43"
  "checkpoint-50"
  "checkpoint-56"
  "checkpoint-60"
)

# checkpoints_for_forget10=(
#   "checkpoint-12"
#   "checkpoint-25"
#   "checkpoint-37"
#   "checkpoint-50"
#   "checkpoint-60"
# )

checkpoints_for_forget10=(
  "checkpoint-13"
  "checkpoint-26"
  "checkpoint-39"
  "checkpoint-52"
  "checkpoint-60"
)

case "${forget_split}" in
  forget01)
    checkpoints=("${checkpoints_for_forget01[@]}")
    ;;
  forget05)
    checkpoints=("${checkpoints_for_forget05[@]}")
    ;;
  forget10)
    checkpoints=("${checkpoints_for_forget10[@]}")
    ;;
  *)
    echo "[Warn] Unknown forget_split='${forget_split}', defaulting to forget10 checkpoints"
    checkpoints=("${checkpoints_for_forget10[@]}")
    ;;
esac


echo "Using checkpoints for ${forget_split}: ${checkpoints[*]}"


for alpha in "${alphas[@]}"; do
    # Include alpha and lr in task name; e.g., tofu_Llama-3.2-1B-Instruct_forget10_NPO_a2_lr1e-5
    task_name="tofu_${model}_${forget_split}_${trainer}_a${alpha}"
    if [[ "${use_constant_lr}" -eq 1 ]]; then
      task_name+="_constant_lr"
    fi
    if [[ "${forget_only}" -eq 1 ]]; then
      task_name+="_forget_only"
    fi
    echo "Running unlearn: ${task_name}"

    # Skip training if output dir already exists; otherwise run unlearning
    if [ -d "saves/unlearn/${task_name}" ]; then
      echo "Directory saves/unlearn/${task_name} exists. Skipping training."
    else
      cmd=(accelerate launch --config_file configs/accelerate/default_config.yaml --main_process_port "$MASTER_PORT" \
        src/train.py --config-name=unlearn.yaml \
        experiment=${experiment} \
        trainer=${trainer} \
        task_name=${task_name} \
        model=${model} \
        forget_split=${forget_split} \
        retain_split=${retain_split} \
        model.model_args.pretrained_model_name_or_path=${path_to_tofu_model} \
        retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
        trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
        trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
        trainer.args.ddp_find_unused_parameters=true \
        trainer.args.eval_strategy=no \
        trainer.args.gradient_checkpointing=true \
        trainer.method_args.alpha=${alpha})
      if [[ "${use_constant_lr}" -eq 1 ]]; then
        cmd+=(trainer.args.lr_scheduler_type=constant)
      fi
      CUDA_VISIBLE_DEVICES=0,1 "${cmd[@]}"
        
      echo "Completed training: ${task_name}"
    fi

    # Eval multiple checkpoints (only odd 0-based indices), but always clean safetensors
    for idx in "${!checkpoints[@]}"; do
      checkpoint="${checkpoints[$idx]}"
      if [ -d "saves/unlearn/${task_name}/${checkpoint}" ]; then
        
        echo "Evaluating checkpoint: ${checkpoint} (index ${idx})"
        eval_dir="saves/unlearn/${task_name}/${checkpoint}/evals_val"
        if [ -d "${eval_dir}" ]; then
          echo "Eval output exists at ${eval_dir}. Skipping evaluation."
        else
          CUDA_VISIBLE_DEVICES=0 python src/eval.py \
            experiment=eval/tofu/default.yaml \
            forget_split=${forget_split} \
            holdout_split=${holdout_split} \
            model=${model} \
            is_validation=true \
            task_name=${task_name}_${checkpoint} \
            model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name}/${checkpoint} \
            paths.output_dir=saves/unlearn/${task_name}/${checkpoint}/evals_val \
            retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json
        fi

      else
        echo "Checkpoint ${checkpoint} does not exist, skipping."
      fi
    done

    python pruner.py ${task_name}
    best_ckpt="$(python pruner.py ${task_name} | awk '/^Best checkpoint: /{print $3; exit}')"

    echo "Best checkpoint for ${task_name} is ${best_ckpt}"

    for idx in "${!checkpoints[@]}"; do
      checkpoint="${checkpoints[$idx]}"
      if [ "${checkpoint}" != "checkpoint-${best_ckpt}" ] && [ -d "saves/unlearn/${task_name}/${checkpoint}" ]; then
        echo "Removing non-best checkpoint: ${checkpoint}"
        find "saves/unlearn/${task_name}/${checkpoint}" -maxdepth 1 -type f -name "*.safetensors" -exec rm -f {} \;
      fi
    done

    # Remove all *.safetensors files in the root task directory
    find "saves/unlearn/${task_name}" -maxdepth 1 -type f -name "*.safetensors" -exec rm -f {} \;

done
