#!/bin/bash
#SBATCH --job-name=build_responses_vlms_yaml
#SBATCH --partition=h100                      # Use GPU partition "a100"
#SBATCH --gres gpu:8                          # set 4 GPUs per job
#SBATCH -c 32                                 # Number of cores
#SBATCH -N 1                                  # Ensure that all cores are on one machine
#SBATCH -t 4-00:00                            # Maximum run-time in D-HH:MM
#SBATCH --mem=128G                            # Memory pool for all cores (see also --mem-per-cpu)
#SBATCH --output=%j.out                       # File to which STDOUT will be written
#SBATCH --error=%j.err                        # File to which STDERR will be written

############################################################
# ENV
############################################################
export PYTHONPATH="./:$PYTHONPATH"
export PYTHONPATH="src/:$PYTHONPATH"
export VLLM_WORKER_MULTIPROC_METHOD=spawn

############################################################
# Parse command line arguments
############################################################
while getopts ":m:c:" opt; do
    case ${opt} in
        m )
            MODEL_KEY=$OPTARG
            ;;
        c )
            MODEL_CONFIG=$OPTARG
            ;;
        \? )
            echo "Invalid option: -$OPTARG" 1>&2
            exit 1
            ;;
        : )
            echo "Option -$OPTARG requires an argument" 1>&2
            exit 1
            ;;
    esac
done

############################################################
# Logging
############################################################
log() {
    echo "$(tput setaf 0)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

log_warning() {
    echo "$(tput setaf 3)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

log_done() {
    echo "$(tput setaf 2)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

log_error() {
    echo "$(tput setaf 1)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

############################################################
# Load YAML configs
############################################################
# Check if yq is installed
if ! command -v yq &> /dev/null; then
    log_error "yq is not installed. Please install it first."
    exit 1
fi

# Check if required arguments are provided
if [ -z "$MODEL_KEY" ] || [ -z "$MODEL_CONFIG" ]; then
    log_error "Both model key (-m) and config file (-c) are required"
    exit 1
fi

# Load default config
CONFIG_DIR="exps/eval_vlms/scripts/configs"
DEFAULT_CONFIG="$CONFIG_DIR/default_config.yaml"
if [ ! -f "$DEFAULT_CONFIG" ]; then
    log_error "Default config file not found: $DEFAULT_CONFIG"
    exit 1
fi

# Check if model config exists
if [ ! -f "$MODEL_CONFIG" ]; then
    log_error "Model config file not found: $MODEL_CONFIG"
    exit 1
fi

# Get model specific config
MODEL_NAME=$(yq -r ".[\"$MODEL_KEY\"].model_name" "$MODEL_CONFIG")
MODEL_FAMILY=$(yq -r ".[\"$MODEL_KEY\"].model_family" "$MODEL_CONFIG")
PROMPT_FILE=$(yq -r ".[\"$MODEL_KEY\"].prompt_file" "$MODEL_CONFIG")
RESPONSE_FILE=$(yq -r ".[\"$MODEL_KEY\"].response_file" "$MODEL_CONFIG")

# Get default config values
PYTHON_VLLM=$(yq -r .python_vllm "$DEFAULT_CONFIG")
MAX_NEW_TOKENS=$(yq -r .max_new_tokens "$DEFAULT_CONFIG")
DO_SAMPLE=$(yq -r .do_sample "$DEFAULT_CONFIG")
TOP_P=$(yq -r .top_p "$DEFAULT_CONFIG")
TEMPERATURE=$(yq -r .temperature "$DEFAULT_CONFIG")
VLLM_MAX_NUM_SEQS=$(yq -r .vllm_max_num_seqs "$DEFAULT_CONFIG")

# Get inference script, prioritize model config over default config
INFERENCE_SCRIPT=$(yq -r "if .[\"$MODEL_KEY\"].inference_script then .[\"$MODEL_KEY\"].inference_script else .inference_script end" "$MODEL_CONFIG" "$DEFAULT_CONFIG")

############################################################
# GPU Configs
############################################################
NUM_GPUS=$(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n' | wc -l)
TENSOR_PARALLEL_SIZE=$NUM_GPUS
VLLM_MAX_NUM_SEQS=16

log "Model key: $MODEL_KEY"
log "Model name: $MODEL_NAME"
log "Model family: $MODEL_FAMILY"
log "Prompt file: $PROMPT_FILE"
log "Response file: $RESPONSE_FILE"
log "TENSOR_PARALLEL_SIZE: $TENSOR_PARALLEL_SIZE"
log "VLLM_MAX_NUM_SEQS: $VLLM_MAX_NUM_SEQS"

############################################################
# Generate responses
############################################################
if [ ! -f "${RESPONSE_FILE}" ]; then
    log "Executing Python script with the following parameters:"
    log "Inference script: ${INFERENCE_SCRIPT}"
    log "Model name: ${MODEL_NAME}"
    log "Prompt file: ${PROMPT_FILE}"
    log "Response file: ${RESPONSE_FILE}"
    log "VLLM batch size: ${VLLM_MAX_NUM_SEQS}"
    log "Tensor parallel size: ${TENSOR_PARALLEL_SIZE}"
  
    ${PYTHON_VLLM} ${INFERENCE_SCRIPT} \
        --model_name ${MODEL_NAME} \
        --prompt_file ${PROMPT_FILE} \
        --max_new_tokens ${MAX_NEW_TOKENS} \
        --do_sample \
        --top_p ${TOP_P} \
        --temperature ${TEMPERATURE} \
        --vllm_batch_size ${VLLM_MAX_NUM_SEQS} \
        --tensor_parallel_size ${TENSOR_PARALLEL_SIZE} \
        --output_path ${RESPONSE_FILE}
else
    log_done "File ${RESPONSE_FILE} already exists. Skipping generation."
fi 