#!/bin/bash
set -euo pipefail

# Single-run NPO unlearning on TOFU for a specific (alpha, lr, checkpoint)

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

# Single model (Hydra model key), trainer, and experiment
model="Llama-3.2-1B-Instruct"
trainer="NPO"
experiment="unlearn/tofu/default.yaml"

# Dataset splits (CLI-overridable)
forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

# TOFU model path for initialization (CLI-overridable)
path_to_tofu_model="open-unlearning/tofu_${model}_full"

# Training batch parameters
per_device_train_batch_size=4
gradient_accumulation_steps=4

# Optional flags
use_constant_lr=0
forget_only=0

# Specific hyperparameters (CLI)
lr=""
alpha=""
ckpt=""  # numeric step or full name like checkpoint-100

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --model NAME                 Hydra model key (default: ${model})
  --tofu_model PATH            Pretrained TOFU model path or HF id (default: ${path_to_tofu_model})
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --lr VALUE                  Learning rate (required)
  --alpha VALUE               Alpha (required)
  --ckpt VALUE                Checkpoint number or name (e.g., 100 or checkpoint-100) (required)
  --use_constant_lr            Use constant LR scheduler and suffix task_name
  --forget-only               Use forget_only experiment and suffix task_name
  -h | --help                  Show this help

Example:
  ./scripts/custom_tofu_unlearn/npo_tofu_unlearn_grid_search.sh \
    --model Llama-3.2-1B-Instruct \
    --tofu_model open-unlearning/tofu_Llama-3.2-1B-Instruct_full \
    --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
    --lr 1e-5 --alpha 2 --ckpt 100
EOF
}

# Parse CLI
while [[ $# -gt 0 ]]; do
  case "$1" in
    --model) model="$2"; shift 2;;
    --tofu_model) path_to_tofu_model="$2"; shift 2;;
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --lr) lr="$2"; shift 2;;
    --alpha) alpha="$2"; shift 2;;
    --ckpt) ckpt="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    --forget-only) forget_only=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# Require specific hyperparameters
if [[ -z "${lr}" || -z "${alpha}" || -z "${ckpt}" ]]; then
  echo "[ERROR] --lr, --alpha, and --ckpt are required." >&2
  usage
  exit 1
fi

# Adjust experiment based on forget_only flag
if [[ "${forget_only}" -eq 1 ]]; then
  experiment="unlearn/tofu/forget_only.yaml"
fi

echo "Using single checkpoint: ${ckpt}"

# Normalize ckpt to directory name (e.g., 100 -> checkpoint-100)
if [[ "${ckpt}" != checkpoint-* ]]; then
  ckpt_dir="checkpoint-${ckpt}"
else
  ckpt_dir="${ckpt}"
fi

# Include alpha and lr in task name; e.g., tofu_Llama-3.2-1B-Instruct_forget10_NPO_a2_lr1e-5
task_name="tofu_${model}_${forget_split}_${trainer}_a${alpha}_lr${lr}"
if [[ "${use_constant_lr}" -eq 1 ]]; then
  task_name+="_constant_lr"
fi
if [[ "${forget_only}" -eq 1 ]]; then
  task_name+="_forget_only"
fi
echo "Running unlearn: ${task_name}"

# Skip training if output dir already exists; otherwise run unlearning
if [ -d "saves/unlearn/${task_name}" ]; then
  echo "Directory saves/unlearn/${task_name} exists. Skipping training."
fi

# cmd=(accelerate launch --config_file configs/accelerate/default_config.yaml --main_process_port "$MASTER_PORT" \
#   src/train.py --config-name=unlearn.yaml \
#   experiment=${experiment} \
#   trainer=${trainer} \
#   task_name=${task_name} \
#   model=${model} \
#   forget_split=${forget_split} \
#   retain_split=${retain_split} \
#   model.model_args.pretrained_model_name_or_path=${path_to_tofu_model} \
#   retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json \
#   trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
#   trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
#   trainer.args.ddp_find_unused_parameters=true \
#   trainer.args.gradient_checkpointing=true \
#   trainer.args.eval_strategy=no \
#   trainer.args.learning_rate=${lr} \
#   trainer.method_args.alpha=${alpha})
# if [[ "${use_constant_lr}" -eq 1 ]]; then
#   cmd+=(trainer.args.lr_scheduler_type=constant)
# fi
# CUDA_VISIBLE_DEVICES=0,1 "${cmd[@]}"
  
# echo "Completed training: ${task_name}"

# Evaluate only the requested checkpoint; store results in root evals/ dir
if [ -d "saves/unlearn/${task_name}/${ckpt_dir}" ]; then
  echo "Evaluating checkpoint: ${ckpt_dir}"
  eval_dir_root="saves/unlearn/${task_name}/${ckpt_dir}/evals"
  if [ -d "${eval_dir_root}" ]; then
    echo "Eval output exists at ${eval_dir_root}. Skipping evaluation."
  else
    CUDA_VISIBLE_DEVICES=1 python src/eval.py \
      experiment=eval/tofu/default.yaml \
      forget_split=${forget_split} \
      holdout_split=${holdout_split} \
      model=${model} \
      is_validation=false \
      task_name=${task_name}_${ckpt_dir} \
      model.model_args.pretrained_model_name_or_path=saves/unlearn/${task_name}/${ckpt_dir} \
      paths.output_dir=${eval_dir_root} \
      retain_logs_path=saves/eval/tofu_${model}_${retain_split}/TOFU_EVAL.json
  fi

  # # Cleanup: remove only safetensors inside checkpoints to save space, keep checkpoint dirs
  # echo "Removing *.safetensors files under saves/unlearn/${task_name}/checkpoint-*/"
  # find "saves/unlearn/${task_name}" -type f -path "*/checkpoint-*/model.safetensors" -exec rm -f {} +
  # find "saves/unlearn/${task_name}" -type f -path "*/checkpoint-*/model.safetensor" -exec rm -f {} +

  # # Cleanup stray safetensors
  # for mf in "model.safetensors" "model.safetensor"; do
  #   root_model_path="saves/unlearn/${task_name}/${mf}"
  #   if [ -f "${root_model_path}" ]; then
  #     echo "Removing ${root_model_path}"
  #     rm -f "${root_model_path}"
  #   fi
  # done
else
  echo "[WARN] Requested checkpoint directory saves/unlearn/${task_name}/${ckpt_dir} does not exist."
fi
