#!/bin/bash
set -euo pipefail

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

# Defaults
per_device_train_batch_size=4
gradient_accumulation_steps=8

forget_split="forget"
retain_split="retain"

msa_forget_epochs=10
msa_retain_epochs=4

train_devices="0,1"
accelerate_config="configs/accelerate/default_config.yaml"

path_to_muse_model="saves/finetune/muse_books_full_olmo2"

# Candidate models + tags
models=(
  # "Olmo-2-7B-stage1-final"
  # "Olmo-2-7B-stage1-3859B"
  # "Olmo-2-7B-stage1-3691B"
  # "Olmo-2-7B-stage1-2207B"
  "Olmo-2-7B"
)

method_tags=(
  # "final"
  # "3859B"
  # "3691B"
  # "2207B"
  "muse_books"
)

usage() {
  cat <<EOF
Usage: $(basename "$0") --run TAG [options]

Where TAG is one of:
  ${method_tags[*]}

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --accelerate_config PATH     Accelerate config file (default: ${accelerate_config})
  --train_devices DEVICES      CUDA devices for training (default: ${train_devices})
  -h | --help                  Show this help

Examples:
  # Run finetuning for the "final" model
  bash $(basename "$0") --run final

  # Run for 3691B checkpoint with custom devices
  bash $(basename "$0") --run 3691B --train_devices 2,3
EOF
}

# Parse CLI
run_tags=()
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --accelerate_config) accelerate_config="$2"; shift 2;;
    --train_devices) train_devices="$2"; shift 2;;
    --run)
      s="$2"; s="${s//,/ }"; read -r -a run_tags <<< "$s"; shift 2;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

if [[ ${#run_tags[@]} -eq 0 ]]; then
  echo "Error: You must specify at least one --run TAG" >&2
  usage
  exit 1
fi

run_finetune_tag() {
  local tag="$1"
  local model_name=""

  # Find model corresponding to tag
  for i in "${!method_tags[@]}"; do
    if [[ "${method_tags[$i]}" == "$tag" ]]; then
      model_name="${models[$i]}"
      break
    fi
  done

  if [[ -z "$model_name" ]]; then
    echo "Error: Unknown tag '$tag'. Must be one of: ${method_tags[*]}" >&2
    exit 1
  fi

  echo "Finetuning MSA components | tag=$tag | model=$model_name | forget_epochs=$msa_forget_epochs | retain_epochs=$msa_retain_epochs"

  # Forget finetune
  cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/forget_books.yaml \
    task_name=msa_muse_books_${tag}_${forget_split} \
    data_split="Books" \
    model=${model_name} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_forget_epochs} \
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
    model.model_args.pretrained_model_name_or_path=${path_to_muse_model} \
    trainer.args.eval_strategy="no")

  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"

  # Retain finetune
  cmd=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/retain_books.yaml \
    task_name=msa_muse_books_${tag}_${retain_split} \
    model=${model_name} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_retain_epochs} \
    trainer.args.gradient_accumulation_steps=${gradient_accumulation_steps} \
    model.model_args.pretrained_model_name_or_path=${path_to_muse_model} \
    trainer.args.eval_strategy="no")

  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd[@]}"
}

# Run requested tags
for tag in "${run_tags[@]}"; do
  run_finetune_tag "$tag"
done

# /msa_finetune.sh --run final
# bash scripts/msa_finetune.sh --run "2207B 3859B 3691B final"
