#!/bin/bash

set -euo pipefail


forget_split="forget"
retain_split="retain"

msa_epochs=5

# Model family (e.g., Llama-3.2-1B or Llama-3.2-8B)
model_family="Llama-3.1-8B"
muse_model="saves/finetune/muse_books_llama3_full"
data_split="Books"

# Alpha/Beta defaults
alphas=(0.75 1.0 1.25 3.0)
betas=(0.0 0.5 1.0 1.5)

# What to run: comma- or space-separated from {instruct,tofu,pretrained}
run_what=(target pretrained)

# Device controls
unlearn_devices="0,1"
eval_devices="0"


usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --model_family NAME          Llama model family (default: ${model_family})
  --muse_model PATH            Path or HF id for the MUSE model (default derives from model_family)
  --alphas "a1 a2 ..."        Space/comma-separated alpha values (default: ${alphas[*]})
  --betas "b1 b2 ..."         Space/comma-separated beta values (default: ${betas[*]})
  --run "modes"               Which to run: instruct,tofu,pretrained (default: ${run_what[*]})
  --unlearn_devices DEVICES    CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES       CUDA devices for eval step (default: ${eval_devices})
  -h | --help                  Show this help

EOF
}

# Parse CLI args

while [[ $# -gt 0 ]]; do
  case "$1" in
    --model_family) model_family="$2"; shift 2;;
    --muse_model) muse_model="$2"; shift 2;;
    --alphas)
      s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --betas)
      s="$2"; s="${s//,/ }"; read -r -a betas <<< "$s"; shift 2;;
    --run)
      s="$2"; s="${s//,/ }"; read -r -a run_what <<< "$s"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done


path_to_base_pretrained_model="meta-llama/${model_family}"
base_pretrained_model="${model_family}"

run_msa_eval() {
  local mode="$1"  # one of: instruct | tofu | pretrained

  local out_dir=""
  local forget_model=""
  local retain_model=""
  local base_model_path=""

  case "$mode" in
    target)
      out_dir="${forget_split}_${model_family}_muse_books"
      model_name="muse_books"
      forget_model="saves/finetune/msa_muse_books_${model_family}_muse_books_${forget_split}"
      retain_model="saves/finetune/msa_muse_books_${model_family}_muse_books_${retain_split}"
      base_model_path="${muse_model}"
      ;;
    pretrained)
      out_dir="${forget_split}_${model_family}_pretrained"
      model_name="$base_pretrained_model"
      forget_model="saves/finetune/msa_muse_books_${model_family}_${forget_split}"
      retain_model="saves/finetune/msa_muse_books_${model_family}_${retain_split}"
      base_model_path="$path_to_base_pretrained_model"
      ;;
    *) echo "Unknown mode: $mode"; exit 1;;
  esac

  # If constant LR is requested, suffix the output directory as well
  out_dir="${out_dir}"

  echo "Running MSA for mode=$mode | forget=$forget_split retain=$retain_split | model_family=$model_family"
  echo "Alphas: ${alphas[*]} | Betas: ${betas[*]}"

  for alpha in "${alphas[@]}"; do
    for beta in "${betas[@]}"; do
      path_to_unlearned_model="saves/unlearn_msa/${out_dir}/msa_alpha_${alpha}_beta_${beta}"
      
      echo "Doing unlearning for mode=${mode}, alpha=${alpha}, beta=${beta}"
      
      # Skip if eval results already exist for this setting
      eval_output_dir="${path_to_unlearned_model}/evals"
      if [[ -d "${eval_output_dir}" ]]; then
        echo "[SKIP] Found existing results at ${eval_output_dir}. Skipping unlearning and evaluation for mode=${mode}, alpha=${alpha}, beta=${beta}."
        continue
      fi

      CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
        --base_model "$base_model_path" \
        --forget_model "$forget_model" \
        --retain_model "$retain_model" \
        --target_model "$muse_model" \
        --alpha "-$alpha" \
        --beta "$beta" \
        --save_path "$path_to_unlearned_model"

      echo "Unlearning done for mode=${mode}, alpha=${alpha}, beta=${beta}"

      
      CUDA_VISIBLE_DEVICES=${eval_devices} python src/eval.py \
          experiment=eval/muse/default.yaml \
          data_split=${data_split} \
          task_name=${path_to_unlearned_model} \
          model=${model_family} \
          model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
          paths.output_dir=${path_to_unlearned_model}/evals \
          retain_logs_path=saves/finetune/muse_books_llama3_retain/evals/MUSE_EVAL.json
        
      find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

      echo "Evaluation done for mode=${mode}, alpha=${alpha}, beta=${beta}"
    done
  done
}

# Run requested modes
for mode in "${run_what[@]}"; do
  run_msa_eval "$mode"
done

#
# How to run (cheatsheet):
#   - Defaults (all modes):
#       ./scripts/eval_msa/msa_unlearn.sh
#   - Only instruct, custom splits and params:
#       ./scripts/eval_msa/msa_unlearn.sh \
#         --run instruct \
#         --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
#         --model_family Llama-3.2-1B \
#         --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
#         --msa_epochs 5
#   - 8B model, tofu+pretrained:
#       ./scripts/eval_msa/msa_unlearn.sh --model_family Llama-3.2-8B --run "tofu,pretrained"
