#!/bin/bash

#SBATCH --job-name=codegen
#SBATCH --partition=a100                     # Use GPU partition "a100"
#SBATCH --gres gpu:2                          # set 4 GPUs per job
#SBATCH -c 32                                 # Number of cores
#SBATCH -N 1                                  # Ensure that all cores are on one machine
#SBATCH -t 3-10:30                            # Maximum run-time in D-HH:MM
#SBATCH --mem=150G                            # Memory pool for all cores (see also --mem-per-cpu)
#SBATCH --output=%j.out                       # File to which STDOUT will be written
#SBATCH --error=%j.err                        # File to which STDERR will be written

# Example usage:
# bash src/turtlegfx_datagen/codegen/scripts/build_responses.sh --model_name "meta-llama/Llama-3.2-1B-Instruct"

# Models:
# "meta-llama/Llama-3.2-1B-Instruct"
# "meta-llama/Llama-3.1-70B-Instruct"

export PYTHONPATH="./:$PYTHONPATH"
export PYTHONPATH="src/:$PYTHONPATH"
export VLLM_WORKER_MULTIPROC_METHOD=spawn

# Default values for the parameters
MODEL_NAME=""
PROMPT_FILES=("src/turtlegfx_datagen/codegen/results/prompts/prompts_codeedit.json")
MAX_NEW_TOKENS=1024
TOP_P=1
TEMPERATURE=0.5

# Parse command line arguments
while [[ "$#" -gt 0 ]]; do
  case $1 in
  --model_name)
    MODEL_NAME="$2"
    shift
    ;;
  --prompt_files)
    IFS=',' read -r -a PROMPT_FILES <<<"$2"
    shift
    ;; # Comma-separated list of prompt files
  --max_new_tokens)
    MAX_NEW_TOKENS="$2"
    shift
    ;;
  --top_p)
    TOP_P="$2"
    shift
    ;;
  --temperature)
    TEMPERATURE="$2"
    shift
    ;;
  *)
    echo "Unknown parameter passed: $1"
    exit 1
    ;;
  esac
  shift
done

# Check if model_name argument is provided
if [ -z "$MODEL_NAME" ]; then
  echo "Error: model_name argument is required"
  exit 1
fi

# Iterate over each prompt file
for PROMPT_FILE in "${PROMPT_FILES[@]}"; do
  echo "PROMPT_FILE: ${PROMPT_FILE}"
  echo "MAX_NEW_TOKENS: ${MAX_NEW_TOKENS}"
  echo "TOP_P: ${TOP_P}"
  echo "TEMPERATURE: ${TEMPERATURE}"
  echo "MODEL_NAME: ${MODEL_NAME}"

  python src/turtlegfx_datagen/inference/chat_completion_vllm.py \
    --model_name "${MODEL_NAME}" \
    --prompt_file ${PROMPT_FILE} \
    --max_new_tokens ${MAX_NEW_TOKENS} \
    --do_sample \
    --quantization \
    --top_p ${TOP_P} \
    --temperature ${TEMPERATURE} \
    --output_path "src/turtlegfx_datagen/codegen/results/responses/prompts_codeedit__${MODEL_NAME}__$(date +'%Y%m%d_%H%M%S').json"
done
