#!/bin/bash
# Pre-training script for Function Encoder VLA

cd $WORKDIR

echo "========================================="
echo "Function Encoder Pre-training"
echo "========================================="
echo ""
echo "Configuration:"
echo "  - Mixture: oxe_magic_soup_plus (27 datasets)"
echo "  - FE basis functions: 16"
echo "  - Buffer size: 512 samples per dataset"
echo "  - Calibration interval: every 16 steps"
echo "  - Batch size: 8 per GPU"
echo "  - Learning rate: 1e-4"
echo "  - Max steps: 300000"
echo "  - LoRA rank: 32"
echo ""

# Set number of GPUs (adjust based on your hardware)
NUM_GPUS=1

torchrun --standalone --nnodes 1 --nproc-per-node $NUM_GPUS \
    vla-scripts/pretrain_fe.py \
    --vla_path openvla/openvla-7b \
    --data_root_dir $DATAROOTDIR \
    --data_mix oxe_magic_soup_plus \
    --run_root_dir runs/fe_pretrain \
    --shuffle_buffer_size 100000 \
    --fe_basis_functions 16 \
    --n_continuous_actions 7 \
    --calibration_buffer_size 128 \
    --calibrate_interval 16 \
    --prefill_samples_per_dataset 32 \
    --use_film False \
    --num_images_in_input 1 \
    --use_proprio False \
    --image_aug True \
    --batch_size 8 \
    --learning_rate 1e-4 \
    --lr_warmup_steps 1000 \
    --num_steps_before_decay 200000 \
    --grad_accumulation_steps 1 \
    --max_steps 300000 \
    --save_freq 5000 \
    --save_latest_checkpoint_only False \
    --use_lora True \
    --lora_rank 32 \
    --lora_dropout 0.05 \
    --seed 7 \
    --wandb_project mos-vla \
    --wandb_entity user

echo ""
echo "========================================="
echo "Pre-training launched!"
echo "Check runs/fe_pretrain/ for checkpoints"
echo "========================================="