{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../Grounded-SAM-2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import os.path as osp\n",
    "import shutil\n",
    "import cv2\n",
    "import torch\n",
    "import numpy as np\n",
    "from skimage.color import label2rgb\n",
    "import supervision as sv\n",
    "from glob import glob\n",
    "from torchvision.ops import box_convert\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "from sam2.build_sam import build_sam2_video_predictor, build_sam2\n",
    "from sam2.sam2_image_predictor import SAM2ImagePredictor \n",
    "from grounding_dino.groundingdino.util.inference import load_model, load_image, predict\n",
    "from utils.track_utils import sample_points_from_masks\n",
    "from utils.video_utils import create_video_from_images\n",
    "from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection \n",
    "from utils.track_utils import sample_points_from_masks\n",
    "from utils.video_utils import create_video_from_images\n",
    "from utils.common_utils import CommonUtils\n",
    "from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo\n",
    "import json\n",
    "import copy\n",
    "from helper import reverse_files, make_tmp_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 198,
   "metadata": {},
   "outputs": [],
   "source": [
    "BOX_THRESHOLD = 0.35\n",
    "TEXT_THRESHOLD = 0.25\n",
    "REDETECT_INTERVAL = 1000\n",
    "REVERSE_TRACK = False\n",
    "CAM_TAG = 'left1' # choose from [\"left1\", \"right1\"] (for hypernerf vrig)\n",
    "\n",
    "# TEXT_PROMPT = \"person.broom.\"\n",
    "# SCENE = \"broom2\"\n",
    "# TEXT_PROMPT = \"person.banana.\"\n",
    "# SCENE = \"vrig-peel-banana\"\n",
    "\n",
    "# TEXT_PROMPT = \"person.toy.\"\n",
    "# SCENE = \"vrig-chicken\"\n",
    "\n",
    "BOX_THRESHOLD = 0.2\n",
    "TEXT_THRESHOLD = 0.15\n",
    "TEXT_PROMPT = \"machine.cable.motor.3dprinter head.plate.board.\"\n",
    "SCENE = \"vrig-3dprinter\"\n",
    "\n",
    "PROMPT_TYPE_FOR_VIDEO = \"box\" # choose from [\"point\", \"box\", \"mask\"]\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# set where to read images from\n",
    "VIDEO_PATH = osp.join(f\"/scratch/xz653/datasets/hypernerf/{SCENE}\", f'rgb/2x/{CAM_TAG}_*.png') \n",
    "OUTPUT_PATH = osp.join(f\"/scratch/xz653/datasets/hypernerf/{SCENE}\", 'instance/2x')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Track + Detect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 199,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "images from /scratch/xz653/datasets/hypernerf/vrig-3dprinter/rgb/2x/left1_*.png 207\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 207/207 [00:12<00:00, 16.81it/s]\n"
     ]
    }
   ],
   "source": [
    "GROUNDING_DINO_CONFIG = \"grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py\"\n",
    "GROUNDING_DINO_CHECKPOINT = \"gdino_checkpoints/groundingdino_swint_ogc.pth\"\n",
    "\n",
    "frame_dir = make_tmp_folder()\n",
    "\n",
    "if VIDEO_PATH.endswith(\".mp4\"):\n",
    "    print(\"Extracting frames from video...\")\n",
    "    video_info = sv.VideoInfo.from_video_path(VIDEO_PATH)  # get video info\n",
    "    print(video_info)\n",
    "    frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)\n",
    "    source_frames = Path(frame_dir)\n",
    "\n",
    "    with sv.ImageSink(\n",
    "        target_dir_path=source_frames, \n",
    "        overwrite=True, \n",
    "        image_name_pattern=\"{:05d}.jpg\"\n",
    "    ) as sink:\n",
    "        for frame in tqdm(frame_generator, desc=\"Saving Video Frames\"):\n",
    "            sink.save_image(frame)\n",
    "else:\n",
    "    # VIDEO_PATH is a glob pattern\n",
    "    im_paths = sorted(glob(VIDEO_PATH))\n",
    "    print(f'images from {VIDEO_PATH}', len(im_paths))\n",
    "    for i, im_path in enumerate(tqdm(im_paths)):\n",
    "        im = Image.open(im_path)\n",
    "        bname = osp.splitext(osp.basename(im_path))[0]\n",
    "        bname = bname.split('_')[-1]\n",
    "        im.convert(\"RGB\").save(osp.join(frame_dir, f\"{bname}.jpg\"))\n",
    "    if REVERSE_TRACK:\n",
    "        reverse_files(frame_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 200,
   "metadata": {},
   "outputs": [],
   "source": [
    "try: \n",
    "    grounding_model;\n",
    "    image_predictor;\n",
    "except NameError:\n",
    "    model_id = \"IDEA-Research/grounding-dino-base\"\n",
    "    processor = AutoProcessor.from_pretrained(model_id)\n",
    "    grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(DEVICE)\n",
    "\n",
    "    sam2_checkpoint = \"../Grounded-SAM-2/checkpoints/sam2.1_hiera_large.pt\"\n",
    "    model_cfg = \"../sam2/configs/sam2.1/sam2.1_hiera_l.yaml\"\n",
    "\n",
    "    video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)\n",
    "    sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)\n",
    "    image_predictor = SAM2ImagePredictor(sam2_image_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "frame loading (JPEG / PNG): 100%|██████████| 207/207 [00:05<00:00, 38.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total frames: 207\n",
      "start_frame_idx 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "objects_count 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "propagate in video: 100%|██████████| 207/207 [01:46<00:00,  1.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "video_segments: 207\n",
      "Path '/scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result' already exists.\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000000.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000001.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000002.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000003.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000004.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000005.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000006.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000007.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000008.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000009.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000010.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000011.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000012.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000013.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000014.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000015.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000016.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000017.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000018.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000019.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000020.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000021.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000022.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000023.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000024.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000025.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000026.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000027.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000028.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000029.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000030.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000031.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000032.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000033.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000034.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000035.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000036.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000037.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000038.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000039.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000040.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000041.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000042.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000043.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000044.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000045.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000046.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000047.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000048.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000049.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000050.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000051.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000052.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000053.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000054.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000055.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000056.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000057.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000058.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000059.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000060.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000061.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000062.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000063.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000064.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000065.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000066.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000067.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000068.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000069.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000070.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000071.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000072.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000073.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000074.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000075.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000076.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000077.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000078.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000079.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000080.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000081.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000082.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000083.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000084.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000085.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000086.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000087.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000088.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000089.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000090.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000091.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000092.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000093.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000094.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000095.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000096.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000097.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000098.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000099.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000100.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000101.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000102.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000103.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000104.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000105.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000106.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000107.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000108.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000109.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000110.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000111.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000112.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000113.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000114.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000115.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000116.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000117.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000118.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000119.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000120.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000121.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000122.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000123.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000124.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000125.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000126.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000127.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000128.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000129.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000130.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000131.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000132.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000133.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000134.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000135.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000136.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000137.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000138.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000139.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000140.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000141.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000142.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000143.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000144.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000145.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000146.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000147.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000148.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000149.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000150.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000151.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000152.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000153.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000154.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000155.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000156.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000157.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000158.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000159.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000160.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000161.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000162.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000163.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000164.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000165.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000166.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000167.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000168.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000169.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000170.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000171.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000172.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000173.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000174.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000175.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000176.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000177.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000178.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000179.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000180.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000181.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000182.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000183.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000184.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000185.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000186.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000187.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000188.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000189.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000190.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000191.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000192.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000193.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000194.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000195.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000196.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000197.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000198.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000199.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000200.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000201.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000202.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000203.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000204.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000205.jpg\n",
      "Annotated image saved as /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/img_result/000206.jpg\n",
      "['000000.jpg', '000001.jpg', '000002.jpg', '000003.jpg', '000004.jpg', '000005.jpg', '000006.jpg', '000007.jpg', '000008.jpg', '000009.jpg', '000010.jpg', '000011.jpg', '000012.jpg', '000013.jpg', '000014.jpg', '000015.jpg', '000016.jpg', '000017.jpg', '000018.jpg', '000019.jpg', '000020.jpg', '000021.jpg', '000022.jpg', '000023.jpg', '000024.jpg', '000025.jpg', '000026.jpg', '000027.jpg', '000028.jpg', '000029.jpg', '000030.jpg', '000031.jpg', '000032.jpg', '000033.jpg', '000034.jpg', '000035.jpg', '000036.jpg', '000037.jpg', '000038.jpg', '000039.jpg', '000040.jpg', '000041.jpg', '000042.jpg', '000043.jpg', '000044.jpg', '000045.jpg', '000046.jpg', '000047.jpg', '000048.jpg', '000049.jpg', '000050.jpg', '000051.jpg', '000052.jpg', '000053.jpg', '000054.jpg', '000055.jpg', '000056.jpg', '000057.jpg', '000058.jpg', '000059.jpg', '000060.jpg', '000061.jpg', '000062.jpg', '000063.jpg', '000064.jpg', '000065.jpg', '000066.jpg', '000067.jpg', '000068.jpg', '000069.jpg', '000070.jpg', '000071.jpg', '000072.jpg', '000073.jpg', '000074.jpg', '000075.jpg', '000076.jpg', '000077.jpg', '000078.jpg', '000079.jpg', '000080.jpg', '000081.jpg', '000082.jpg', '000083.jpg', '000084.jpg', '000085.jpg', '000086.jpg', '000087.jpg', '000088.jpg', '000089.jpg', '000090.jpg', '000091.jpg', '000092.jpg', '000093.jpg', '000094.jpg', '000095.jpg', '000096.jpg', '000097.jpg', '000098.jpg', '000099.jpg', '000100.jpg', '000101.jpg', '000102.jpg', '000103.jpg', '000104.jpg', '000105.jpg', '000106.jpg', '000107.jpg', '000108.jpg', '000109.jpg', '000110.jpg', '000111.jpg', '000112.jpg', '000113.jpg', '000114.jpg', '000115.jpg', '000116.jpg', '000117.jpg', '000118.jpg', '000119.jpg', '000120.jpg', '000121.jpg', '000122.jpg', '000123.jpg', '000124.jpg', '000125.jpg', '000126.jpg', '000127.jpg', '000128.jpg', '000129.jpg', '000130.jpg', '000131.jpg', '000132.jpg', '000133.jpg', '000134.jpg', '000135.jpg', '000136.jpg', '000137.jpg', '000138.jpg', '000139.jpg', '000140.jpg', '000141.jpg', '000142.jpg', '000143.jpg', '000144.jpg', '000145.jpg', '000146.jpg', '000147.jpg', '000148.jpg', '000149.jpg', '000150.jpg', '000151.jpg', '000152.jpg', '000153.jpg', '000154.jpg', '000155.jpg', '000156.jpg', '000157.jpg', '000158.jpg', '000159.jpg', '000160.jpg', '000161.jpg', '000162.jpg', '000163.jpg', '000164.jpg', '000165.jpg', '000166.jpg', '000167.jpg', '000168.jpg', '000169.jpg', '000170.jpg', '000171.jpg', '000172.jpg', '000173.jpg', '000174.jpg', '000175.jpg', '000176.jpg', '000177.jpg', '000178.jpg', '000179.jpg', '000180.jpg', '000181.jpg', '000182.jpg', '000183.jpg', '000184.jpg', '000185.jpg', '000186.jpg', '000187.jpg', '000188.jpg', '000189.jpg', '000190.jpg', '000191.jpg', '000192.jpg', '000193.jpg', '000194.jpg', '000195.jpg', '000196.jpg', '000197.jpg', '000198.jpg', '000199.jpg', '000200.jpg', '000201.jpg', '000202.jpg', '000203.jpg', '000204.jpg', '000205.jpg', '000206.jpg']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 207/207 [00:01<00:00, 123.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Video saved at /scratch/xz653/datasets/hypernerf/vrig-3dprinter/instance/2x/left1/groundingsam.mp4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# init video predictor state\n",
    "inference_state = video_predictor.init_state(video_path=frame_dir)\n",
    "step = REDETECT_INTERVAL # the step to sample frames for Grounding DINO predictor\n",
    "\n",
    "sam2_masks = MaskDictionaryModel()\n",
    "PROMPT_TYPE_FOR_VIDEO = \"mask\" # box, mask or point\n",
    "objects_count = 0\n",
    "frame_object_count = {}\n",
    "\n",
    "\n",
    "sam2_masks = MaskDictionaryModel()\n",
    "PROMPT_TYPE_FOR_VIDEO = \"mask\" # box, mask or point\n",
    "objects_count = 0\n",
    "frame_object_count = {}\n",
    "frame_names = [\n",
    "    p for p in os.listdir(frame_dir)\n",
    "    if os.path.splitext(p)[-1] in [\".jpg\", \".jpeg\", \".JPG\", \".JPEG\"]\n",
    "]\n",
    "frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))\n",
    "text = TEXT_PROMPT\n",
    "device = DEVICE\n",
    "video_dir = frame_dir\n",
    "output_video_path = osp.join(OUTPUT_PATH, CAM_TAG, \"groundingsam.mp4\")\n",
    "\n",
    "mask_data_dir = osp.join(OUTPUT_PATH, CAM_TAG, \"mask_data\")\n",
    "json_data_dir = osp.join(OUTPUT_PATH, CAM_TAG, \"json_data\")\n",
    "result_dir = osp.join(OUTPUT_PATH, CAM_TAG, \"img_result\")\n",
    "for d in [mask_data_dir, json_data_dir, result_dir]:\n",
    "    os.makedirs(d, exist_ok=True)\n",
    "\n",
    "\"\"\"\n",
    "Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames\n",
    "\"\"\"\n",
    "print(\"Total frames:\", len(frame_names))\n",
    "for start_frame_idx in range(0, len(frame_names), step):\n",
    "# prompt grounding dino to get the box coordinates on specific frame\n",
    "    print(\"start_frame_idx\", start_frame_idx)\n",
    "    # continue\n",
    "    img_path = os.path.join(video_dir, frame_names[start_frame_idx])\n",
    "    image = Image.open(img_path).convert(\"RGB\")\n",
    "    image_base_name = frame_names[start_frame_idx].split(\".\")[0]\n",
    "    mask_dict = MaskDictionaryModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f\"mask_{image_base_name}.npy\")\n",
    "\n",
    "    # run Grounding DINO on the image\n",
    "    inputs = processor(images=image, text=text, return_tensors=\"pt\").to(device)\n",
    "    with torch.no_grad():\n",
    "        outputs = grounding_model(**inputs)\n",
    "\n",
    "    results = processor.post_process_grounded_object_detection(\n",
    "        outputs,\n",
    "        inputs.input_ids,\n",
    "        box_threshold=BOX_THRESHOLD,\n",
    "        text_threshold=TEXT_THRESHOLD,\n",
    "        target_sizes=[image.size[::-1]]\n",
    "    )\n",
    "\n",
    "    # prompt SAM image predictor to get the mask for the object\n",
    "    image_predictor.set_image(np.array(image.convert(\"RGB\")))\n",
    "\n",
    "    # process the detection results\n",
    "    input_boxes = results[0][\"boxes\"] # .cpu().numpy()\n",
    "    # print(\"results[0]\", results[0])\n",
    "    OBJECTS = results[0][\"labels\"]\n",
    "    if input_boxes.shape[0] != 0:\n",
    "        # prompt SAM 2 image predictor to get the mask for the object\n",
    "        masks, scores, logits = image_predictor.predict(\n",
    "            point_coords=None,\n",
    "            point_labels=None,\n",
    "            box=input_boxes,\n",
    "            multimask_output=False,\n",
    "        )\n",
    "        # convert the mask shape to (n, H, W)\n",
    "        if masks.ndim == 2:\n",
    "            masks = masks[None]\n",
    "            scores = scores[None]\n",
    "            logits = logits[None]\n",
    "        elif masks.ndim == 4:\n",
    "            masks = masks.squeeze(1)\n",
    "        \"\"\"\n",
    "        Step 3: Register each object's positive points to video predictor\n",
    "        \"\"\"\n",
    "\n",
    "        # If you are using point prompts, we uniformly sample positive points based on the mask\n",
    "        if mask_dict.promote_type == \"mask\":\n",
    "            mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)\n",
    "        else:\n",
    "            raise NotImplementedError(\"SAM 2 video predictor only support mask prompts\")\n",
    "    else:\n",
    "        print(\"No object detected in the frame, skip merge the frame merge {}\".format(frame_names[start_frame_idx]))\n",
    "        mask_dict = sam2_masks\n",
    "\n",
    "    \"\"\"\n",
    "    Step 4: Propagate the video predictor to get the segmentation results for each frame\n",
    "    \"\"\"\n",
    "    objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)\n",
    "    frame_object_count[start_frame_idx] = objects_count\n",
    "    print(\"objects_count\", objects_count)\n",
    "    \n",
    "    if len(mask_dict.labels) == 0:\n",
    "        mask_dict.save_empty_mask_and_json(mask_data_dir, json_data_dir, image_name_list = frame_names[start_frame_idx:start_frame_idx+step])\n",
    "        print(\"No object detected in the frame, skip the frame {}\".format(start_frame_idx))\n",
    "        continue\n",
    "    else:\n",
    "        video_predictor.reset_state(inference_state)\n",
    "\n",
    "        for object_id, object_info in mask_dict.labels.items():\n",
    "            frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(\n",
    "                    inference_state,\n",
    "                    start_frame_idx,\n",
    "                    object_id,\n",
    "                    object_info.mask,\n",
    "                )\n",
    "        \n",
    "        video_segments = {}  # output the following {step} frames tracking masks\n",
    "        for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):\n",
    "            frame_masks = MaskDictionaryModel()\n",
    "            \n",
    "            for i, out_obj_id in enumerate(out_obj_ids):\n",
    "                out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()\n",
    "                object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id), logit=mask_dict.get_target_logit(out_obj_id))\n",
    "                object_info.update_box()\n",
    "                frame_masks.labels[out_obj_id] = object_info\n",
    "                image_base_name = frame_names[out_frame_idx].split(\".\")[0]\n",
    "                frame_masks.mask_name = f\"mask_{image_base_name}.npy\"\n",
    "                frame_masks.mask_height = out_mask.shape[-2]\n",
    "                frame_masks.mask_width = out_mask.shape[-1]\n",
    "\n",
    "            video_segments[out_frame_idx] = frame_masks\n",
    "            sam2_masks = copy.deepcopy(frame_masks)\n",
    "\n",
    "        print(\"video_segments:\", len(video_segments))\n",
    "    \"\"\"\n",
    "    Step 5: save the tracking masks and json files\n",
    "    \"\"\"\n",
    "    for frame_idx, frame_masks_info in video_segments.items():\n",
    "        mask = frame_masks_info.labels\n",
    "        mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)\n",
    "        for obj_id, obj_info in mask.items():\n",
    "            mask_img[obj_info.mask == True] = obj_id\n",
    "\n",
    "        mask_img = mask_img.numpy().astype(np.uint16)\n",
    "        np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)\n",
    "\n",
    "        json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(\".npy\", \".json\"))\n",
    "        frame_masks_info.to_json(json_data_path)\n",
    "       \n",
    "\n",
    "CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)\n",
    "if REVERSE_TRACK:\n",
    "    for d in [mask_data_dir, json_data_dir, result_dir]:\n",
    "        reverse_files(d)\n",
    "\n",
    "create_video_from_images(result_dir, output_video_path, frame_rate=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transform Masks To Each Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Labels of each instance: ['##printer head', '##printer head', 'cable'] . Please check them to be aligned when using multiple cameras\n"
     ]
    }
   ],
   "source": [
    "print(\"Labels of each instance:\", OBJECTS, \". Please check them to be aligned when using multiple cameras\")\n",
    "class_map = {i: i for i in range(len(OBJECTS)+1)}\n",
    "# class_map = {0: 0, 1: 2, 2: 1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 207/207 [00:16<00:00, 12.45it/s]\n"
     ]
    }
   ],
   "source": [
    "# transform the results to a standard format\n",
    "names = [f\"{OBJECTS[class_map[name_i]-1]}:{'kinematic' if OBJECTS[class_map[name_i]-1] == 'person' else 'deformable'}\" for name_i in range(1, 1 + len(OBJECTS))]\n",
    "out_path = Path(OUTPUT_PATH)\n",
    "(out_path / \"names.json\").write_text(json.dumps(names))\n",
    "(out_path / f\"names.{CAM_TAG}.json\").write_text(json.dumps(names))\n",
    "\n",
    "imask_path = out_path / \"imask\"\n",
    "cmask_path = out_path / \"cmask\"\n",
    "\n",
    "imask_path.mkdir(parents=True, exist_ok=True)\n",
    "cmask_path.mkdir(parents=True, exist_ok=True)\n",
    "class_map_np = np.vectorize(class_map.get)\n",
    "\n",
    "for mask_path in tqdm(list(Path(mask_data_dir).glob(\"*.npy\"))):\n",
    "    imask = np.load(mask_path)\n",
    "    dtype = imask.dtype\n",
    "    imask = class_map_np(imask).astype(dtype)\n",
    "    cmask = label2rgb(imask, bg_label=0)\n",
    "\n",
    "    imask = Image.fromarray(imask)\n",
    "    cmask = Image.fromarray((cmask * 255).astype(np.uint8))\n",
    "\n",
    "    img_id = mask_path.name.split(\".\")[0].split(\"_\")[-1]\n",
    "\n",
    "    if CAM_TAG: img_id = f\"{CAM_TAG}_{img_id}\"\n",
    "\n",
    "    imask.save(imask_path / (img_id + '.png'))\n",
    "    cmask.save(cmask_path / (img_id + '.png')) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transform Masks To Single Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 207/207 [00:21<00:00,  9.74it/s]\n"
     ]
    }
   ],
   "source": [
    "names = [\"object:deformable\"]\n",
    "out_path = Path(OUTPUT_PATH)\n",
    "(out_path / \"names.json\").write_text(json.dumps(names))\n",
    "\n",
    "imask_path = out_path / \"imask\"\n",
    "cmask_path = out_path / \"cmask\"\n",
    "\n",
    "imask_path.mkdir(parents=True, exist_ok=True)\n",
    "cmask_path.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "for mask_path in tqdm(list(Path(mask_data_dir).glob(\"*.npy\"))):\n",
    "    imask = np.load(mask_path)\n",
    "    dtype = imask.dtype\n",
    "    imask = (imask > 0).astype(dtype)\n",
    "    cmask = label2rgb(imask, bg_label=0)\n",
    "\n",
    "    imask = Image.fromarray(imask)\n",
    "    cmask = Image.fromarray((cmask * 255).astype(np.uint8))\n",
    "\n",
    "    img_id = mask_path.name.split(\".\")[0].split(\"_\")[-1]\n",
    "\n",
    "    if CAM_TAG: img_id = f\"{CAM_TAG}_{img_id}\"\n",
    "\n",
    "    imask.save(imask_path / (img_id + '.png'))\n",
    "    cmask.save(cmask_path / (img_id + '.png')) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
