#!/bin/bash

set -euo pipefail

# Orchestrate NPO, GradDiff, RMU, SatImp, UnDial, and MSA for forget01/retain99/holdout10 with explicit args.

# Defaults (explicitly passed below so you can see what runs)
model="Llama-3.1-8B-Instruct"

tofu_model="open-unlearning/tofu_${model}_full"

# tofu_model="saves/finetune/tofu_Llama-3.1-8B-Instruct_full_constant_lr"
model_family="Llama-3.1-8B"  # for MSA pipeline

forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"


# Flag toggles for all compatible calls
# Set to 1 to include the flag; 0 to omit
use_constant_lr=0
forget_only=0


# Timing helpers
sec_to_hms() { # seconds -> HH:MM:SS
  local total=$1; local h=$((total/3600)); local m=$(((total%3600)/60)); local s=$((total%60))
  printf "%02d:%02d:%02d" "$h" "$m" "$s"
}

declare -a TIMINGS


echo "[Experiment] MSA finetune + unlearn for ${model_family} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/experiments/msa.sh \
  --model_family "${model_family}" \
  --tofu_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --alphas "0.5" \
  --betas "0.0" \
  --msa_epochs 5 \
  --run "instruct" \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr")
dur=$(( $(date +%s) - start_ts ))
echo "[Timing] MSA duration: $(sec_to_hms "$dur")"
TIMINGS+=("MSA: $(sec_to_hms "$dur")")

echo "[Experiment] Completed NPO, GradDiff, RMU, SatImp, UnDial, and MSA for ${forget_split}/${retain_split}."
echo "[Timing] Summary:"
for line in "${TIMINGS[@]}"; do
  echo "  - ${line}"
done

# --alphas "0.5 0.75 1.0 1.25 1.5 3" \
  # --betas "0.0 0.5 0.75 1.0 1.25 1.5" \