#!/bin/bash
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

# Bagel Server startup script
# Usage: ./start_server.sh [model_path] [port]

module load cuda/12.8
source activate uni-plan

# Default configuration (can be overridden by environment variables)
CONFIG_PATH=${BAGEL_CONFIG_PATH:-./models/BAGEL-7B-MoT}
MODEL_PATH=${BAGEL_MODEL_PATH:-./models/bagel/ckpt/0005000}
ACTION_NORM_PATH=${BAGEL_ACTION_NORM_PATH:-./data/bagel_data/dynamics/libero_spatial_with_wrist/action_normalizer.json}
PORT=8000
GPU_IDS="0 1 2 3"         # GPU list, each GPU runs one worker
NUM_WORKERS=4       # Number of workers (usually equals number of GPUs)
MAX_MEM_PER_GPU=80GiB  # Maximum memory per GPU (NF4 quantization requires ~15-20GB, recommend 40GB+)

echo "========================================="
echo "Starting Bagel Inference Server (Multi-Worker)"
echo "========================================="
echo "Model Path: $MODEL_PATH"
echo "Action Norm Path: $ACTION_NORM_PATH"
echo "Port: $PORT"
echo "GPU IDs: $GPU_IDS"
echo "Num Workers: $NUM_WORKERS"
echo "========================================="

python bagel_client/websocket_bagel_server.py \
    --model-config-path $CONFIG_PATH \
    --model-weights-path $MODEL_PATH \
    --action-norm-path "$ACTION_NORM_PATH" \
    --max-mem-per-gpu "$MAX_MEM_PER_GPU" \
    --num-workers "$NUM_WORKERS" \
    --gpu-ids $GPU_IDS \
    --start-method spawn \
    --host 0.0.0.0 \
    --port "$PORT" \
    --edit-cfg-text-scale 4.0 \
    --edit-cfg-img-scale 2.0 \
    --edit-timestep-shift 3.0 \
    --edit-num-timesteps 50 \
    --edit-cfg-renorm-type text_channel \
    --understand-max-tokens 1000 \
    --understand-temperature 0.3

