#!/bin/bash

#SBATCH --job-name=openmask3d
#SBATCH --gpus=1
#SBATCH --gres=gpumem:40g
#SBATCH --time=24:00:00
#SBATCH --ntasks=2
#SBATCH --mem-per-cpu=50G

export OMP_NUM_THREADS=3  # speeds up MinkowskiEngine
set -e

# OPENMASK3D SCANNET200 EVALUATION SCRIPT
# This script performs the following in order to evaluate OpenMask3D predictions on the ScanNet200 validation set
# 1. Compute class agnostic masks and save them
# 2. Compute mask features for each mask and save them
# 3. Evaluate for closed-set 3D semantic instance segmentation

# --------
# NOTE: SET THESE PARAMETERS!
SCANS_PATH="scans"
SCANNET_PROCESSED_DIR="./openmask3d/data/scannetpp"
# model ckpt paths
MASK_MODULE_CKPT_PATH="$(pwd)/resources/scannet200_model.ckpt"
SAM_CKPT_PATH="$(pwd)/resources/sam_vit_h_4b8939.pth"
# output directories to save masks and mask features
EXPERIMENT_NAME="scannetpp_office"
OUTPUT_DIRECTORY="$(pwd)/output"
TIMESTAMP=$(date +"%Y-%m-%d-%H-%M-%S")
OUTPUT_FOLDER_DIRECTORY="${OUTPUT_DIRECTORY}/${TIMESTAMP}-${EXPERIMENT_NAME}"
# OUTPUT_FOLDER_DIRECTORY="$(pwd)/output/2024-03-10-13-15-22-scannetpp_office"
MASK_SAVE_DIR="${OUTPUT_FOLDER_DIRECTORY}/masks"
# MASK_SAVE_DIR="./openmask3d/data/scannetpp/predictions/masks"
MASK_FEATURE_SAVE_DIR="${OUTPUT_FOLDER_DIRECTORY}/mask_features"
# MASK_FEATURE_SAVE_DIR="./openmask3d/data/scannetpp/predictions/mask_features"
SAVE_VISUALIZATIONS=true #if set to true, saves pyviz3d visualizations
SAVE_CROPS=true 

# Paremeters below are AUTOMATICALLY set based on the parameters above:
SCANNET_LABEL_DB_PATH="${SCANNET_PROCESSED_DIR%/}/label_database.yaml"
SCANNET_INSTANCE_GT_DIR="${SCANNET_PROCESSED_DIR%/}/instance_gt/validation"
# gpu optimization
OPTIMIZE_GPU_USAGE=false

cd openmask3d

# 1.Compute class agnostic masks and save them
python class_agnostic_mask_computation/get_masks_scannetpp.py \
    general.experiment_name=${EXPERIMENT_NAME} \
    general.project_name="scannet++" \
    general.checkpoint=${MASK_MODULE_CKPT_PATH} \
    general.train_mode=false \
    model.num_queries=150 \
    general.use_dbscan=true \
    general.dbscan_eps=0.95 \
    general.save_visualizations=${SAVE_VISUALIZATIONS} \
    data.test_dataset.data_dir=${SCANNET_PROCESSED_DIR}  \
    data.validation_dataset.data_dir=${SCANNET_PROCESSED_DIR} \
    data.train_dataset.data_dir=${SCANNET_PROCESSED_DIR} \
    data.train_dataset.label_db_filepath=${SCANNET_LABEL_DB_PATH} \
    data.validation_dataset.label_db_filepath=${SCANNET_LABEL_DB_PATH} \
    data.test_dataset.label_db_filepath=${SCANNET_LABEL_DB_PATH}  \
    general.mask_save_dir=${MASK_SAVE_DIR} \
    hydra.run.dir="${OUTPUT_FOLDER_DIRECTORY}/hydra_outputs/class_agnostic_mask_computation"
echo "[INFO] Mask computation done!"
# get the path of the saved masks
echo "[INFO] Masks saved to ${MASK_SAVE_DIR}."

# 2. Compute mask features
echo "[INFO] Computing mask features..."
python compute_features_scannetpp.py \
    data.scans_path=${SCANS_PATH} \
    data.masks.masks_path=${MASK_SAVE_DIR} \
    output.output_directory=${MASK_FEATURE_SAVE_DIR} \
    output.experiment_name=${EXPERIMENT_NAME} \
    output.save_crops=${SAVE_CROPS} \
    external.sam_checkpoint=${SAM_CKPT_PATH} \
    gpu.optimize_gpu_usage=${OPTIMIZE_GPU_USAGE} \
    hydra.run.dir="${OUTPUT_FOLDER_DIRECTORY}/hydra_outputs/mask_features_computation" \
    openmask3d.frequency=30
echo "[INFO] Feature computation done!"

echo "[INFO] Reading masks from ${MASK_SAVE_DIR}."
echo "[INFO] Reading mask features from ${MASK_FEATURE_SAVE_DIR}."
# python evaluation/run_eval_close_vocab_inst_seg.py \
#     --gt_dir=${SCANNET_INSTANCE_GT_DIR} \
#     --mask_pred_dir=${MASK_SAVE_DIR} \
#     --mask_features_dir=${MASK_FEATURE_SAVE_DIR} \
