#!/usr/bin/env bash
set -euo pipefail

script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
converter_py="${script_dir}/hf_checkpoint_converter.py"

if [[ $# -ne 1 ]]; then
  echo "Usage: $0 <directory>" >&2
  exit 1
fi

root_dir="$(realpath "$1")"

if [[ ! -d "$root_dir" ]]; then
  echo "Error: '$root_dir' is not a directory" >&2
  exit 1
fi

if [[ ! -f "$converter_py" ]]; then
  echo "Error: converter not found at $converter_py" >&2
  exit 1
fi

has_dynamic_modules() {
  local d="$1"
  [[ -f "$d/configuration_sequence_mixing.py" && -f "$d/modeling_sequence_mixing.py" ]]
}

is_checkpoint_dir() {
  local d="$1"
  [[ -f "$d/model.safetensors" || -f "$d/pytorch_model.bin" ]]
}

convert_checkpoint() {
  local ckpt_dir="$1"

  if has_dynamic_modules "$ckpt_dir"; then
    echo "[SKIP] Modules already present in $ckpt_dir"
    return 0
  fi

  echo "[CONVERT] In-place: $ckpt_dir"
  python "$converter_py" "$ckpt_dir"
}

process_dir() {
  local dir="$1"

  # If the provided dir itself is a checkpoint, convert just it
  if is_checkpoint_dir "$dir"; then
    convert_checkpoint "$dir"
    return 0
  fi

  # Otherwise, find all nested checkpoint dirs (recursively)
  mapfile -t ckpt_dirs < <(find "$dir" -type f \( -name 'model.safetensors' -o -name 'pytorch_model.bin' \) -printf '%h\n' | sort -u)

  if [[ ${#ckpt_dirs[@]} -eq 0 ]]; then
    echo "No checkpoints found under $dir"
    return 0
  fi

  for c in "${ckpt_dirs[@]}"; do
    convert_checkpoint "$c"
  done
}

process_dir "$root_dir"


