#!/bin/bash -l
set -euo pipefail

# --- repos ---
REPOS=(
  "roneneldan/TinyStories-1M"
  "roneneldan/TinyStories-8M"
  "roneneldan/TinyStories-33M"
  "openai-community/gpt2"
  "openai-community/gpt2-medium"
  "openai-community/gpt2-large"
  "google/gemma-3-1b-pt"
  "google/gemma-3-4b-pt"
  "google/gemma-3-12b-pt"
  "microsoft/Phi-4-mini-instruct"
  "mistralai/Mistral-7B-v0.1"
  "meta-llama/Llama-3.1-8B"
)

# Defaults (can be replaced per-flag)
declare -A TOKDEF=(
  [n]="5000"            # -n
  [prompt_tokens]="500" # --prompt-tokens
  [mcl]="400"           # -mcl
  [cis]="25"            # -cis
)

declare -A KNOWN_B=(
  # GPT-2 family
  [gpt2]="0.124"
  [gpt2-medium]="0.355"
  [gpt2-large]="0.774"
)

# Keep track of extra/unrecognized flags to forward verbatim
EXTRA_TOKEN_ARGS=()

usage() {
  cat <<EOF
Usage: $0 [TOKEN FLAGS]

Known token flags (replace defaults if provided):
  -n <int>                  (default: ${TOKDEF[n]})
  --prompt-tokens <int>     (default: ${TOKDEF[prompt_tokens]})
  -mcl <int>                (default: ${TOKDEF[mcl]})
  -cis <int>                (default: ${TOKDEF[cis]})

You may also use = forms, e.g. --prompt-tokens=1000 or -n=50000.
Any unrecognized flags are forwarded verbatim to the token args.
EOF
  exit 0
}

# --- parse CLI, replacing defaults when specified ---
while [[ $# -gt 0 ]]; do
  case "$1" in
    -h|--help) usage ;;

    # -n (with space/value or attached)
    -n)
      [[ $# -ge 2 ]] || { echo "Error: -n requires a value" >&2; exit 1; }
      TOKDEF[n]="$2"; shift 2; continue ;;
    -n=*)
      TOKDEF[n]="${1#-n=}"; shift; continue ;;
    -n[0-9]*)
      TOKDEF[n]="${1#-n}"; shift; continue ;;

    # --prompt-tokens (space or =)
    --prompt-tokens)
      [[ $# -ge 2 ]] || { echo "Error: --prompt-tokens requires a value" >&2; exit 1; }
      TOKDEF[prompt_tokens]="$2"; shift 2; continue ;;
    --prompt-tokens=*)
      TOKDEF[prompt_tokens]="${1#--prompt-tokens=}"; shift; continue ;;

    # -mcl (space or = or attached)
    -mcl)
      [[ $# -ge 2 ]] || { echo "Error: -mcl requires a value" >&2; exit 1; }
      TOKDEF[mcl]="$2"; shift 2; continue ;;
    -mcl=*)
      TOKDEF[mcl]="${1#-mcl=}"; shift; continue ;;
    -mcl[0-9]*)
      TOKDEF[mcl]="${1#-mcl}"; shift; continue ;;

    # -cis (space or = or attached)
    -cis)
      [[ $# -ge 2 ]] || { echo "Error: -cis requires a value" >&2; exit 1; }
      TOKDEF[cis]="$2"; shift 2; continue ;;
    -cis=*)
      TOKDEF[cis]="${1#-cis=}"; shift; continue ;;
    -cis[0-9]*)
      TOKDEF[cis]="${1#-cis}"; shift; continue ;;

    # everything else: forward verbatim to token args
    *)
      EXTRA_TOKEN_ARGS+=("$1"); shift; continue ;;
  esac
done

# Build final TOKEN_ARGS (known flags in stable order) + extras
TOKEN_ARGS=(
  -n "${TOKDEF[n]}"
  --prompt-tokens "${TOKDEF[prompt_tokens]}"
  -mcl "${TOKDEF[mcl]}"
  -cis "${TOKDEF[cis]}"
  "${EXTRA_TOKEN_ARGS[@]}"
)

# --- helpers ---
name_from_repo() { echo "${1##*/}"; }

params_in_billions() {
  # Returns a floating-point number in billions, or empty if unknown.
  local name="$1" bmatch mmatch val
  # Prefer explicit B/b (e.g., 7B, 1b)
  bmatch=$(grep -oE '([0-9]+([.][0-9]+)?)\s*[Bb]' <<<"$name" | tail -n1 || true)
  if [[ -n "${bmatch:-}" ]]; then
    val="${bmatch//[Bb ]/}"
    echo "$val"
    return 0
  fi
  # Then M/m (e.g., 33M, 774m) -> convert to billions
  mmatch=$(grep -oE '([0-9]+([.][0-9]+)?)\s*[Mm]' <<<"$name" | tail -n1 || true)
  if [[ -n "${mmatch:-}" ]]; then
    val="${mmatch//[Mm ]/}"
    # divide by 1000 using bc -l
    echo "scale=6; $val/1000" | bc -l
    return 0
  fi
  # Fallback: known map (e.g., gpt2, gpt2-medium, gpt2-large)
  if [[ -n "${KNOWN_B[$name]:-}" ]]; then
    echo "${KNOWN_B[$name]}"
    return 0
  fi
  echo ""  # unknown
}

batch_size_for() {
  local name="$1" numB
  numB="$(params_in_billions "$name")"
  # If unknown, keep a conservative default
  if [[ -z "${numB:-}" ]]; then
    echo 16
    return 0
  fi
  # Requirement: 64 for models < 1B
  if (( $(echo "$numB < 1" | bc -l) )); then
    echo 64
  elif (( $(echo "$numB < 4" | bc -l) )); then
    echo 32
  elif (( $(echo "$numB < 9" | bc -l) )); then
    echo 16
  else
    echo 8
  fi
}

# --- main loop ---
for MODEL_ID in "${REPOS[@]}"; do
  NAME="$(name_from_repo "$MODEL_ID")"
  BATCH_SIZE="$(batch_size_for "$NAME")"

  SH_NAME_ARG=( --sh-name "${NAME}.sh" )
  MODEL_ARGS=( --batch-size "${BATCH_SIZE}" --model-id "${MODEL_ID}" --output "dataset_collisions-${NAME}" )

  FULL_FLAG=()
  case "${MODEL_ID,,}" in
    *gpt*|*gemma*) FULL_FLAG=( --full ) ;;
  esac

  # Build and exec command (safe array invocation)
  CMD=( python3.11 slurm.py --sh-dir ./scripts/auto_generated/dataset_collisions )
  CMD+=( --py-path src/ablations/dataset_collisions.py -o dataset_exp --semi )
  CMD+=( "${SH_NAME_ARG[@]}" )
  CMD+=( "${TOKEN_ARGS[@]}" )
  CMD+=( "${MODEL_ARGS[@]}" "${FULL_FLAG[@]}" )

  echo "+ ${CMD[*]}"
  # "${CMD[@]}" > /dev/null 2>&1 &
done
