#!/bin/bash
#SBATCH --job-name=relabel
#SBATCH --partition=a100                   # Use GPU partition "a100", "h100sxm"
#SBATCH --gres gpu:2                       # Set 2 GPUs per job
#SBATCH -c 32                              # Number of cores
#SBATCH -N 1                               # Ensure that all cores are on one machine
#SBATCH -t 4-00:00                         # Maximum run-time in D-HH:MM
#SBATCH --mem=256G                         # Memory pool for all cores
#SBATCH --output=%j.out                    # File to which STDOUT will be written
#SBATCH --error=%j.err                     # File to which STDERR will be written


# Override both partition and GPU count
# sbatch -p a100 --gres gpu:a100:8 exps/datasetgen/scripts/build_dataset_relabel.sh
# sbatch -p h100 --gres gpu:h100:8 exps/datasetgen/scripts/build_dataset.sh -c exps/datasetgen/scripts/configs/dataset_inferred_spec_sz2.sh


# Function to display usage information
usage() {
    echo "Usage: $0 [-c <experiment_config_file>]"
    echo "  -c <config>     Experiment-specific config file (optional)"
    exit 1
}

# Function for logging with timestamp
log() {
    echo "$(tput setaf 3)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

log_done() {
    echo "$(tput setaf 2)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

log_error() {
    echo "$(tput setaf 1)$(date '+%Y-%m-%d %H:%M:%S') $@$(tput sgr0)"
}

# Parse command line arguments
while getopts ":c:" opt; do
    case ${opt} in
        c )
            CONFIG_FILE=$OPTARG
            ;;
        \? )
            usage
            ;;
    esac
done


############################################################
# Source configuration files
############################################################
if [ -z "$CONFIG_FILE" ]; then
    CONFIG_FILE="exps/datasetgen/scripts/configs/dataset_relabel_config.sh"
fi

# Check if config file exists
if [ ! -f "$CONFIG_FILE" ]; then
    log_error "Error: Experiment config file '$CONFIG_FILE' not found."
    exit 1
fi

source "$CONFIG_FILE"

