#!/bin/bash

set -euo pipefail

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

# Defaults
per_device_train_batch_size=4
forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"
msa_epochs=5
model_family="Llama-3.2-1B"
train_devices="0,1"
accelerate_config="configs/accelerate/default_config.yaml"


# Which modes to finetune: instruct, tofu, pretrained
run_what=(instruct tofu pretrained)

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --model_family NAME          Llama model family (default: ${model_family})
  --tofu_model PATH            Path or HF id for the TOFU model (default derives from model_family)
  --epochs N                   Number of finetune epochs (default: ${msa_epochs})
  --batch_size N               Per-device batch size (default: ${per_device_train_batch_size})
  --accelerate_config PATH     Accelerate config file (default: ${accelerate_config})
  --train_devices DEVICES      CUDA devices (default: ${train_devices})
  --run "modes"               Which to run: instruct,tofu,pretrained (default: ${run_what[*]})
  -h | --help                  Show this help

Examples:
  # Defaults for all modes
  ./scripts/msa_tofu/msa_finetune.sh

  # Only tofu with custom model and epochs
  ./scripts/msa_tofu/msa_finetune.sh \
    --run tofu \
    --model_family Llama-3.2-1B \
    --tofu_model open-unlearning/tofu_Llama-3.2-1B-Instruct_full \
    --epochs 5 --batch_size 4
EOF
}

# Parse CLI
tofu_model_cli=""
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --model_family) model_family="$2"; shift 2;;
    --tofu_model) tofu_model_cli="$2"; shift 2;;
    --epochs) msa_epochs="$2"; shift 2;;
    --batch_size) per_device_train_batch_size="$2"; shift 2;;
    --accelerate_config) accelerate_config="$2"; shift 2;;
    --train_devices) train_devices="$2"; shift 2;;
    --run) s="$2"; s="${s//,/ }"; read -r -a run_what <<< "$s"; shift 2;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

# Derived model identifiers
instruct_model="${model_family}-Instruct"
path_to_instruct_model="meta-llama/${instruct_model}"
path_to_base_pretrained_model="meta-llama/${model_family}"
base_pretrained_model="${model_family}"
if [[ -n "${tofu_model_cli}" ]]; then
  path_to_tofu_model="${tofu_model_cli}"
else
  path_to_tofu_model="open-unlearning/tofu_${instruct_model}_full"
fi

run_finetune_mode() {
  local mode="$1"  # instruct | tofu | pretrained

  local model_name=""
  local pretrained_path=""
  local method_tag=""

  case "$mode" in
    instruct)
      model_name="$instruct_model"
      pretrained_path=""
      method_tag="${instruct_model}"
      ;;
    tofu)
      model_name="$instruct_model"
      pretrained_path="$path_to_tofu_model"
      method_tag="${instruct_model}_tofu"
      ;;
    pretrained)
      model_name="$base_pretrained_model"
      pretrained_path="$path_to_base_pretrained_model"
      method_tag="${base_pretrained_model}"
      ;;
    *) echo "Unknown mode: $mode"; exit 1;;
  esac

  echo "Finetuning MSA components for mode=$mode | model_family=$model_family | epochs=$msa_epochs | batch_size=$per_device_train_batch_size"

  # Forget finetune
  task_suffix=""
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    task_suffix="_constant_lr"
  fi
  cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/forget.yaml \
    task_name=msa_tofu_${method_tag}_${forget_split}${task_suffix} \
    model=${model_name} \
    data/datasets@data.train=TOFU_QA_forget \
    data.train.TOFU_QA_forget.args.hf_args.name=${forget_split} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_epochs} \
    trainer.args.eval_strategy="no")
  if [[ -n "${pretrained_path}" ]]; then
    cmd+=(model.model_args.pretrained_model_name_or_path=${pretrained_path})
  fi
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    cmd+=(trainer.args.lr_scheduler_type=constant)
  fi
  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"

  # Retain finetune
  cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/retain.yaml \
    task_name=msa_tofu_${method_tag}_${retain_split}${task_suffix} \
    model=${model_name} \
    data/datasets@data.train=TOFU_QA_retain \
    data.train.TOFU_QA_retain.args.hf_args.name=${retain_split} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_epochs} \
    trainer.args.eval_strategy="no")
  if [[ -n "${pretrained_path}" ]]; then
    cmd+=(model.model_args.pretrained_model_name_or_path=${pretrained_path})
  fi
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    cmd+=(trainer.args.lr_scheduler_type=constant)
  fi
  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"
}

for mode in "${run_what[@]}"; do
  run_finetune_mode "$mode"
done

# How to run (cheatsheet):
#   ./scripts/msa_tofu/msa_finetune.sh
#   ./scripts/msa_tofu/msa_finetune.sh --run tofu --model_family Llama-3.2-8B --epochs 3 --batch_size 2 \
#       --tofu_model open-unlearning/tofu_Llama-3.2-8B-Instruct_full --train_devices 0,1
