#!/bin/bash

set -euo pipefail


forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

msa_epochs=5

# Model family (e.g., Llama-3.2-1B or Llama-3.2-8B)
model_family="Llama-3.2-1B"

# Alpha/Beta (single run values; pass via CLI)
alpha=""
beta=""

# What to run: comma- or space-separated from {instruct,tofu,pretrained}
run_what=(instruct tofu pretrained)

# Device controls
unlearn_devices="0,1"
eval_devices="0"

# Optional flags
use_constant_lr=1

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --model_family NAME          Llama model family (default: ${model_family})
  --tofu_model PATH            Path or HF id for the target/TOFU model (default derives from model_family)
  --alpha VALUE               Alpha value (required)
  --beta VALUE                Beta value (required)
  --msa_epochs N               Number of MSA epochs (default: ${msa_epochs})
  --run "modes"               Which to run: instruct,tofu,pretrained (default: ${run_what[*]})
  --unlearn_devices DEVICES    CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES       CUDA devices for eval step (default: ${eval_devices})
  --use_constant_lr            Use constant LR finetuned MSA models (suffix paths)
  -h | --help                  Show this help

Examples:
  # Run all (instruct, tofu, pretrained) with defaults
  ./scripts/eval_msa/msa_unlearn.sh

  # Run only instruct with custom splits and params
  ./scripts/eval_msa/msa_unlearn.sh \
    --run instruct \
    --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
    --model_family Llama-3.2-1B \
    --tofu_model open-unlearning/tofu_Llama-3.2-1B-Instruct_full \
    --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
    --msa_epochs 5

  # Switch to 8B and run tofu + pretrained
  ./scripts/eval_msa/msa_unlearn.sh --model_family Llama-3.2-8B --run "tofu,pretrained"
EOF
}

# Parse CLI args
tofu_model_cli=""
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --model_family) model_family="$2"; shift 2;;
    --tofu_model) tofu_model_cli="$2"; shift 2;;
    --alpha) alpha="$2"; shift 2;;
    --beta) beta="$2"; shift 2;;
    --msa_epochs) msa_epochs="$2"; shift 2;;
    --run)
      s="$2"; s="${s//,/ }"; read -r -a run_what <<< "$s"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# Derived model names/paths based on MODEL_SIZE
instruct_model="${model_family}-Instruct"
path_to_instruct_model="meta-llama/${instruct_model}"
path_to_base_pretrained_model="meta-llama/${model_family}"
base_pretrained_model="${model_family}"
if [[ -n "${tofu_model_cli}" ]]; then
  path_to_tofu_model="${tofu_model_cli}"
else
  path_to_tofu_model="open-unlearning/tofu_${instruct_model}_full"
fi

run_msa_eval() {
  local mode="$1"  # one of: instruct | tofu | pretrained

  local out_dir=""
  local forget_model=""
  local retain_model=""
  local base_model_path=""
  local task_suffix=""

  if [[ "${use_constant_lr}" -eq 1 ]]; then
    task_suffix="_constant_lr"
  fi

  case "$mode" in
    instruct)
      out_dir="${forget_split}_instruct"
      forget_model="saves/obtain_msa/msa_tofu_${instruct_model}_${forget_split}${task_suffix}"
      retain_model="saves/obtain_msa_split/msa_tofu_${instruct_model}_${retain_split}${task_suffix}"
      base_model_path="$path_to_instruct_model"
      ;;
    tofu)
      out_dir="${forget_split}_tofu"
      forget_model="saves/obtain_msa/msa_tofu_${instruct_model}_tofu_${forget_split}${task_suffix}"
      retain_model="saves/obtain_msa_split/msa_tofu_${instruct_model}_tofu_${retain_split}${task_suffix}"
      base_model_path="$path_to_tofu_model"
      ;;
    pretrained)
      out_dir="${forget_split}_pretrained"
      forget_model="saves/obtain_msa/msa_tofu_${base_pretrained_model}_${forget_split}${task_suffix}"
      retain_model="saves/obtain_msa_split/msa_tofu_${base_pretrained_model}_${retain_split}${task_suffix}"
      base_model_path="$path_to_base_pretrained_model"
      ;;
    *)
      echo "Unknown mode: $mode"; exit 1;;
  esac

  # If constant LR is requested, suffix the output directory as well
  out_dir="${out_dir}${task_suffix}"

  echo "Running MSA for mode=$mode | forget=$forget_split retain=$retain_split holdout=$holdout_split | model_family=$model_family"
  echo "Alpha: ${alpha} | Beta: ${beta} | msa_epochs=${msa_epochs}"

  if [[ -z "${alpha}" || -z "${beta}" ]]; then
    echo "[ERROR] --alpha and --beta are required." >&2
    exit 1
  fi

  path_to_unlearned_model="saves/unlearn_msa/${out_dir}/msa_alpha_${alpha}_beta_${beta}"

  # Skip if test eval results already exist for this setting
  eval_output_dir="${path_to_unlearned_model}/evals"
  if [[ -d "${eval_output_dir}" ]]; then
    echo "[SKIP] Found existing results at ${eval_output_dir}. Skipping unlearning and evaluation for mode=${mode}, alpha=${alpha}, beta=${beta}."
    return
  fi

  CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
    --base_model "$base_model_path" \
    --forget_model "$forget_model" \
    --retain_model "$retain_model" \
    --target_model "$path_to_tofu_model" \
    --alpha "-$alpha" \
    --beta "$beta" \
    --save_path "$path_to_unlearned_model"

  echo "Unlearning done for mode=${mode}, alpha=${alpha}, beta=${beta}"

  CUDA_VISIBLE_DEVICES="${eval_devices}" python src/eval.py \
    experiment=eval/tofu/default.yaml \
    forget_split=${forget_split} \
    holdout_split=${holdout_split} \
    model=${instruct_model} \
    is_validation=false \
    task_name=${path_to_unlearned_model} \
    model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
    paths.output_dir=${path_to_unlearned_model}/evals \
    retain_logs_path=saves/eval/tofu_${instruct_model}_${retain_split}/TOFU_EVAL.json

  find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

  echo "[TEST] Evaluation done for mode=${mode}, alpha=${alpha}, beta=${beta}"
}

# Run requested modes
for mode in "${run_what[@]}"; do
  run_msa_eval "$mode"
done

#
# How to run (cheatsheet):
#   - Defaults (all modes):
#       ./scripts/eval_msa/msa_unlearn.sh
#   - Only instruct, custom splits and params:
#       ./scripts/eval_msa/msa_unlearn.sh \
#         --run instruct \
#         --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
#         --model_family Llama-3.2-1B \
#         --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
#         --msa_epochs 5
#   - 8B model, tofu+pretrained:
#       ./scripts/eval_msa/msa_unlearn.sh --model_family Llama-3.2-8B --run "tofu,pretrained"
