#!/bin/bash
set -euo pipefail

# Orchestrates MSA pipeline: finetune MSA components, then run MSA unlearn.
# Considers all three variants: instruct, tofu, pretrained.

# Defaults (match the refactored scripts)
model_family="Llama-3.2-1B"
forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"
alphas="0.75 1.0 1.25"
betas="0.0 0.75 1.0 1.25"
msa_epochs=5
run_what="instruct tofu pretrained"

# Optional overrides
tofu_model=""  # if empty, derived from model_family inside called scripts
train_devices="0,1"
eval_devices="0"
accelerate_config="configs/accelerate/default_config.yaml"
batch_size=4
use_constant_lr=0

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Runs MSA finetune first, then MSA unlearn, for instruct/tofu/pretrained.

Options:
  --model_family NAME         Llama model family (default: ${model_family})
  --tofu_model PATH           Path/HF id for TOFU target model (default derives from model_family)
  --forget_split NAME         Forget split (default: ${forget_split})
  --retain_split NAME         Retain split (default: ${retain_split})
  --holdout_split NAME        Holdout split (default: ${holdout_split})
  --alphas "a1 a2 ..."        Space/comma-separated alphas (default: ${alphas})
  --betas "b1 b2 ..."         Space/comma-separated betas (default: ${betas})
  --msa_epochs N              Epochs for MSA (default: ${msa_epochs})
  --run "modes"               instruct,tofu,pretrained (default: ${run_what})
  --train_devices DEVICES     CUDA devices for finetune (default: ${train_devices})
  --eval_devices DEVICES      CUDA devices for eval/unlearn (default: ${eval_devices})
  --accelerate_config PATH    Accelerate config (default: ${accelerate_config})
  --batch_size N              Per-device batch size for finetune (default: ${batch_size})
  --use_constant_lr           Use constant LR for finetune and unlearn
  -h | --help                 Show this help

Examples:
  # Run all variants with defaults
  ./scripts/experiments/msa.sh

  # Only tofu+pretrained with custom family and alphas/betas
  ./scripts/experiments/msa.sh \
    --run "tofu,pretrained" \
    --model_family Llama-3.2-8B \
    --alphas "0.75 1.0" --betas "0.0 0.75 1.0"
EOF
}

# Parse CLI
while [[ $# -gt 0 ]]; do
  case "$1" in
    --model_family) model_family="$2"; shift 2;;
    --tofu_model) tofu_model="$2"; shift 2;;
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --alphas) alphas="${2//,/ }"; shift 2;;
    --betas) betas="${2//,/ }"; shift 2;;
    --msa_epochs) msa_epochs="$2"; shift 2;;
    --run) run_what="${2//,/ }"; shift 2;;
    --train_devices) train_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    --accelerate_config) accelerate_config="$2"; shift 2;;
    --batch_size) batch_size="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# echo "[MSA] Finetune: model_family=${model_family} | splits=${forget_split}/${retain_split} | run=${run_what}"
# ./scripts/msa_tofu/msa_finetune.sh \
#   --model_family "${model_family}" \
#   --forget_split "${forget_split}" \
#   --retain_split "${retain_split}" \
#   --holdout_split "${holdout_split}" \
#   --epochs "${msa_epochs}" \
#   --batch_size "${batch_size}" \
#   --accelerate_config "${accelerate_config}" \
#   --train_devices "${train_devices}" \
#   --run "${run_what}" \
#   $([[ "$use_constant_lr" -eq 1 ]] && echo "--use_constant_lr") \
#   ${tofu_model:+--tofu_model "${tofu_model}"}

echo "[MSA] Unlearn: alphas=${alphas} | betas=${betas} | run=${run_what}"
./scripts/eval_msa/msa_unlearn.sh \
  --model_family "${model_family}" \
  --forget_split "${forget_split}" \
  --retain_split "${retain_split}" \
  --holdout_split "${holdout_split}" \
  --alphas "${alphas}" \
  --betas "${betas}" \
  --msa_epochs "${msa_epochs}" \
  --run "${run_what}" \
  --unlearn_devices "${train_devices}" \
  --eval_devices "${eval_devices}" \
  $([[ "$use_constant_lr" -eq 1 ]] && echo "--use_constant_lr") \
  ${tofu_model:+--tofu_model "${tofu_model}"}

echo "[MSA] Completed finetune + unlearn for: ${run_what} on splits ${forget_split}/${retain_split}."
