#!/bin/bash

if [ "$#" -lt 2 ]; then
    echo "Usage: $0 <use_vllm> <path_to_config>"
    echo "  use_vllm: true/false - whether to use VLLM for inference"
    echo "  path_to_config: path to the inference config file"
    exit 1
fi

VLLM_ENABLED=$1
INFERENCE_SETTINGS_PATH=$2

pip install -U flash-attn

nvidia-smi

NUM_GPUS=$(python -c "import torch; print(torch.cuda.device_count())")

if [ "$VLLM_ENABLED" = true ]; then
    python -m src inference_chat --inference_settings_path $INFERENCE_SETTINGS_PATH 
else
    accelerate launch --module --multi_gpu --num_processes=$NUM_GPUS src inference_chat --inference_settings_path $INFERENCE_SETTINGS_PATH
fi

ls inference_output/ && echo "Done!"
