#!/bin/bash
set -euo pipefail

# Orchestrate NPO, GradDiff, RMU, SatImp, UnDial, and MSA for forget01/retain99/holdout10 with explicit args.

# Defaults (explicitly passed below so you can see what runs)
model="Llama-3.1-8B-Instruct"

tofu_model="open-unlearning/tofu_${model}_full"

# tofu_model="saves/finetune/tofu_Llama-3.1-8B-Instruct_full_constant_lr"
model_family="Llama-3.1-8B"  # for MSA pipeline

forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"


# Flag toggles for all compatible calls
# Set to 1 to include the flag; 0 to omit
use_constant_lr=0
forget_only=1


# Timing helpers
sec_to_hms() { # seconds -> HH:MM:SS
  local total=$1; local h=$((total/3600)); local m=$(((total%3600)/60)); local s=$((total%60))
  printf "%02d:%02d:%02d" "$h" "$m" "$s"
}

declare -a TIMINGS

echo "[Experiment] RMU grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/rmu_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --alphas "2 4" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] RMU grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("RMU: $(sec_to_hms "$dur")")


echo "[Experiment] SatImp grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/satlmp_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --lrs "1e-5 1e-4" \
  --gammas "8.0" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] SatImp grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("SatImp: $(sec_to_hms "$dur")")


echo "[Experiment] UnDial grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/undial_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --lrs "5e-5 1e-4" \
  --alphas "1" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] UnDial grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("UnDial: $(sec_to_hms "$dur")")


# echo "[Experiment] NPO grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/npo_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --lrs "5e-6" \
#   --alphas "4" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] NPO grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("NPO: $(sec_to_hms "$dur")")



# echo "[Experiment] GradDiff grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/grad_diff_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --lrs "1e-5" \
#   --alphas "4 8" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] GradDiff grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("GradDiff: $(sec_to_hms "$dur")")



# echo "[Experiment] MSA finetune + unlearn for ${model_family} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/experiments/msa.sh \
#   --model_family "${model_family}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --alphas "0.5 0.75 1.0 1.25 1.5 3" \
#   --betas "0.0 0.5 0.75 1.0 1.25 1.5" \
#   --msa_epochs 5 \
#   --run "instruct tofu pretrained" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] MSA duration: $(sec_to_hms "$dur")"
# TIMINGS+=("MSA: $(sec_to_hms "$dur")")

# echo "[Experiment] Completed NPO, GradDiff, RMU, SatImp, UnDial, and MSA for ${forget_split}/${retain_split}."
# echo "[Timing] Summary:"
# for line in "${TIMINGS[@]}"; do
#   echo "  - ${line}"
# done
