#!/bin/bash
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0


MODEL_PATH=${BAGEL_MODEL_PATH:-./models}
ACTION_NORM_PATH=${BAGEL_ACTION_NORM_PATH:-./data/bagel_data/action_normalizer.json}
PORT=8000
GPU_IDS="0 1 2 3"
NUM_WORKERS=4
MAX_MEM_PER_GPU=80GiB

echo "========================================="
echo "Starting Bagel Inference Server (Multi-Worker)"
echo "========================================="
echo "Model Path: $MODEL_PATH"
echo "Action Norm Path: $ACTION_NORM_PATH"
echo "Port: $PORT"
echo "GPU IDs: $GPU_IDS"
echo "Num Workers: $NUM_WORKERS"
echo "========================================="

python bagel_client/websocket_bagel_server.py \
    --model-path "$MODEL_PATH" \
    --action-norm-path "$ACTION_NORM_PATH" \
    --max-mem-per-gpu "$MAX_MEM_PER_GPU" \
    --num-workers "$NUM_WORKERS" \
    --gpu-ids $GPU_IDS \
    --start-method spawn \
    --host 0.0.0.0 \
    --port "$PORT" \
    --edit-cfg-text-scale 4.0 \
    --edit-cfg-img-scale 2.0 \
    --edit-timestep-shift 3.0 \
    --edit-num-timesteps 50 \
    --edit-cfg-renorm-type text_channel \
    --understand-max-tokens 1000 \
    --understand-temperature 0.3

