#!/bin/bash

set -euo pipefail

forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

msa_epochs=5

# Alpha/Beta defaults
alphas=(0.75 1 1.25 3.0)
betas=(0.75 1.0 1.25)

# Device controls
unlearn_devices="0,1"
eval_devices="0"

# Optional flags
use_constant_lr=0

# OLMo model names (no .yaml), e.g. Olmo-2-1B-stage1-210B ... final
# model_names=(
#   "Olmo-2-1B-stage1-210B"
#   "Olmo-2-1B-stage1-504B"
#   "Olmo-2-1B-stage1-2098B"
#   "Olmo-2-1B-stage1-3146B"
#   "Olmo-2-1B-stage1-3566B"
#   "Olmo-2-1B-stage1-final"
# )

# model_revisions=(
#   "stage1-step100000-tokens210B"
#   "stage1-step240000-tokens504B"
#   "stage1-step1000000-tokens2098B"
#   "stage1-step1500000-tokens3146B"
#   "stage1-step1700000-tokens3566B"
#   "stage1-step1900000-tokens3985B"
# )

model_names=(
  "Olmo-2-7B-stage1-final"
  "Olmo-2-7B-stage1-3859B"
  "Olmo-2-7B-stage1-3691B"
  "Olmo-2-7B-stage1-2207B"
  "Olmo-2-7B-stage1-500B"
)

model_revisions=(
  "stage1-step928000-tokens3893B"
  "stage1-step920000-tokens3859B"
  "stage1-step880000-tokens3691B"
  "stage1-step526000-tokens2207B"
  "stage1-step119000-tokens500B"
)

base_model_hf_path="allenai/OLMo-2-1124-7B"


# Target model path (required)
target_model="saves/finetune/tofu_Olmo-2-7B-stage1-final_full"

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --alphas "a1 a2 ..."        Space/comma-separated alpha values (default: ${alphas[*]})
  --betas "b1 b2 ..."         Space/comma-separated beta values (default: ${betas[*]})
  --msa_epochs N               Number of MSA epochs (default: ${msa_epochs})
  --unlearn_devices DEVICES    CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES       CUDA devices for eval step (default: ${eval_devices})
  --use_constant_lr            Use constant LR finetuned MSA models (suffix paths)
  -h | --help                  Show this help

Example:
  ./scripts/experiments/Olmo/msa_unlearn.sh \
    --target_model "saves/finetune/tofu_Olmo-2-1B-stage1-final_full_constant_lr" \
    --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
    --use_constant_lr
EOF
}

# Parse CLI args
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --target_model) target_model="$2"; shift 2;;
    --alphas) s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --betas) s="$2"; s="${s//,/ }"; read -r -a betas <<< "$s"; shift 2;;
    --msa_epochs) msa_epochs="$2"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done



task_suffix=""
if [[ "${use_constant_lr}" -eq 1 ]]; then
  task_suffix="_constant_lr"
fi

# Require target model argument
if [[ -z "${target_model}" ]]; then
  echo "[ERROR] --target_model is required (path to the target model)." >&2
  exit 1
fi

if [[ ${#model_names[@]} -ne ${#model_revisions[@]} ]]; then
  echo "[ERROR] model_names and model_revisions length mismatch: ${#model_names[@]} vs ${#model_revisions[@]}" >&2
  exit 1
fi

for idx in "${!model_names[@]}"; do
  model="${model_names[$idx]}"
  revision="${model_revisions[$idx]}"
  echo "[OLMo] Unlearn via MSA for base=${model} (rev=${revision}) | forget=${forget_split} retain=${retain_split}"

  # Paths to forget/retain MSA finetuned models from msa_finetune.sh
  forget_model="saves/obtain_msa/msa_${model}_${forget_split}${task_suffix}"
  retain_model="saves/obtain_msa_split/msa_${model}_${retain_split}${task_suffix}"

  # Target model: provided via CLI
  tm="${target_model}"

  # Output container for this base model's unlearned blends
  out_dir="${forget_split}_${model}${task_suffix}"

  # Pick the requested alpha/beta per model for test evaluation
  case "${model}" in
    *"500B")  alpha="3.0"; beta="1.0" ;;
    *"2207B")  alpha="1.25"; beta="0.75" ;;
    *"3691B") alpha="1.25"; beta="1.0" ;;
    *"3859B") alpha="1.25"; beta="1.25" ;;
    *"final") alpha="1.25"; beta="1.0" ;;
    *) echo "[SKIP] No test config specified for ${model}"; continue ;;
  esac

  echo "[TEST] Using alpha=${alpha}, beta=${beta} | epochs=${msa_epochs}"

  path_to_unlearned_model="saves/unlearn_msa/${out_dir}/msa_alpha_${alpha}_beta_${beta}"

  # Skip if test eval results already exist for this setting
  eval_output_dir="${path_to_unlearned_model}/evals"
  if [[ -d "${eval_output_dir}" ]]; then
    echo "[SKIP] Found existing test results at ${eval_output_dir}. Skipping alpha=${alpha}, beta=${beta}."
    continue
  fi

  CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
    --base_model "${base_model_hf_path}" \
    --base_revision "${revision}" \
    --forget_model "${forget_model}" \
    --retain_model "${retain_model}" \
    --target_model "${tm}" \
    --alpha "-${alpha}" \
    --beta "${beta}" \
    --save_path "${path_to_unlearned_model}"

  echo "Unlearning done for base=${model}, alpha=${alpha}, beta=${beta}"

  CUDA_VISIBLE_DEVICES="${eval_devices}" python src/eval.py \
    experiment=eval/tofu/default.yaml \
    forget_split=${forget_split} \
    holdout_split=${holdout_split} \
    model=${model} \
    is_validation=false \
    task_name=${path_to_unlearned_model} \
    model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
    paths.output_dir=${path_to_unlearned_model}/evals \
    retain_logs_path=saves/eval/tofu_Olmo-2-7B-stage1-final_${retain_split}/TOFU_EVAL.json

  find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

  echo "[TEST] Evaluation done for base=${model}, alpha=${alpha}, beta=${beta}"
done

# Cheatsheet:
#   ./scripts/experiments/Olmo/msa_unlearn.sh \
#     --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" --use_constant_lr
