{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X5_RGz_P9R46"
      },
      "outputs": [],
      "source": [
        "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvTp4UVz9R5A"
      },
      "source": [
        "# Inference Tutorial\n",
        "\n",
        "In this tutorial you'll learn:\n",
        "- [How to load an Omnivore model](#Load-Model)\n",
        "- [Inference with Images](#Inference-with-Images)\n",
        "- [Inference with Videos](#Inference-with-Videos)\n",
        "- [Inference with RGBD images](#Inference-with-RGBD-Images)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-S5NNdsM9R5C"
      },
      "source": [
        "### Install modules \n",
        "\n",
        "We assume that torch and torchvision have already installed using the instructions in the [README.md](https://github.com/facebookresearch/omnivore/blob/main/README.md#setup-and-installation). We install the other dependencies required for using Omnivore models - `einops`, `pytorchvideo` and `timm`.\n",
        "\n",
        "For this tutorial, we additionally install `ipywidgets` and `matplotlib`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFUBrIGt9R5D"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "\n",
        "!{sys.executable} -m pip install einops pytorchvideo timm -q\n",
        "\n",
        "# only needed for the tutorial\n",
        "# if the video rendering doesn't work, restart the kernel after installation\n",
        "!{sys.executable} -m pip install ipywidgets matplotlib -q"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PF6cS6LM9R5E"
      },
      "source": [
        "### Import modules"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rb8K79k79R5F"
      },
      "outputs": [],
      "source": [
        "import os \n",
        "\n",
        "try:\n",
        "    from omnivore.transforms import SpatialCrop, TemporalCrop, DepthNorm\n",
        "except:\n",
        "    # need to also make the omnivore transform module available\n",
        "    !git clone https://github.com/facebookresearch/omnivore.git\n",
        "    sys.path.append(\"./omnivore\")\n",
        "\n",
        "    from omnivore.transforms import SpatialCrop, TemporalCrop, DepthNorm\n",
        "\n",
        "import csv\n",
        "import json\n",
        "from typing import List\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "import torchvision.transforms as T\n",
        "from PIL import Image\n",
        "from pytorchvideo.data.encoded_video import EncodedVideo\n",
        "from torchvision.transforms._transforms_video import NormalizeVideo\n",
        "\n",
        "from pytorchvideo.transforms import (\n",
        "    ApplyTransformToKey,\n",
        "    ShortSideScale,\n",
        "    UniformTemporalSubsample,\n",
        ")\n",
        "\n",
        "%matplotlib inline\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib.image as mpimg\n",
        "from ipywidgets import Video"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cwylyPKE9R5H"
      },
      "source": [
        "<a id=’load-model’></a>\n",
        "# Load Model\n",
        "\n",
        "We provide several pretrained Omnivore models via Torch Hub. Available models are described in [model zoo documentation](https://github.com/facebookresearch/omnivore/blob/main/README.md#model-zoo). \n",
        "\n",
        "Here we are selecting the `omnivore_swinB` model which was trained on Imagenet 1K, Kinetics 400 and SUN RGBD. \n",
        "\n",
        "**NOTE**: to run on GPU in Google Colab, in the menu bar selet: Runtime -> Change runtime type -> Harware Accelerator -> GPU\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qHfPB3mo9R5I"
      },
      "outputs": [],
      "source": [
        "# Device on which to run the model\n",
        "# Set to cuda to load on GPU\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" \n",
        "\n",
        "# Pick a pretrained model \n",
        "model_name = \"omnivore_swinB\"\n",
        "model = torch.hub.load(\"facebookresearch/omnivore:main\", model=model_name, force_reload=True)\n",
        "\n",
        "# Set to eval mode and move to desired device\n",
        "model = model.to(device)\n",
        "model = model.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BQuM7a0m9R5K"
      },
      "source": [
        "# Inference with Images\n",
        "\n",
        "First we'll load an image and use the Omnivore model to classify it. \n",
        "\n",
        "### Setup\n",
        "\n",
        "Download the id to label mapping for the Imagenet1K dataset. This will be used to get the category label names from the predicted class ids."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "482YS3Rj9R5L"
      },
      "outputs": [],
      "source": [
        "!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4yxx-F569R5M"
      },
      "outputs": [],
      "source": [
        "with open(\"imagenet_class_index.json\", \"r\") as f:\n",
        "    imagenet_classnames = json.load(f)\n",
        "\n",
        "# Create an id to label name mapping\n",
        "imagenet_id_to_classname = {}\n",
        "for k, v in imagenet_classnames.items():\n",
        "    imagenet_id_to_classname[k] = v[1] "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gRDpifXh9R5N"
      },
      "source": [
        "### Load and visualize the image\n",
        "\n",
        "You can download the test image in the cell below or specify a path to your own image. Before passing the image into the model we need to apply some input transforms. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vGVAZ2pE9R5O"
      },
      "outputs": [],
      "source": [
        "# Download the example image file\n",
        "!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5Lghqzew9R5P"
      },
      "outputs": [],
      "source": [
        "image_path = \"library.jpg\"\n",
        "image = Image.open(image_path).convert(\"RGB\")\n",
        "plt.figure(figsize=(6, 6))\n",
        "plt.imshow(image)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zxbQQu2J9R5P"
      },
      "outputs": [],
      "source": [
        "image_transform = T.Compose(\n",
        "    [\n",
        "        T.Resize(224),\n",
        "        T.CenterCrop(224),\n",
        "        T.ToTensor(),\n",
        "        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
        "    ]\n",
        ")\n",
        "image = image_transform(image)\n",
        "\n",
        "# The model expects inputs of shape: B x C x T x H x W\n",
        "image = image[None, :, None, ...]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2pnAKqiy9R5Q"
      },
      "source": [
        "### Run the model \n",
        "\n",
        "The transformed image can be passed through the model to get class predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xf1usPrp9R5R"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    prediction = model(image.to(device), input_type=\"image\")\n",
        "    pred_classes = prediction.topk(k=5).indices\n",
        "\n",
        "pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]\n",
        "print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b1ub-hwp9R5R"
      },
      "source": [
        "# Inference with Videos\n",
        "\n",
        "Now we'll see how to use the Omnivore model to classify a video. \n",
        "\n",
        "\n",
        "### Setup \n",
        "\n",
        "Download the id to label mapping for the Kinetics 400 dataset \n",
        "This will be used to get the category label names from the predicted class ids."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SiIZ-_T29R5S"
      },
      "outputs": [],
      "source": [
        "!wget https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json -O kinetics_classnames.json"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eQ14kfY89R5S"
      },
      "outputs": [],
      "source": [
        "with open(\"kinetics_classnames.json\", \"r\") as f:\n",
        "    kinetics_classnames = json.load(f)\n",
        "\n",
        "# Create an id to label name mapping\n",
        "kinetics_id_to_classname = {}\n",
        "for k, v in kinetics_classnames.items():\n",
        "    kinetics_id_to_classname[v] = str(k).replace('\"', \"\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "alMw8PVv9R5T"
      },
      "source": [
        "### Define the transformations for the input required by the model\n",
        "\n",
        "Before passing the video into the model we need to apply some input transforms and sample a clip of the correct duration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xg0vPreI9R5T"
      },
      "outputs": [],
      "source": [
        "num_frames = 160\n",
        "sampling_rate = 2\n",
        "frames_per_second = 30\n",
        "\n",
        "clip_duration = (num_frames * sampling_rate) / frames_per_second\n",
        "\n",
        "video_transform = ApplyTransformToKey(\n",
        "    key=\"video\",\n",
        "    transform=T.Compose(\n",
        "        [\n",
        "            UniformTemporalSubsample(num_frames), \n",
        "            T.Lambda(lambda x: x / 255.0),  \n",
        "            ShortSideScale(size=224),\n",
        "            NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
        "            TemporalCrop(frames_per_clip=32, stride=40),\n",
        "            SpatialCrop(crop_size=224, num_crops=3),\n",
        "        ]\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "45CBNmWX9R5U"
      },
      "source": [
        "### Load and visualize an example video\n",
        "We can test the classification of an example video from the kinetics validation set such as this [archery video](https://www.youtube.com/watch?v=3and4vWkW4s)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RcXsmi0e9R5U"
      },
      "outputs": [],
      "source": [
        "# Download the example video file\n",
        "!wget https://dl.fbaipublicfiles.com/omnivore/example_data/dance.mp4 -O dance.mp4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Pg9vyCe9R5V"
      },
      "outputs": [],
      "source": [
        "# Load the example video\n",
        "video_path = \"dance.mp4\" \n",
        "\n",
        "Video.from_file(video_path, width=500)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vGYZedo99R5V"
      },
      "outputs": [],
      "source": [
        "# We crop the video to a smaller resolution and duration to save RAM\n",
        "!ffmpeg -y -ss 0 -i dance.mp4 -filter:v scale=224:-1 -t 1 -v 0 dance_cropped.mp4\n",
        "\n",
        "video_path = \"dance_cropped.mp4\" "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qekSOQWt9R5W"
      },
      "outputs": [],
      "source": [
        "# Initialize an EncodedVideo helper class\n",
        "video = EncodedVideo.from_path(video_path)\n",
        "\n",
        "# Load the desired clip and specify the start and end duration.\n",
        "# The start_sec should correspond to where the action occurs in the video\n",
        "video_data = video.get_clip(start_sec=0, end_sec=2.0)\n",
        "\n",
        "# Apply a transform to normalize the video input\n",
        "video_data = video_transform(video_data)\n",
        "\n",
        "# Move the inputs to the desired device\n",
        "video_inputs = video_data[\"video\"]\n",
        "\n",
        "# Take the first clip \n",
        "# The model expects inputs of shape: B x C x T x H x W\n",
        "video_input = video_inputs[0][None, ...]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jimUjTH49R5W"
      },
      "source": [
        "### Get model predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A_Mq1iKd9R5W"
      },
      "outputs": [],
      "source": [
        "# Pass the input clip through the model \n",
        "with torch.no_grad():\n",
        "    prediction = model(video_input.to(device), input_type=\"video\")\n",
        "\n",
        "    # Get the predicted classes \n",
        "    pred_classes = prediction.topk(k=5).indices\n",
        "\n",
        "# Map the predicted classes to the label names\n",
        "pred_class_names = [kinetics_id_to_classname[int(i)] for i in pred_classes[0]]\n",
        "print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iKJXLfXl9R5X"
      },
      "source": [
        "# Inference with RGBD Images\n",
        "\n",
        "Now we'll see how to use the Omnivore model to classify an image with a depth map. \n",
        "\n",
        "\n",
        "### Setup \n",
        "\n",
        "Download the id to label mapping for the SUN RGBD dataset with 19 classes. \n",
        "This will be used to get the category label names from the predicted class ids."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iJBEELKL9R5X"
      },
      "outputs": [],
      "source": [
        "!wget https://dl.fbaipublicfiles.com/omnivore/sunrgbd_classnames.json -O sunrgbd_classnames.json "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uGy1HnxH9R5Y"
      },
      "outputs": [],
      "source": [
        "with open(\"sunrgbd_classnames.json\", \"r\") as f:\n",
        "    sunrgbd_id_to_classname = json.load(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AJ_gNmmE9R5Y"
      },
      "source": [
        "### Define the transformations for the input required by the model\n",
        "\n",
        "Before passing the RGBD image into the model we need to apply some input transforms. \n",
        "\n",
        "Here, the depth statistics (max, mean, std) are based on the SUN RGBD dataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M5CqLA9m9R5Z"
      },
      "outputs": [],
      "source": [
        "rgbd_transform = T.Compose(\n",
        "    [\n",
        "        DepthNorm(max_depth=75.0, clamp_max_before_scale=True),\n",
        "        T.Resize(224),\n",
        "        T.CenterCrop(224),\n",
        "        T.Normalize(\n",
        "            mean=[0.485, 0.456, 0.406, 0.0418], \n",
        "            std=[0.229, 0.224, 0.225, 0.0295]\n",
        "        ),\n",
        "    ]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bT-6CSHU9R5Z"
      },
      "source": [
        "### Download an example image and depth\n",
        "\n",
        "We can test the classification of an example RGBD image from the SUN RGBD validation set. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U2q0nOYF9R5Z"
      },
      "outputs": [],
      "source": [
        "# Download the example image and disparity file\n",
        "!wget -O store.png https://upload.wikimedia.org/wikipedia/commons/thumb/f/f4/Interior_of_the_IKEA_B%C4%83neasa_33.jpg/791px-Interior_of_the_IKEA_B%C4%83neasa_33.jpg\n",
        "!wget -O store_disparity.pt https://dl.fbaipublicfiles.com/omnivore/example_data/store_disparity.pt "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k9mAgJkb9R5a"
      },
      "source": [
        "### Load and visualize the image and depth map\n",
        "\n",
        "(Notice the price tags on the furniture items in the image of the bedroom)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gpx1W3Il9R5a"
      },
      "outputs": [],
      "source": [
        "image_path = \"./store.png\"\n",
        "depth_path = \"./store_disparity.pt\"\n",
        "image = Image.open(image_path).convert(\"RGB\")\n",
        "depth = torch.load(depth_path)[None, ...]\n",
        "\n",
        "plt.figure(figsize=(20, 10))\n",
        "plt.subplot(1, 2, 1)\n",
        "plt.title(\"RGB\", fontsize=20)\n",
        "plt.imshow(image)\n",
        "plt.subplot(1, 2, 2)\n",
        "plt.imshow(depth.numpy().squeeze())\n",
        "plt.title(\"Depth\", fontsize=20)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e5wYmU229R5b"
      },
      "outputs": [],
      "source": [
        "# Convert to tensor and transform\n",
        "image = T.ToTensor()(image)\n",
        "rgbd = torch.cat([image, depth], dim=0)\n",
        "rgbd = rgbd_transform(rgbd)\n",
        "\n",
        "# The model expects inputs of shape: B x C x T x H x W\n",
        "rgbd_input = rgbd[None, :, None, ...]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fE0ngqcm9R5b"
      },
      "source": [
        "### Get model predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BnU_eHHr9R5b"
      },
      "outputs": [],
      "source": [
        "with torch.no_grad():\n",
        "    prediction = model(rgbd_input.to(device), input_type=\"rgbd\")\n",
        "    pred_classes = prediction.topk(k=5).indices\n",
        "\n",
        "pred_class_names = [sunrgbd_id_to_classname[str(i.item())] for i in pred_classes[0]]\n",
        "print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E-Cxg7bM9R5c"
      },
      "source": [
        "# Inference with Egocentric videos\n",
        "\n",
        "We can also use the Omnivore model to classify an egocentric video from the EPIC Kitchens dataset. \n",
        "\n",
        "### Setup \n",
        "\n",
        "First load the omnivore model trained on the epic kitchens dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0cewNDk39R5c"
      },
      "outputs": [],
      "source": [
        "model_name = \"omnivore_swinB_epic\"\n",
        "model = torch.hub.load(\"facebookresearch/omnivore:main\", model=model_name)\n",
        "\n",
        "# Set to eval mode and move to desired device\n",
        "model = model.to(device)\n",
        "model = model.eval()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d6a6Knb99R5d"
      },
      "source": [
        "Download the id to label mapping for the EPIC Kitchens 100 dataset. \n",
        "This will be used to get the action class names from the predicted class ids."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jUJnRmTN9R5d"
      },
      "outputs": [],
      "source": [
        "!wget https://dl.fbaipublicfiles.com/omnivore/epic_action_classes.csv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RBIdxnI39R5d"
      },
      "outputs": [],
      "source": [
        "with open('epic_action_classes.csv', mode='r') as inp:\n",
        "    reader = csv.reader(inp)\n",
        "    epic_id_to_action = {idx: \" \".join(rows) for idx, rows in enumerate(reader)}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvsAiKPn9R5e"
      },
      "source": [
        "### Define the transformations for the input required by the model\n",
        "\n",
        "Before passing the video into the model we need to apply some input transforms and sample a clip of the correct duration."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SuBKJZ4C9R5e"
      },
      "outputs": [],
      "source": [
        "num_frames = 32\n",
        "sampling_rate = 2\n",
        "frames_per_second = 30\n",
        "\n",
        "clip_duration = (num_frames * sampling_rate) / frames_per_second\n",
        "\n",
        "video_transform = ApplyTransformToKey(\n",
        "    key=\"video\",\n",
        "    transform=T.Compose(\n",
        "        [\n",
        "            UniformTemporalSubsample(num_frames), \n",
        "            T.Lambda(lambda x: x / 255.0),  \n",
        "            ShortSideScale(size=224),\n",
        "            NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
        "            TemporalCrop(frames_per_clip=32, stride=40),\n",
        "            SpatialCrop(crop_size=224, num_crops=3),\n",
        "        ]\n",
        "    ),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YJ9kH8GL9R5e"
      },
      "source": [
        "### Load and visualize an example video\n",
        "We can test the classification of an example video from the epic kitchens dataset (see [trailer](https://www.youtube.com/watch?v=8IzkrWAfAGg&t=5s) here). "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LIXTPexc9R5f"
      },
      "outputs": [],
      "source": [
        "# Download the example video file\n",
        "!wget https://dl.fbaipublicfiles.com/omnivore/example_data/epic.mp4 -O epic.mp4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7hTRd7FX9R5f"
      },
      "outputs": [],
      "source": [
        "# Load the example video\n",
        "video_path = \"epic.mp4\" \n",
        "\n",
        "Video.from_file(video_path, width=500)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DXa-9woe9R5f"
      },
      "outputs": [],
      "source": [
        "# We crop the video to a smaller resolution and duration to save RAM\n",
        "!ffmpeg -y -ss 0 -i epic.mp4 -filter:v scale=224:-1 -t 1 -v 0 epic_cropped.mp4\n",
        "\n",
        "video_path = \"epic_cropped.mp4\" "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uWxqjD_p9R5g"
      },
      "outputs": [],
      "source": [
        "# Initialize an EncodedVideo helper class\n",
        "video = EncodedVideo.from_path(video_path)\n",
        "\n",
        "# Load the desired clip\n",
        "video_data = video.get_clip(start_sec=0.0, end_sec=2.0)\n",
        "\n",
        "# Apply a transform to normalize the video input\n",
        "video_data = video_transform(video_data)\n",
        "\n",
        "# Move the inputs to the desired device\n",
        "video_inputs = video_data[\"video\"]\n",
        "\n",
        "# Take the first clip \n",
        "# The model expects inputs of shape: B x C x T x H x W\n",
        "video_input = video_inputs[0][None, ...]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eSg2TA6a9R5g"
      },
      "source": [
        "### Get model predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YFciETpn9R5g"
      },
      "outputs": [],
      "source": [
        "# Pass the input clip through the model \n",
        "with torch.no_grad():\n",
        "    prediction = model(video_input.to(device), input_type=\"video\")\n",
        "\n",
        "    # Get the predicted classes \n",
        "    pred_classes = prediction.topk(k=5).indices\n",
        "\n",
        "# Map the predicted classes to the label names\n",
        "pred_class_names = [epic_id_to_action[int(i)] for i in pred_classes[0]]\n",
        "print(\"Top 5 predicted actions: %s\" % \", \".join(pred_class_names))"
      ]
    }
  ],
  "metadata": {
    "bento_stylesheets": {
      "bento/extensions/flow/main.css": true,
      "bento/extensions/kernel_selector/main.css": true,
      "bento/extensions/kernel_ui/main.css": true,
      "bento/extensions/new_kernel/main.css": true,
      "bento/extensions/system_usage/main.css": true,
      "bento/extensions/theme/main.css": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "colab": {
      "name": "inference_tutorial.ipynb",
      "provenance": []
    },
    "accelerator": "GPU",
    "gpuClass": "standard"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}