#!/bin/bash
# This example will start serving the 345M model.
export CUDA_DEVICE_MAX_CONNECTIONS=1

pip install flask-restful


SH_FILE=$(basename $0)
FILE_NAME=${SH_FILE%.*}
echo $FILE_NAME

CUDA_DEVICE=${FILE_NAME##*_}
echo $CUDA_DEVICE
FILE_NAME=${FILE_NAME%_*}

SERVER_PORT=${FILE_NAME##*_}
echo $SERVER_PORT
FILE_NAME=${FILE_NAME%_*}

TASK_VERSION=$FILE_NAME
echo $TASK_VERSION

export CUDA_VISIBLE_DEVICES=$CUDA_DEVICE

DISTRIBUTED_ARGS="--nproc_per_node 1 \
                  --nnodes 1 \
                  --node_rank 0 \
                  --master_addr localhost \
                  --master_port 6000"

CHECKPOINT_PATH=/workspace/checkpoints/checkpoints/${TASK_VERSION}
VOCAB_FILE=/workspace/dataset/gpt2_vocabulary/gpt2-vocab.json
MERGE_FILE=/workspace/dataset/gpt2_vocabulary/gpt2-merges.txt


torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py   \
       --port $SERVER_PORT \
       --tensor-model-parallel-size 1  \
       --pipeline-model-parallel-size 1  \
       --num-layers 24  \
       --hidden-size 1024  \
       --load ${CHECKPOINT_PATH}  \
       --num-attention-heads 16  \
       --max-position-embeddings 1024  \
       --tokenizer-type GPT2BPETokenizer  \
       --fp16  \
       --micro-batch-size 1  \
       --seq-length 1024  \
       --vocab-file $VOCAB_FILE  \
       --merge-file $MERGE_FILE  \
       --seed 42
