#!/bin/bash

set -euo pipefail

forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

msa_epochs=5

# Alpha/Beta defaults
alphas=(0.75 1.0 1.5 3.0)
betas=(0.75 1.0 1.25)

# Device controls
unlearn_devices="0,1"
eval_devices="0"

# Optional flags
use_constant_lr=0


model_names=(
  "Olmo-2-7B-stage1-final"
  "Olmo-2-7B-stage1-3859B"
  "Olmo-2-7B-stage1-3691B"
  "Olmo-2-7B-stage1-2207B"
  "Olmo-2-7B-stage1-500B"
)

model_revisions=(
  "stage1-step928000-tokens3893B"
  "stage1-step920000-tokens3859B"
  "stage1-step880000-tokens3691B"
  "stage1-step526000-tokens2207B"
  "stage1-step119000-tokens500B"
)


# base_model_hf_path="allenai/OLMo-2-0425-1B"
base_model_hf_path="allenai/OLMo-2-1124-7B"


# Target model path (required)
target_model="saves/finetune/tofu_Olmo-2-7B-stage1-final_full"

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --target_model PATH          Target model path (required)
  --alphas "a1 a2 ..."        Space/comma-separated alpha values (default: ${alphas[*]})
  --betas "b1 b2 ..."         Space/comma-separated beta values (default: ${betas[*]})
  --msa_epochs N               Number of MSA epochs (default: ${msa_epochs})
  --unlearn_devices DEVICES    CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES       CUDA devices for eval step (default: ${eval_devices})
  --use_constant_lr            Use constant LR finetuned MSA models (suffix paths)
  -h | --help                  Show this help

Example:
  ./scripts/experiments/Olmo/msa_unlearn.sh \
    --target_model "saves/finetune/tofu_Olmo-2-1B-stage1-final_full_constant_lr" \
    --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
    --use_constant_lr
EOF
}

# Parse CLI args
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --target_model) target_model="$2"; shift 2;;
    --alphas) s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --betas) s="$2"; s="${s//,/ }"; read -r -a betas <<< "$s"; shift 2;;
    --msa_epochs) msa_epochs="$2"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done



task_suffix=""
if [[ "${use_constant_lr}" -eq 1 ]]; then
  task_suffix="_constant_lr"
fi

# Require target model argument
if [[ -z "${target_model}" ]]; then
  echo "[ERROR] --target_model is required (path to the target model)." >&2
  exit 1
fi

if [[ ${#model_names[@]} -ne ${#model_revisions[@]} ]]; then
  echo "[ERROR] model_names and model_revisions length mismatch: ${#model_names[@]} vs ${#model_revisions[@]}" >&2
  exit 1
fi

for idx in "${!model_names[@]}"; do
  model="${model_names[$idx]}"
  revision="${model_revisions[$idx]}"
  echo "[OLMo] Unlearn via MSA for base=${model} (rev=${revision}) | forget=${forget_split} retain=${retain_split}"

  # Paths to forget/retain MSA finetuned models from msa_finetune.sh
  forget_model="saves/obtain_msa/msa_${model}_${forget_split}${task_suffix}"
  retain_model="saves/obtain_msa_split/msa_${model}_${retain_split}${task_suffix}"

  # Target model: provided via CLI
  tm="${target_model}"

  # Output container for this base model's unlearned blends
  out_dir="${forget_split}_${model}${task_suffix}"

  echo "Alphas: ${alphas[*]} | Betas: ${betas[*]} | epochs=${msa_epochs}"

  for alpha in "${alphas[@]}"; do
    for beta in "${betas[@]}"; do
      path_to_unlearned_model="saves/unlearn_msa/${out_dir}/msa_alpha_${alpha}_beta_${beta}"

      # Skip if eval results already exist for this setting
      eval_output_dir="${path_to_unlearned_model}/evals_val"
      if [[ -d "${eval_output_dir}" ]]; then
        echo "[SKIP] Found existing results at ${eval_output_dir}. Skipping alpha=${alpha}, beta=${beta}."
        continue
      fi


      CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
        --base_model "${base_model_hf_path}" \
        --base_revision "${revision}" \
        --forget_model "${forget_model}" \
        --retain_model "${retain_model}" \
        --target_model "${tm}" \
        --alpha "-${alpha}" \
        --beta "${beta}" \
        --save_path "${path_to_unlearned_model}"

      echo "Unlearning done for base=${model}, alpha=${alpha}, beta=${beta}"

      CUDA_VISIBLE_DEVICES="${eval_devices}" python src/eval.py \
        experiment=eval/tofu/default.yaml \
        forget_split=${forget_split} \
        holdout_split=${holdout_split} \
        model=${model} \
        is_validation=true \
        task_name=${path_to_unlearned_model} \
        model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
        paths.output_dir=${path_to_unlearned_model}/evals_val \
        retain_logs_path=saves/eval/tofu_Olmo-2-7B-stage1-final_${retain_split}/TOFU_EVAL.json

      find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

      echo "Evaluation done for base=${model}, alpha=${alpha}, beta=${beta}"
    done
  done
done

# Cheatsheet:
#   ./scripts/experiments/Olmo/msa_unlearn.sh \
#     --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" --use_constant_lr
