#!/bin/bash
set -euo pipefail

# Orchestrate NPO, GradDiff, RMU, SatImp, UnDial, and MSA for forget01/retain99/holdout10 with explicit args.

# Defaults (explicitly passed below so you can see what runs)
model="Llama-3.2-1B-Instruct"

# tofu_model="open-unlearning/tofu_${model}_full"

tofu_model="saves/finetune/tofu_Llama-3.2-1B-Instruct_full_constant_lr"
model_family="Llama-3.2-1B"  # for MSA pipeline

forget_split="forget01"
retain_split="retain99"
holdout_split="holdout10"


# Flag toggles for all compatible calls
# Set to 1 to include the flag; 0 to omit
use_constant_lr=1
forget_only=0


# Timing helpers
sec_to_hms() { # seconds -> HH:MM:SS
  local total=$1; local h=$((total/3600)); local m=$(((total%3600)/60)); local s=$((total%60))
  printf "%02d:%02d:%02d" "$h" "$m" "$s"
}

declare -a TIMINGS


echo "[Experiment] RMU grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/rmu_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --alphas "1 2 4" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] RMU grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("RMU: $(sec_to_hms "$dur")")



echo "[Experiment] SatImp grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/satlmp_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --gammas "0.1 1.0 4.0" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] SatImp grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("SatImp: $(sec_to_hms "$dur")")



echo "[Experiment] UnDial grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/undial_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --lrs "1e-5 2e-5" \
  --alphas "1 2 4" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] UnDial grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("UnDial: $(sec_to_hms "$dur")")



echo "[Experiment] NPO grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/npo_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --lrs "1e-5 2e-5" \
  --alphas "2 4 8" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] NPO grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("NPO: $(sec_to_hms "$dur")")



echo "[Experiment] GradDiff grid search for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/custom_tofu_unlearn/grad_diff_tofu_unlearn_grid_search.sh \
  --model "${model}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --lrs "1e-5 2e-5" \
  --alphas "1 2 4" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
  $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] GradDiff grid duration: $(sec_to_hms "$dur")"
TIMINGS+=("GradDiff: $(sec_to_hms "$dur")")





echo "[Experiment] MSA finetune + unlearn for ${model_family} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/experiments/msa.sh \
  --model_family "${model_family}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --alphas "1.0 1.5 2.5 3.0 3.5" \
  --betas "0.0 0.5 0.75 1.0 1.25 1.5" \
  --msa_epochs 5 \
  --run "instruct tofu pretrained" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] MSA duration: $(sec_to_hms "$dur")"
TIMINGS+=("MSA: $(sec_to_hms "$dur")")


echo "[Experiment] Completed NPO, GradDiff, RMU, SatImp, UnDial, and MSA for ${forget_split}/${retain_split}."
echo "[Timing] Summary:"
for line in "${TIMINGS[@]}"; do
  echo "  - ${line}"
done


# [Timing] Summary:
#   - RMU: 00:24:04
#   - SatImp: 01:07:10
#   - UnDial: 00:41:24
#   - NPO: 00:28:18
#   - GradDiff: 00:23:16
#   - MSA: 01:47:18
