#!/bin/bash

set -euo pipefail

forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

msa_epochs=5

# Alpha/Beta defaults
alphas=(0.75 1.0 1.5 3.0)
betas=(0.75 1.0 1.25)

# Device controls
unlearn_devices="0,1"
eval_devices="0"

# Optional flags
use_constant_lr=0


model="Olmo-2-7B"
model_name="Olmo-2-7B-stage1-final"
use_constant_lr=0
tag="${model_name}_tofu"

# base_model_hf_path="allenai/OLMo-2-0425-1B"
base_model_hf_path="allenai/OLMo-2-1124-7B"


# Target model path (required)
path_to_tofu_model="saves/finetune/tofu_Olmo-2-7B-stage1-final_full"
target_model="saves/finetune/tofu_Olmo-2-7B-stage1-final_full"

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --target_model PATH          Target model path (required)
  --alphas "a1 a2 ..."        Space/comma-separated alpha values (default: ${alphas[*]})
  --betas "b1 b2 ..."         Space/comma-separated beta values (default: ${betas[*]})
  --msa_epochs N               Number of MSA epochs (default: ${msa_epochs})
  --unlearn_devices DEVICES    CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES       CUDA devices for eval step (default: ${eval_devices})
  --use_constant_lr            Use constant LR finetuned MSA models (suffix paths)
  -h | --help                  Show this help

Example:
  ./scripts/experiments/Olmo/msa_unlearn.sh \
    --target_model "saves/finetune/tofu_Olmo-2-1B-stage1-final_full_constant_lr" \
    --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
    --use_constant_lr
EOF
}

# Parse CLI args
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --target_model) target_model="$2"; shift 2;;
    --alphas) s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --betas) s="$2"; s="${s//,/ }"; read -r -a betas <<< "$s"; shift 2;;
    --msa_epochs) msa_epochs="$2"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done



task_suffix=""
if [[ "${use_constant_lr}" -eq 1 ]]; then
  task_suffix="_constant_lr"
fi

# Require target model argument
if [[ -z "${target_model}" ]]; then
  echo "[ERROR] --target_model is required (path to the target model)." >&2
  exit 1
fi



# Paths to forget/retain MSA finetuned models from msa_finetune.sh
forget_model="saves/obtain_msa/msa_${tag}_${forget_split}${task_suffix}"
retain_model="saves/obtain_msa_split/msa_${tag}_${retain_split}${task_suffix}"


# Target model: provided via CLI
tm="${target_model}"

# Output container for this base model's unlearned blends
out_dir="${forget_split}_${tag}${task_suffix}"

echo "Alphas: ${alphas[*]} | Betas: ${betas[*]} | epochs=${msa_epochs}"

for alpha in "${alphas[@]}"; do
  for beta in "${betas[@]}"; do
    path_to_unlearned_model="saves/unlearn_msa/${out_dir}/msa_alpha_${alpha}_beta_${beta}"

    # Skip if eval results already exist for this setting
    eval_output_dir="${path_to_unlearned_model}/evals_val"
    if [[ -d "${eval_output_dir}" ]]; then
      echo "[SKIP] Found existing results at ${eval_output_dir}. Skipping alpha=${alpha}, beta=${beta}."
      continue
    fi


    CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
      --base_model "${tm}" \
      --forget_model "${forget_model}" \
      --retain_model "${retain_model}" \
      --target_model "${tm}" \
      --alpha "-${alpha}" \
      --beta "${beta}" \
      --save_path "${path_to_unlearned_model}"

    echo "Unlearning done for base=${model}, alpha=${alpha}, beta=${beta}"

    CUDA_VISIBLE_DEVICES="${eval_devices}" python src/eval.py \
      experiment=eval/tofu/default.yaml \
      forget_split=${forget_split} \
      holdout_split=${holdout_split} \
      model=${model} \
      is_validation=true \
      task_name=${path_to_unlearned_model} \
      model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
      paths.output_dir=${path_to_unlearned_model}/evals_val \
      retain_logs_path=saves/eval/tofu_Olmo-2-7B-stage1-final_${retain_split}/TOFU_EVAL.json

    find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

    echo "Evaluation done for base=${model}, alpha=${alpha}, beta=${beta}"
  done
done


# Cheatsheet:
#   ./scripts/experiments/Olmo/msa_unlearn.sh \
#     --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" --use_constant_lr
