#!/bin/bash
#SBATCH --job-name=failure_labelling
#SBATCH --partition=h100                      # Use GPU partition "a100"
#SBATCH --gres gpu:8                          # set 4 GPUs per job
#SBATCH -c 32                                 # Number of cores
#SBATCH -N 1                                  # Ensure that all cores are on one machine
#SBATCH -t 4-00:00                            # Maximum run-time in D-HH:MM
#SBATCH --mem=250G                             # Memory pool for all cores (see also --mem-per-cpu)
#SBATCH --output=%j.out                       # File to which STDOUT will be written
#SBATCH --error=%j.err                        # File to which STDERR will be written


############################################################
# Environment Setup
############################################################
export PYTHONPATH="./:$PYTHONPATH"
export VLLM_WORKER_MULTIPROC_METHOD=spawn

############################################################
# Logging Functions
############################################################
source src/llmutils/bash_utils/logging.sh


############################################################
# Check dependencies
############################################################
if ! command -v yq &> /dev/null; then
    log_error "Error: yq is not installed. Please install it first."
    exit 1
fi

############################################################
# Load configs
############################################################
CONFIG_DIR="exps/failure_analysis/scripts"
CONFIG_FILE="${CONFIG_DIR}/models_config.yaml"
RESPONSE_CONFIG="${CONFIG_DIR}/response_config.yaml"

if [ ! -f "$CONFIG_FILE" ]; then
    log_error "Config file not found: $CONFIG_FILE"
    exit 1
fi

if [ ! -f "$RESPONSE_CONFIG" ]; then
    log_error "Response config file not found: $RESPONSE_CONFIG"
    exit 1
fi

############################################################
# Process each model
############################################################
process_model() {
    local model_key=$1
    
    # Get model configuration
    local prompt_file=$(yq -r ".[\"$model_key\"].prompt_file" "$CONFIG_FILE")
    local response_file=$(yq -r ".[\"$model_key\"].response_file" "$CONFIG_FILE")
    local gpu_type=$(yq -r ".[\"$model_key\"].gpu_type" "$CONFIG_FILE")
    local gpu_count=$(yq -r ".[\"$model_key\"].gpu_count" "$CONFIG_FILE")
    
    # Get response configuration
    local python_vllm=$(yq -r ".python_vllm" "$RESPONSE_CONFIG")
    local model_name=$(yq -r ".model_name" "$RESPONSE_CONFIG")
    local max_new_tokens=$(yq -r ".max_new_tokens" "$RESPONSE_CONFIG")
    local top_p=$(yq -r ".top_p" "$RESPONSE_CONFIG")
    local temperature=$(yq -r ".temperature" "$RESPONSE_CONFIG")
    
    # Skip if prompt file doesn't exist
    if [ ! -f "$prompt_file" ]; then
        log_error "Prompt file not found for $model_key: $prompt_file"
        return
    fi
    
    # Skip if response file exists
    if [ -f "$response_file" ]; then
        log_warning "Skipping response generation for $model_key (file exists)"
        return
    fi
    
    # VLLM settings
    local TENSOR_PARALLEL_SIZE=${gpu_count}
    local VLLM_MAX_NUM_SEQS=$((TENSOR_PARALLEL_SIZE * 8))
    
    log "Generating responses for $model_key"
    log "Model: $model_name"
    log "Prompt file: $prompt_file"
    log "Response file: $response_file"
    log "GPU type: $gpu_type"
    log "GPU count: $gpu_count"
    
    # Create output directory if it doesn't exist
    mkdir -p "$(dirname "$response_file")"
    
    if ! ${python_vllm} src/turtlegfx_datagen/inference/build_responses_pixtral.py \
        --model_name ${model_name} \
        --prompt_file ${prompt_file} \
        --max_new_tokens ${max_new_tokens} \
        --do_sample \
        --quantization \
        --top_p ${top_p} \
        --temperature ${temperature} \
        --vllm_batch_size ${VLLM_MAX_NUM_SEQS} \
        --tensor_parallel_size ${TENSOR_PARALLEL_SIZE} \
        --output_path ${response_file}; then
        log_error "Failed to generate responses for ${model_key}"
        return
    fi
    
    log_done "Successfully generated responses for ${model_key}"
}

############################################################
# Main execution
############################################################
# Get all model keys from the config file
model_keys=$(yq -r 'keys | .[]' "$CONFIG_FILE")

# Process each model
for model_key in $model_keys; do
    process_model "$model_key"
done

log_done "All responses generation complete"