{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f400486b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) Meta Platforms, Inc. and affiliates."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1ae39ff",
   "metadata": {},
   "source": [
    "# Object masks from prompts with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4a4b25c",
   "metadata": {},
   "source": [
    "The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt. \n",
    "\n",
    "The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18ab8c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, HTML\n",
    "display(HTML(\n",
    "\"\"\"\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\"\"\"\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "644532a8",
   "metadata": {},
   "source": [
    "## Environment Set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07fabfee",
   "metadata": {},
   "source": [
    "If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_collab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ea65efc",
   "metadata": {},
   "outputs": [],
   "source": [
    "using_colab = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91dd9a89",
   "metadata": {},
   "outputs": [],
   "source": [
    "if using_colab:\n",
    "    import torch\n",
    "    import torchvision\n",
    "    print(\"PyTorch version:\", torch.__version__)\n",
    "    print(\"Torchvision version:\", torchvision.__version__)\n",
    "    print(\"CUDA is available:\", torch.cuda.is_available())\n",
    "    import sys\n",
    "    !{sys.executable} -m pip install opencv-python matplotlib\n",
    "    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'\n",
    "\n",
    "    !mkdir images\n",
    "    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg\n",
    "    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg\n",
    "\n",
    "    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0be845da",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "33681dd1",
   "metadata": {},
   "source": [
    "Necessary imports and helper functions for displaying points, boxes, and masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69b28288",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29bc90d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_mask(mask, ax, random_color=False):\n",
    "    if random_color:\n",
    "        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
    "    else:\n",
    "        color = np.array([30/255, 144/255, 255/255, 0.6])\n",
    "    h, w = mask.shape[-2:]\n",
    "    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
    "    ax.imshow(mask_image)\n",
    "\n",
    "def show_points(coords, labels, ax, marker_size=375):\n",
    "    pos_points = coords[labels==1]\n",
    "    neg_points = coords[labels==0]\n",
    "    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
    "\n",
    "def show_box(box, ax):\n",
    "    x0, y0 = box[0], box[1]\n",
    "    w, h = box[2] - box[0], box[3] - box[1]\n",
    "    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23842fb2",
   "metadata": {},
   "source": [
    "## Example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c2e4f6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "image = cv2.imread('images/truck.jpg')\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e30125fd",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "plt.axis('on')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98b228b8",
   "metadata": {},
   "source": [
    "## Selecting objects with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bb1927b",
   "metadata": {},
   "source": [
    "First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17ccff22",
   "metadata": {},
   "outputs": [],
   "source": [
    "sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
    "device = \"cuda\"\n",
    "model_type = \"default\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e28150b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from segment_anything import sam_model_registry, SamPredictor\n",
    "\n",
    "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
    "sam.to(device=device)\n",
    "\n",
    "predictor = SamPredictor(sam)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c925e829",
   "metadata": {},
   "source": [
    "Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d95d48dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.set_image(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8fc7a46",
   "metadata": {},
   "source": [
    "To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c69570c",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[500, 375]])\n",
    "input_label = np.array([1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a91ba973",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('on')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c765e952",
   "metadata": {},
   "source": [
    "Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5373fd68",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, scores, logits = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    multimask_output=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7f0e938",
   "metadata": {},
   "source": [
    "With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47821187",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks.shape  # (number_of_masks) x H x W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9c227a6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for i, (mask, score) in enumerate(zip(masks, scores)):\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.imshow(image)\n",
    "    show_mask(mask, plt.gca())\n",
    "    show_points(input_point, input_label, plt.gca())\n",
    "    plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fa31f7c",
   "metadata": {},
   "source": [
    "## Specifying a specific object with additional points"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88d6d29a",
   "metadata": {},
   "source": [
    "The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6923b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[500, 375], [1125, 625]])\n",
    "input_label = np.array([1, 1])\n",
    "\n",
    "mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d98f96a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ce8b82f",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e06d5c8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c93e2087",
   "metadata": {},
   "source": [
    "To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a196f68",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_point = np.array([[500, 375], [1125, 625]])\n",
    "input_label = np.array([1, 0])\n",
    "\n",
    "mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81a52282",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    mask_input=mask_input[None, :, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfca709f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41e2d5a9",
   "metadata": {},
   "source": [
    "## Specifying a specific object with a box"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d61ca7ac",
   "metadata": {},
   "source": [
    "The model can also take a box as input, provided in xyxy format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ea92a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_box = np.array([425, 600, 700, 875])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b35a8814",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=None,\n",
    "    point_labels=None,\n",
    "    box=input_box[None, :],\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "984b79c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0], plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c1ed9f0a",
   "metadata": {},
   "source": [
    "## Combining points and boxes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8455d1c5",
   "metadata": {},
   "source": [
    "Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90e2e547",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_box = np.array([425, 600, 700, 875])\n",
    "input_point = np.array([[575, 750]])\n",
    "input_label = np.array([0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6956d8c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks, _, _ = predictor.predict(\n",
    "    point_coords=input_point,\n",
    "    point_labels=input_label,\n",
    "    box=input_box,\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e13088a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "show_mask(masks[0], plt.gca())\n",
    "show_box(input_box, plt.gca())\n",
    "show_points(input_point, input_label, plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45ddbca3",
   "metadata": {},
   "source": [
    "## Batched prompt inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df6f18a0",
   "metadata": {},
   "source": [
    "SamPredictor can take multiple input prompts for the same image, using `predict_torch` method. This method assumes input points are already torch tensors and have already been transformed to the input frame. For example, imagine we have several box outputs from an object detector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a06681b",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_boxes = torch.tensor([\n",
    "    [75, 275, 1725, 850],\n",
    "    [425, 600, 700, 875],\n",
    "    [1375, 550, 1650, 800],\n",
    "    [1240, 675, 1400, 750],\n",
    "], device=predictor.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf957d16",
   "metadata": {},
   "source": [
    "Transform the boxes to the input frame, then predict masks. `SamPredictor` stores the necessary transform as the `transform` field for easy access, though it can also be instantiated directly for use in e.g. a dataloader (see `segment_anything.utils.transforms`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "117521a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])\n",
    "masks, _, _ = predictor.predict_torch(\n",
    "    point_coords=None,\n",
    "    point_labels=None,\n",
    "    boxes=transformed_boxes,\n",
    "    multimask_output=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a8f5d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "masks.shape  # (batch_size) x (num_predicted_masks_per_input) x H x W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c00c3681",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "plt.imshow(image)\n",
    "for mask in masks:\n",
    "    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)\n",
    "for box in input_boxes:\n",
    "    show_box(box.cpu().numpy(), plt.gca())\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8bea70c0",
   "metadata": {},
   "source": [
    "## End-to-end batched inference"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89c3ba52",
   "metadata": {},
   "source": [
    "If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c01ae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "image1 = image  # truck.jpg from above\n",
    "image1_boxes = torch.tensor([\n",
    "    [75, 275, 1725, 850],\n",
    "    [425, 600, 700, 875],\n",
    "    [1375, 550, 1650, 800],\n",
    "    [1240, 675, 1400, 750],\n",
    "], device=sam.device)\n",
    "\n",
    "image2 = cv2.imread('images/groceries.jpg')\n",
    "image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
    "image2_boxes = torch.tensor([\n",
    "    [450, 170, 520, 350],\n",
    "    [350, 190, 450, 350],\n",
    "    [500, 170, 580, 350],\n",
    "    [580, 170, 640, 350],\n",
    "], device=sam.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce56c57d",
   "metadata": {},
   "source": [
    "Both images and prompts are input as PyTorch tensors that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:\n",
    "* `image`: The input image as a PyTorch tensor in CHW format.\n",
    "* `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.\n",
    "* `point_coords`: Batched coordinates of point prompts.\n",
    "* `point_labels`: Batched labels of point prompts.\n",
    "* `boxes`: Batched input boxes.\n",
    "* `mask_inputs`: Batched input masks.\n",
    "\n",
    "If a prompt is not present, the key can be excluded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f908ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from segment_anything.utils.transforms import ResizeLongestSide\n",
    "resize_transform = ResizeLongestSide(sam.image_encoder.img_size)\n",
    "\n",
    "def prepare_image(image, transform, device):\n",
    "    image = transform.apply_image(image)\n",
    "    image = torch.as_tensor(image, device=device.device)\n",
    "    return image.permute(2, 0, 1).contiguous()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23f63723",
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_input = [\n",
    "     {\n",
    "         'image': prepare_image(image1, resize_transform, sam),\n",
    "         'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),\n",
    "         'original_size': image1.shape[:2]\n",
    "     },\n",
    "     {\n",
    "         'image': prepare_image(image2, resize_transform, sam),\n",
    "         'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),\n",
    "         'original_size': image2.shape[:2]\n",
    "     }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fbeb831",
   "metadata": {},
   "source": [
    "Run the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b311b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_output = sam(batched_input, multimask_output=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27bb50fd",
   "metadata": {},
   "source": [
    "The output is a list over results for each input image, where list elements are dictionaries with the following keys:\n",
    "* `masks`: A batched torch tensor of predicted binary masks, the size of the original image.\n",
    "* `iou_predictions`: The model's prediction of the quality for each mask.\n",
    "* `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb3dba0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "batched_output[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1108f48",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(20, 20))\n",
    "\n",
    "ax[0].imshow(image1)\n",
    "for mask in batched_output[0]['masks']:\n",
    "    show_mask(mask.cpu().numpy(), ax[0], random_color=True)\n",
    "for box in image1_boxes:\n",
    "    show_box(box.cpu().numpy(), ax[0])\n",
    "ax[0].axis('off')\n",
    "\n",
    "ax[1].imshow(image2)\n",
    "for mask in batched_output[1]['masks']:\n",
    "    show_mask(mask.cpu().numpy(), ax[1], random_color=True)\n",
    "for box in image2_boxes:\n",
    "    show_box(box.cpu().numpy(), ax[1])\n",
    "ax[1].axis('off')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
