#!/bin/bash

set -euo pipefail

export MASTER_PORT=$(python -c "import socket; s=socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()")
echo "Master Port: $MASTER_PORT"

############################################
# Defaults (OLMo stage-1 models)
############################################
per_device_train_batch_size=4
forget_split="forget10"
retain_split="retain90"
holdout_split="holdout10"
msa_epochs=5
train_devices="0,1"
accelerate_config="configs/accelerate/default_config.yaml"

# Valid OLMo model names (no .yaml suffix), e.g.: Olmo-2-1B-stage1-210B, ... , final
# Pass your own via --model_names "Olmo-2-1B-stage1-210B Olmo-2-1B-stage1-final"
# model_names=(
#   # "Olmo-2-1B-stage1-210B"
#   # "Olmo-2-1B-stage1-504B"
#   # "Olmo-2-1B-stage1-2098B"
#   # "Olmo-2-1B-stage1-3146B"
#   "Olmo-2-1B-stage1-3566B"
#   "Olmo-2-1B-stage1-final"
# )

model_names=(
  "Olmo-2-7B-stage1-final"
  "Olmo-2-7B-stage1-3859B"
  "Olmo-2-7B-stage1-3691B"
  "Olmo-2-7B-stage1-2207B"
  "Olmo-2-7B-stage1-500B"
)

# Optional flags
use_constant_lr=0

usage() {
  cat <<EOF
Usage: $(basename "$0") [options]

Options:
  --forget_split NAME          Forget split (default: ${forget_split})
  --retain_split NAME          Retain split (default: ${retain_split})
  --holdout_split NAME         Holdout split (default: ${holdout_split})
  --model_names "NAMES"        Space/comma-separated OLMo model names (default: ${model_names[*]})
  --epochs N                   Number of finetune epochs (default: ${msa_epochs})
  --batch_size N               Per-device batch size (default: ${per_device_train_batch_size})
  --accelerate_config PATH     Accelerate config file (default: ${accelerate_config})
  --train_devices DEVICES      CUDA devices (default: ${train_devices})
  --use_constant_lr            Use constant LR scheduler and suffix task_name
  -h | --help                  Show this help

Examples:
  # Run for a subset of OLMo stage-1 checkpoints
  ./scripts/experiments/Olmo/msa_finetune.sh \
    --model_names "Olmo-2-1B-stage1-210B,Olmo-2-1B-stage1-final" \
    --epochs 5 --batch_size 4
EOF
}

# Parse CLI
while [[ $# -gt 0 ]]; do
  case "$1" in
    --forget_split) forget_split="$2"; shift 2;;
    --retain_split) retain_split="$2"; shift 2;;
    --holdout_split) holdout_split="$2"; shift 2;;
    --model_names)
      s="$2"; s="${s//,/ }"; read -r -a model_names <<< "$s"; shift 2;;
    --epochs) msa_epochs="$2"; shift 2;;
    --batch_size) per_device_train_batch_size="$2"; shift 2;;
    --accelerate_config) accelerate_config="$2"; shift 2;;
    --train_devices) train_devices="$2"; shift 2;;
    --use_constant_lr) use_constant_lr=1; shift 1;;
    -h|--help) usage; exit 0;;
    *) echo "Unknown arg: $1"; usage; exit 1;;
  esac
done

echo "Finetuning MSA components for OLMo models | epochs=$msa_epochs | batch_size=$per_device_train_batch_size"

  # Forget finetune
  task_suffix=""
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    task_suffix="_constant_lr"
  fi

for model_cfg in "${model_names[@]}"; do
  tag="${model_cfg}"

  echo "[OLMo] Finetune MSA for ${model_cfg}"

  # Compute task suffix once
  task_suffix=""
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    task_suffix="_constant_lr"
  fi

  # Forget MSA
  cmd_f=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/forget.yaml \
    task_name=msa_${tag}_${forget_split}${task_suffix} \
    model=${model_cfg} \
    data/datasets@data.train=TOFU_QA_forget \
    data.train.TOFU_QA_forget.args.hf_args.name=${forget_split} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_epochs} \
    trainer.args.eval_strategy="no")
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    cmd_f+=(trainer.args.lr_scheduler_type=constant)
  fi
  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd_f[@]}"

  # Retain MSA
  cmd_r=(accelerate launch --config_file "${accelerate_config}" --main_process_port "$MASTER_PORT" \
    src/train.py experiment=finetune/msa/retain.yaml \
    task_name=msa_${tag}_${retain_split}${task_suffix} \
    model=${model_cfg} \
    data/datasets@data.train=TOFU_QA_retain \
    data.train.TOFU_QA_retain.args.hf_args.name=${retain_split} \
    trainer.args.per_device_train_batch_size=${per_device_train_batch_size} \
    trainer.args.ddp_find_unused_parameters=true \
    trainer.args.gradient_checkpointing=true \
    trainer.args.num_train_epochs=${msa_epochs} \
    trainer.args.eval_strategy="no")
  if [[ "${use_constant_lr}" -eq 1 ]]; then
    cmd_r+=(trainer.args.lr_scheduler_type=constant)
  fi
  CUDA_VISIBLE_DEVICES=${train_devices} "${cmd_r[@]}"
done

# How to run (cheatsheet):
#   ./scripts/experiments/Olmo/msa_finetune.sh
#   ./scripts/experiments/Olmo/msa_finetune.sh --model_names "Olmo-2-1B-stage1-210B,Olmo-2-1B-stage1-final" \
#       --epochs 3 --batch_size 2 --train_devices 0,1
