#!/bin/bash
set -euo pipefail

forget_split="forget"
retain_split="retain"

msa_epochs=5

model_family="allenai/OLMo-2-1124-7B"
muse_model="saves/finetune/muse_books_full_olmo2"
data_split="Books"


alphas=(0.75 1.0 1.25 3.0)
betas=(0.0 0.5 1.0 1.5)

# Method tags and their corresponding models
method_tags=(
  # "final"
  # "3859B"
  # "3691B"
  # "2207B"
  "muse_books"
)

models=(
  # "stage1-step928000-tokens3893B"
  # "stage1-step920000-tokens3859B"
  # "stage1-step880000-tokens3691B"
  # "stage1-step526000-tokens2207B"
  "Olmo-2-7B"
)

unlearn_devices="0,1"
eval_devices="0"

usage() {
  cat <<EOF
Usage: $(basename "$0") --run TAG [options]

Where TAG is one of:
  ${method_tags[*]}

Options:
  --muse_model PATH          Path or HF id for the MUSE model (default: ${muse_model})
  --alphas "a1 a2 ..."       Alpha values (default: ${alphas[*]})
  --betas "b1 b2 ..."        Beta values (default: ${betas[*]})
  --run "tags"               Which tags to run (space/comma-separated, default: all)
  --unlearn_devices DEVICES  CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES     CUDA devices for eval step (default: ${eval_devices})
  -h | --help                Show this help

Examples:
  # Run for final checkpoint
  ./scripts/eval_msa/msa_unlearn.sh --run final

  # Run multiple tags
  ./scripts/eval_msa/msa_unlearn.sh --run "3859B 3691B"

  # Run all tags
  ./scripts/eval_msa/msa_unlearn.sh --run "final 3859B 3691B 2207B muse_books"
EOF
}

# Defaults to all tags
run_tags=("${method_tags[@]}")

# Parse CLI args
while [[ $# -gt 0 ]]; do
  case "$1" in
    --model_family) model_family="$2"; shift 2;;
    --muse_model) muse_model="$2"; shift 2;;
    --alphas) s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --betas) s="$2"; s="${s//,/ }"; read -r -a betas <<< "$s"; shift 2;;
    --run) s="$2"; s="${s//,/ }"; read -r -a run_tags <<< "$s"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

run_msa_eval() {
  local tag="$1"
  local model_revision=""
  local forget_model=""
  local retain_model=""
  local base_model_path=""

  # Map tag → model
  for i in "${!method_tags[@]}"; do
    if [[ "${method_tags[$i]}" == "$tag" ]]; then
      model_revision="${models[$i]}"
      break
    fi
  done

  if [[ -z "$model_revision" ]]; then
    echo "Error: Unknown tag '$tag'. Must be one of: ${method_tags[*]}"
    exit 1
  fi

  forget_model="saves/finetune/msa_muse_books_${tag}_${forget_split}"
  retain_model="saves/finetune/msa_muse_books_${tag}_${retain_split}"
  base_model_revision="${model_revision}"

  echo "Running MSA | tag=$tag | model=$model_revision"
  echo "Forget model: $forget_model"
  echo "Retain model: $retain_model"
  echo "Alphas: ${alphas[*]} | Betas: ${betas[*]}"
  echo "Model Revision: ${base_model_revision}"

  for alpha in "${alphas[@]}"; do
    for beta in "${betas[@]}"; do
      path_to_unlearned_model="saves/unlearn_msa/olmo2/${tag}/msa_alpha_${alpha}_beta_${beta}"

      echo ">>> Unlearning for tag=${tag}, alpha=${alpha}, beta=${beta}"

      eval_output_dir="${path_to_unlearned_model}/evals"
      if [[ -d "${eval_output_dir}" ]]; then
        echo "[SKIP] Found results at ${eval_output_dir}, skipping..."
        continue
      fi

      CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
        --base_model "${muse_model}" \
        --forget_model "$forget_model" \
        --retain_model "$retain_model" \
        --target_model "${muse_model}" \
        --alpha "-$alpha" \
        --beta "$beta" \
        --save_path "$path_to_unlearned_model"

      echo ">>> Unlearning done."

      CUDA_VISIBLE_DEVICES=${eval_devices} python src/eval.py \
        experiment=eval/muse/default.yaml \
        data_split=${data_split} \
        task_name=${path_to_unlearned_model} \
        model="Olmo-2-7B" \
        model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
        paths.output_dir=${path_to_unlearned_model}/evals \
        retain_logs_path=saves/eval/muse_Llama-2-7b-hf_${data_split}_retrain/MUSE_EVAL.json

      find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

      echo ">>> Evaluation done."
    done
  done
}

# Run for requested tags
for tag in "${run_tags[@]}"; do
  run_msa_eval "$tag"
done