#!/bin/bash

set -euo pipefail


forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"

msa_epochs=5

# Model family (e.g., Llama-3.2-1B or Llama-3.2-8B)
model_family="Llama-3.2-1B"

# Alpha/Beta defaults
alphas=(0.75 1.0 1.25)
betas=(0.0 0.75 1.0 1.25)

# What to run: comma- or space-separated from {instruct,tofu,pretrained}
run_what=(instruct tofu pretrained)

# Device controls
unlearn_devices="0,1"
eval_devices="0"

# Optional flags
use_constant_lr=0

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --model_family NAME          Llama model family (default: ${model_family})
  --tofu_model PATH            Path or HF id for the target/TOFU model (default derives from model_family)
  --alphas "a1 a2 ..."        Space/comma-separated alpha values (default: ${alphas[*]})
  --betas "b1 b2 ..."         Space/comma-separated beta values (default: ${betas[*]})
  --msa_epochs N               Number of MSA epochs (default: ${msa_epochs})
  --run "modes"               Which to run: instruct,tofu,pretrained (default: ${run_what[*]})
  --unlearn_devices DEVICES    CUDA devices for unlearn step (default: ${unlearn_devices})
  --eval_devices DEVICES       CUDA devices for eval step (default: ${eval_devices})
  --use_constant_lr            Use constant LR finetuned MSA models (suffix paths)
  -h | --help                  Show this help

Examples:
  # Run all (instruct, tofu, pretrained) with defaults
  ./scripts/eval_msa/msa_unlearn.sh

  # Run only instruct with custom splits and params
  ./scripts/eval_msa/msa_unlearn.sh \
    --run instruct \
    --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
    --model_family Llama-3.2-1B \
    --tofu_model open-unlearning/tofu_Llama-3.2-1B-Instruct_full \
    --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
    --msa_epochs 5

  # Switch to 8B and run tofu + pretrained
  ./scripts/eval_msa/msa_unlearn.sh --model_family Llama-3.2-8B --run "tofu,pretrained"
EOF
}

# Parse CLI args
tofu_model_cli=""
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --model_family) model_family="$2"; shift 2;;
    --tofu_model) tofu_model_cli="$2"; shift 2;;
    --alphas)
      s="$2"; s="${s//,/ }"; read -r -a alphas <<< "$s"; shift 2;;
    --betas)
      s="$2"; s="${s//,/ }"; read -r -a betas <<< "$s"; shift 2;;
    --msa_epochs) msa_epochs="$2"; shift 2;;
    --run)
      s="$2"; s="${s//,/ }"; read -r -a run_what <<< "$s"; shift 2;;
    --unlearn_devices) unlearn_devices="$2"; shift 2;;
    --eval_devices) eval_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# Derived model names/paths based on MODEL_SIZE
instruct_model="${model_family}-Instruct"
path_to_instruct_model="meta-llama/${instruct_model}"
path_to_base_pretrained_model="meta-llama/${model_family}"
base_pretrained_model="${model_family}"
if [[ -n "${tofu_model_cli}" ]]; then
  path_to_tofu_model="${tofu_model_cli}"
else
  path_to_tofu_model="open-unlearning/tofu_${instruct_model}_full"
fi

