#!/bin/bash

# ------------------------------------------------------------------------------
# Pre-compute average alignment-metric ratios per *module type*.
# Mirrors the structure of `scripts/finetune.sh` but only runs the
# lightweight `src/precompute_module_type_ratios.py` helper.
# ------------------------------------------------------------------------------

# Environment-configurable knobs ------------------------------------------------
MODEL_NAME_OR_PATH=${MODEL_NAME_OR_PATH:-"google/gemma-3-4b-it"}

# Dataset (interleaved list: <dataset1> <frac_or_num> ...).  By default we use
# the same MetaMathQA dataset as the finetuning example.
DATASET_MIXER_LIST=(
  "meta-math/MetaMathQA" "1.0"
)
DATASET_SPLITS=("train")

# Alignment-metric parameters
SAMPLE_SIZE=${SAMPLE_SIZE:-50}     # Number of examples to sample for metrics
GEN_MAX_LENGTH=${GEN_MAX_LENGTH:-128}  # Truncation length when preparing batch for model

# --max_seq_length controls tokenisation/truncation during dataset processing.
# Default depends on dataset (MetaMathQA often uses 1024).
TOKEN_MAX_SEQ=${TOKEN_MAX_SEQ:-512}

# Provide transform fns explicitly for MetaMathQA
DATASET_TRANSFORM_FN=(
  "sft_metamathqa_tokenize_and_truncate_v1" \
  "sft_metamathqa_filter_v1"
)

# Where to place outputs (defaults to a folder derived from the model name)
OUTPUT_DIR=${OUTPUT_DIR:-"./outputs_$(echo ${MODEL_NAME_OR_PATH} | tr '/' '_')_math"}

# ------------------------------------------------------------------------------
# Convenience logging
# ------------------------------------------------------------------------------

echo "📐  Pre-computing module-type ratios"
echo "Model:            $MODEL_NAME_OR_PATH"
echo "Dataset mixer:    ${DATASET_MIXER_LIST[*]}"
echo "Splits:           ${DATASET_SPLITS[*]}"
echo "Sample size:      $SAMPLE_SIZE"
echo "Gen max length:   $GEN_MAX_LENGTH"
echo "Token max_seq:    $TOKEN_MAX_SEQ"
echo "Output directory: $OUTPUT_DIR"

# ------------------------------------------------------------------------------
# Python invocation
# ------------------------------------------------------------------------------
python src/precompute_module_type_ratios.py \
  --model_name_or_path "$MODEL_NAME_OR_PATH" \
  --dataset_mixer_list "${DATASET_MIXER_LIST[@]}" \
  --dataset_mixer_list_splits "${DATASET_SPLITS[@]}" \
  --dataset_transform_fn "${DATASET_TRANSFORM_FN[@]}" \
  --sample_size $SAMPLE_SIZE \
  --max_length $GEN_MAX_LENGTH \
  --max_seq_length $TOKEN_MAX_SEQ \
  --output_dir "$OUTPUT_DIR" 