#!/bin/bash
set -euo pipefail

# Orchestrate NPO, GradDiff, RMU, SatImp, UnDial, and MSA for forget01/retain99/holdout10 with explicit args.

# Defaults (explicitly passed below so you can see what runs)
model="Olmo-2-1B-stage1-final"

# tofu_model="open-unlearning/tofu_${model}_full"

tofu_model="saves/finetune/tofu_Olmo-2-1B-stage1-final_full_constant_lr"
model_family="tofu_Olmo-2-1B-stage1"  # for MSA pipeline

forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"


# Flag toggles for all compatible calls
# Set to 1 to include the flag; 0 to omit
use_constant_lr=1
forget_only=0


# Timing helpers
sec_to_hms() { # seconds -> HH:MM:SS
  local total=$1; local h=$((total/3600)); local m=$(((total%3600)/60)); local s=$((total%60))
  printf "%02d:%02d:%02d" "$h" "$m" "$s"
}

declare -a TIMINGS


# echo "[Experiment] RMU grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/rmu_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --alphas "1 2 4" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] RMU grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("RMU: $(sec_to_hms "$dur")")


# echo "[Experiment] SatImp grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/satlmp_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --gammas "0.1 1.0 4.0" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] SatImp grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("SatImp: $(sec_to_hms "$dur")")


# echo "[Experiment] UnDial grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/undial_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --lrs "1e-5 2e-5" \
#   --alphas "1 2 4" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] UnDial grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("UnDial: $(sec_to_hms "$dur")")


# echo "[Experiment] NPO grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/npo_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --lrs "1e-5" \
#   --alphas "2 4 8" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] NPO grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("NPO: $(sec_to_hms "$dur")")



# echo "[Experiment] GradDiff grid search for ${model} on ${forget_split}/${retain_split}"
# start_ts=$(date +%s)
# ./scripts/custom_tofu_unlearn/grad_diff_tofu_unlearn_grid_search.sh \
#   --model "${model}" \
#   --tofu_model "${tofu_model}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --lrs "1e-5 2e-5" \
#   --alphas "1 2 4" \
#   $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr") \
#   $([[ "${forget_only:-1}" -eq 1 ]] && echo "--forget-only")
# dur=$(( $(date +%s) - start_ts ))
# echo "[Timing] GradDiff grid duration: $(sec_to_hms "$dur")"
# TIMINGS+=("GradDiff: $(sec_to_hms "$dur")")


echo "[Experiment] MSA unlearning for ${model} on ${forget_split}/${retain_split}"
start_ts=$(date +%s)
./scripts/experiments/Olmo/msa_unlearn.sh \
  --target_model "${tofu_model}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --alphas "0.75 1.0 1.5 3.0" \
  --betas "0.75 1.0 1.25" \
  --msa_epochs 5 \
  $([[ "${use_constant_lr:-1}" -eq 1 ]] && echo "--use_constant_lr")


dur=$(( $(date +%s) - start_ts ))
echo "[Timing] MSA unlearning duration: $(sec_to_hms "$dur")"
TIMINGS+=("MSA: $(sec_to_hms "$dur")")