run_msa_eval() {
  local mode="$1"  # one of: instruct | tofu | pretrained

  local out_dir=""
  local forget_model=""
  local retain_model=""
  local base_model_path=""
  local task_suffix=""

  if [[ "${use_constant_lr}" -eq 1 ]]; then
    task_suffix="_constant_lr"
  fi

  case "$mode" in
    instruct)
      out_dir="${forget_split}_instruct"
      forget_model="saves/obtain_msa/msa_tofu_${instruct_model}_${forget_split}${task_suffix}"
      retain_model="saves/obtain_msa_split/msa_tofu_${instruct_model}_${retain_split}${task_suffix}"
      base_model_path="$path_to_instruct_model"
      ;;
    tofu)
      out_dir="${forget_split}_tofu"
      forget_model="saves/obtain_msa/msa_tofu_${instruct_model}_tofu_${forget_split}${task_suffix}"
      retain_model="saves/obtain_msa_split/msa_tofu_${instruct_model}_tofu_${retain_split}${task_suffix}"
      base_model_path="$path_to_tofu_model"
      ;;
    pretrained)
      out_dir="${forget_split}_pretrained"
      forget_model="saves/obtain_msa/msa_tofu_${base_pretrained_model}_${forget_split}${task_suffix}"
      retain_model="saves/obtain_msa_split/msa_tofu_${base_pretrained_model}_${retain_split}${task_suffix}"
      base_model_path="$path_to_base_pretrained_model"
      ;;
    *)
      echo "Unknown mode: $mode"; exit 1;;
  esac

  # If constant LR is requested, suffix the output directory as well
  out_dir="${out_dir}${task_suffix}"

  echo "Running MSA for mode=$mode | forget=$forget_split retain=$retain_split holdout=$holdout_split | model_family=$model_family"
  echo "Alphas: ${alphas[*]} | Betas: ${betas[*]} | msa_epochs=${msa_epochs}"

  for alpha in "${alphas[@]}"; do
    for beta in "${betas[@]}"; do
      path_to_unlearned_model="saves/unlearn_msa/${out_dir}/msa_alpha_${alpha}_beta_${beta}"

      # Skip if eval results already exist for this setting
      eval_output_dir="${path_to_unlearned_model}/evals_val"
      if [[ -d "${eval_output_dir}" ]]; then
        echo "[SKIP] Found existing results at ${eval_output_dir}. Skipping unlearning and evaluation for mode=${mode}, alpha=${alpha}, beta=${beta}."
        continue
      fi

      CUDA_VISIBLE_DEVICES="${unlearn_devices}" python src/utils/msa_unlearn.py \
        --base_model "$base_model_path" \
        --forget_model "$forget_model" \
        --retain_model "$retain_model" \
        --target_model "$path_to_tofu_model" \
        --alpha "-$alpha" \
        --beta "$beta" \
        --save_path "$path_to_unlearned_model"

      echo "Unlearning done for mode=${mode}, alpha=${alpha}, beta=${beta}"

      CUDA_VISIBLE_DEVICES="${eval_devices}" python src/eval.py \
        experiment=eval/tofu/default.yaml \
        forget_split=${forget_split} \
        holdout_split=${holdout_split} \
        model=${instruct_model} \
        is_validation=true \
        task_name=${path_to_unlearned_model} \
        model.model_args.pretrained_model_name_or_path=${path_to_unlearned_model} \
        paths.output_dir=${path_to_unlearned_model}/evals_val \
        retain_logs_path=saves/eval/tofu_${instruct_model}_${retain_split}/TOFU_EVAL.json

      find "${path_to_unlearned_model}" -type f -name '*.safetensors' -exec rm -v {} +

      echo "Evaluation done for mode=${mode}, alpha=${alpha}, beta=${beta}"
    done
  done
}

# Run requested modes
for mode in "${run_what[@]}"; do
  run_msa_eval "$mode"
done

#
# How to run (cheatsheet):
#   - Defaults (all modes):
#       ./scripts/eval_msa/msa_unlearn.sh
#   - Only instruct, custom splits and params:
#       ./scripts/eval_msa/msa_unlearn.sh \
#         --run instruct \
#         --forget_split forget10 --retain_split retain90 --holdout_split holdout10 \
#         --model_family Llama-3.2-1B \
#         --alphas "0.75 1.0 1.25" --betas "0.0 0.75 1.0 1.25" \
#         --msa_epochs 5
#   - 8B model, tofu+pretrained:
#       ./scripts/eval_msa/msa_unlearn.sh --model_family Llama-3.2-8B --run "tofu,pretrained"