############################################################
# GPU Configs
############################################################
NUM_GPUS=$(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n' | wc -l)
# assert if TENSOR_PARALLEL_SIZE is the same as NUM_GPUS
if [ "$TENSOR_PARALLEL_SIZE" -ne "$NUM_GPUS" ]; then
    log_error "WARNING: TENSOR_PARALLEL_SIZE ($TENSOR_PARALLEL_SIZE) is not the same as NUM_GPUS ($NUM_GPUS)"
fi

############################################################
# Configuration
############################################################
log "ITERATION: $ITERATION"
log "DATE: $DATE"
log "--------------------------------"
log "Experiment Config: $CONFIG_FILE"
log "Input file: $INPUT_FILE"
log "--------------------------------"
log "SLURM_JOB_PARTITION: $SLURM_JOB_PARTITION"
log "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
log "Number of GPUs allocated: $NUM_GPUS"
log "TENSOR_PARALLEL_SIZE: $TENSOR_PARALLEL_SIZE"
log "VLLM_MAX_NUM_SEQS: $VLLM_MAX_NUM_SEQS"
log "--------------------------------"

############################################################
# Nested Functions for Modular Structure
############################################################

codededup() {

    # Merge Datasets
    codededup_merge_datasets() {
        echo "CODEDEDUP (1. Merge Datasets): Starting..."

        local INPUT=$(find ${POSTPROCESS_DIR} -regex ".*/${PROMPT_NAME}__${MODEL_NAME_SHORT}__${DATE}_part_[0-9]+\.json" 2>/dev/null)
        local OUTPUT="exps/datasetgen/results/${DATE}/iter${ITERATION}/coderelabel/seed_dataset_iter${ITERATION}__${MODEL_NAME_SHORT}.json"

        # if output not exists and input parts exist
        if [ ! -f "${OUTPUT}" ] && ls ${INPUT} > /dev/null 2>&1; then
            echo "CODEDEDUP (1. Merge Datasets): Merging multiple parts with seed file"
            ${PYTHON_TURTLE} src/turtlegfx_datagen/codededup/merge_datasets.py \
                --input_paths ${INPUT} ${SEED_FILE} \
                --output_path ${OUTPUT}
        else
            echo "CODEDEDUP (1. Merge Datasets): Merged file ${OUTPUT} already exists or no input parts to merge. Skipping."
        fi
        echo "CODEDEDUP (1. Merge Datasets): Finished."
    }

    codededup_merge_datasets
}


coderelabel() {
    PROMPT_PATTERN="exps/datasetgen/results/${DATE}/iter${ITERATION}/coderelabel/prompts/prompts_${PROMPT_NAME}_part_*.json"

    # Generate Prompts
    coderelabel_generate_prompts() {
        log "CODERELABEL (1. Generate Prompts): Starting..."

        # Check if prompt parts exist
        if ls ${PROMPT_PATTERN} >/dev/null 2>&1; then
            log_done "CODERELABEL (1. Generate Prompts): All prompt parts exist. Skipping generation."
        else
            log "CODERELABEL (1. Generate Prompts): Generating new prompts..."
            ${PYTHON_TURTLE} src/turtlegfx_datagen/coderelabel/build_prompts_relabel.py \
                --input_file "${INPUT_FILE}" \
                --prompt_template "${PROMPT_TEMPLATE}" \
                --output_file "${PROMPT_FILE}" \
                --max_num_items 50000
        fi

        log_done "CODERELABEL (1. Generate Prompts): Finished."
    }

    # Generate Responses
    coderelabel_generate_responses() {
        log "CODERELABEL (2. Generate Responses): Starting..."

        # Count prompt parts
        local PROMPT_PARTS_COUNT=$(ls ${PROMPT_PATTERN} 2>/dev/null | wc -l)

        # print prompt parts count
        log "CODERELABEL (2. Generate Responses): Prompt parts count: $PROMPT_PARTS_COUNT"

        # Flag to check if any job was submitted
        local JOB_SUBMITTED=false

        # Submit jobs for missing parts in reverse order (process part 0 last)
        for i in $(seq $((PROMPT_PARTS_COUNT - 1)) -1 0); do
            local PART_FILE="${RESPONSES_DIR}/${PROMPT_NAME}__${MODEL_NAME_SHORT}__${DATE}_part_${i}.json"

            if [ ! -f "${PART_FILE}" ]; then
                if [ "$i" -eq 0 ]; then
                    log "CODERELABEL (2. Generate Responses): Generating part $i in the current bash session"
                    bash exps/datasetgen/scripts/coderelabel/build_responses.sh -c $CONFIG_FILE -p $i
                else
                    log "CODERELABEL (2. Generate Responses): Submitting job for part $i"
                    if [ $((i % 2)) -eq 1 ]; then
                        sbatch \
                            -p a100 \
                            --gres gpu:a100:${TENSOR_PARALLEL_SIZE} \
                            exps/datasetgen/scripts/coderelabel/build_responses.sh -c $CONFIG_FILE -p $i
                    else
                        sbatch \
                            -p h100 \
                            --gres gpu:h100:${TENSOR_PARALLEL_SIZE} \
                            exps/datasetgen/scripts/coderelabel/build_responses.sh -c $CONFIG_FILE -p $i
                    fi
                    JOB_SUBMITTED=true
                fi
            else
                log "CODERELABEL (2. Generate Responses): Part file $PART_FILE already exists. Skipping."
            fi
        done

        # Exit if jobs are submitted
        if $JOB_SUBMITTED; then
            log_done "CODERELABEL (2. Generate Responses): Jobs submitted for missing parts. Exiting."
            exit 0
        fi

        log_done "CODERELABEL (2. Generate Responses): Finished."
    }

    # Postprocess Responses
    coderelabel_postprocess() {
        log "CODERELABEL (3. Postprocess Responses): Starting..."

        local PROMPT_PARTS_COUNT=$(ls ${PROMPT_PATTERN} 2>/dev/null | wc -l)

        # Flag to check if any job was submitted
        local JOB_SUBMITTED=false

        # Postprocess missing parts in reverse order (process part 0 last)
        for i in $(seq $((PROMPT_PARTS_COUNT - 1)) -1 0); do
            local PART_FILE="${RESPONSES_DIR}/${PROMPT_NAME}__${MODEL_NAME_SHORT}__${DATE}_part_${i}.json"
            local OUTPUT="${POSTPROCESS_DIR}/${PROMPT_NAME}__${MODEL_NAME_SHORT}__${DATE}_part_${i}.json"

            if [ ! -f "${OUTPUT}" ]; then
                if [ "$i" -eq 0 ]; then
                    log "CODERELABEL (3. Postprocess Responses): Generating part $i in the current bash session"
                    bash exps/datasetgen/scripts/coderelabel/postprocess.sh -c $CONFIG_FILE -p $i
                else
                    log "CODERELABEL (3. Postprocess Responses): Submitting job for part $i"
                    sbatch exps/datasetgen/scripts/coderelabel/postprocess.sh -c $CONFIG_FILE -p $i
                    JOB_SUBMITTED=true
                fi
            else
                log "CODERELABEL (3. Postprocess Responses): Output file $OUTPUT already exists. Skipping."
            fi
        done

        # Exit if jobs are submitted
        if $JOB_SUBMITTED; then
            log_done "CODERELABEL (3. Postprocess Responses): Jobs submitted for missing parts. Exiting."
            exit 0
        fi

        log_done "CODERELABEL (3. Postprocess Responses): Finished."
    }

    # Execute all steps
    coderelabel_generate_prompts
    coderelabel_generate_responses
    coderelabel_postprocess
}


# Override config values if specified via command line
if [ ! -z "$CLI_ITERATION" ]; then
    ITERATION=$CLI_ITERATION
    log "Using command line ITERATION: $ITERATION"
fi

if [ ! -z "$CLI_DATE" ]; then
    DATE=$CLI_DATE
    log "Using command line DATE: $DATE"
fi

# Validate required parameters
if [ -z "$ITERATION" ]; then
    log_error "Error: ITERATION must be set either in config file or via command line"
    exit 1
fi

if [ -z "$DATE" ]; then
    log_error "Error: DATE must be set either in config file or via command line"
    exit 1
fi

############################################################
# Execute
############################################################
coderelabel
codededup