{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e6f42b3a",
   "metadata": {},
   "source": [
    "# Perception Encoder Demo\n",
    "[![Paper](https://img.shields.io/badge/Paper-Perception%20Encoder-b31b1b.svg)](https://ai.meta.com/research/publications/perception-encoder-the-best-visual-embeddings-are-not-at-the-output-of-the-network) \n",
    "[![Paper](https://img.shields.io/badge/arXiv-2504.13181-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2504.13181)\n",
    "[![Hugging Face Collection](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Collection-blue)](https://huggingface.co/collections/facebook/perception-encoder-67f977c9a65ca5895a7f6ba1)\n",
    "[![Model License](https://img.shields.io/badge/Model_License-Apache_2.0-olive)](https://opensource.org/licenses/Apache-2.0)\n",
    "\n",
    "This notebook provides examples of image and video feature extraction with pre-trained Perception Encoder (PE). These featuire can be used for image and video zero-shot classification and retrieval.\n",
    "\n",
    "You can run the demo locally or run it on Google colab. You can also run it with (faster) or wihtout GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75e90ff6",
   "metadata": {},
   "source": [
    "### Prepare Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "89375a2a",
   "metadata": {},
   "source": [
    "import os, sys\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "# check whether run in Colab\n",
    "if 'google.colab' in sys.modules:\n",
    "    print('Running in Colab.')\n",
    "    !git clone https://github.com/facebookresearch/perception_models.git\n",
    "    !pip install decord\n",
    "    !pip install ftfy\n",
    "    sys.path.append('./perception_models')\n",
    "    os.chdir('./perception_models')\n",
    "else:\n",
    "    sys.path.append('../../../')\n",
    "import decord\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    print('GPU is available. Use GPU for this demo')\n",
    "else:\n",
    "    print('Use CPU for this demo')\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "6e04507d",
   "metadata": {},
   "source": [
    "### Create Model and Preprocess Transform\n",
    "\n",
    "The following code creates a PE-core model and loads pretrained checkpoints. Then it gets the preprocess image transfrom and text tokenizer. The models available are:\n",
    "- PE-core-G14-448 \n",
    "- PE-core-L14-336\n",
    "- PE-core-B16-224"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "94a3c87c",
   "metadata": {},
   "source": [
    "import core.vision_encoder.pe as pe\n",
    "import core.vision_encoder.transforms as transforms\n",
    "\n",
    "model_name = 'PE-Core-G14-448' \n",
    "# models available: ['PE-Core-G14-448', 'PE-Core-L14-336', 'PE-Core-B16-224']\n",
    "\n",
    "model = pe.CLIP.from_config(model_name, pretrained=True)  # Downloads from HF\n",
    "model = model.to(device)\n",
    "\n",
    "preprocess = transforms.get_image_transform(model.image_size)\n",
    "tokenizer = transforms.get_text_tokenizer(model.context_length)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "8587d072",
   "metadata": {},
   "source": [
    "### Example 1: Image and Text Feature Extraction for Zero-shot Image Classification/Retrieval\n",
    "\n",
    "In this example, we extract the embedding of a cat image, and the embeddings of 3 senetence. We measure the cosine similarites between the image and sentences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "38f80aa1",
   "metadata": {},
   "source": [
    "image = preprocess(Image.open(\"./apps/pe/docs/assets/cat.png\")).unsqueeze(0).to(device)\n",
    "captions = [\"a diagram\", \"a dog\", \"a cat\"]\n",
    "text = tokenizer(captions).to(device)\n",
    "with torch.no_grad():\n",
    "    image_features = model.encode_image(image)\n",
    "    text_features = model.encode_text(text)\n",
    "    image_features /= image_features.norm(dim=-1, keepdim=True)\n",
    "    text_features /= text_features.norm(dim=-1, keepdim=True)\n",
    "    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1).cpu().numpy()[0]\n",
    "\n",
    "plt.imshow(Image.open(\"./apps/pe/docs/assets/cat.png\"))\n",
    "plt.axis('off')\n",
    "plt.show()\n",
    "print(\"Captions:\", captions)\n",
    "print(\"Label probs:\", ' '.join(['{:.2f}'.format(prob) for prob in text_probs]))  # prints: [[0.00, 0.00, 1.00]]\n",
    "print(f\"This image is about {captions[text_probs.argmax()]}\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "3f19339c",
   "metadata": {},
   "source": [
    "### Example 2: Video and Text Feature Extraction for Zero-shot Video Classification/Retrieval\n",
    "\n",
    "In this example, we extract the embedding of a dog video, and the embeddings of 3 senetence. We measure the cosine similarites between the video and sentences. We first create a simple video preprocess function to process video input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "530e4d78",
   "metadata": {},
   "source": [
    "def preprocess_video(video_path, num_frames=8, transform=None, return_first_frame_for_demo=True):\n",
    "    \"\"\"\n",
    "    Uniformly samples a specified number of frames from a video and preprocesses them.\n",
    "    Parameters:\n",
    "    - video_path: str, path to the video file.\n",
    "    - num_frames: int, number of frames to sample. Defaults to 8.\n",
    "    - transform: torchvision.transforms, a transform function to preprocess frames.\n",
    "    Returns:\n",
    "    - Video Tensor: a tensor of shape (num_frames, 3, H, W) where H and W are the height and width of the frames.\n",
    "    \"\"\"\n",
    "    # Load the video\n",
    "    vr = decord.VideoReader(video_path)\n",
    "    total_frames = len(vr)\n",
    "    # Uniformly sample frame indices\n",
    "    frame_indices = [int(i * (total_frames / num_frames)) for i in range(num_frames)]\n",
    "    frames = vr.get_batch(frame_indices).asnumpy()\n",
    "    # Preprocess frames\n",
    "    preprocessed_frames = [transform(Image.fromarray(frame)) for frame in frames]\n",
    "\n",
    "    first_frame = None\n",
    "    if return_first_frame_for_demo:\n",
    "        first_frame = frames[0]\n",
    "    return torch.stack(preprocessed_frames, dim=0), first_frame"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "53653e7d",
   "metadata": {},
   "source": [
    "Then we use the encode_video function to encode video to get video embeddings. And compare the cosine similaries with the text emebddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9fe4ab59",
   "metadata": {},
   "source": [
    "video, first_frame = preprocess_video(\"./apps/pe/docs/assets/dog.mp4\", 8, transform=preprocess)\n",
    "video = video.unsqueeze(0).to(device)\n",
    "text = tokenizer([\"a diagram\", \"a dog\", \"a cat\"]).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    image_features = model.encode_video(video)\n",
    "    text_features = model.encode_text(text)\n",
    "    image_features /= image_features.norm(dim=-1, keepdim=True)\n",
    "    text_features /= text_features.norm(dim=-1, keepdim=True)\n",
    "    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1).cpu().numpy()[0]\n",
    "\n",
    "plt.imshow(Image.fromarray(first_frame))\n",
    "plt.axis('off')\n",
    "plt.show()\n",
    "print(\"Captions:\", captions)\n",
    "print(\"Label probs:\", ' '.join(['{:.2f}'.format(prob) for prob in text_probs]))  # prints: [[0.00, 1.00, 0.00]]\n",
    "print(f\"This video is about {captions[text_probs.argmax()]}\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aceb0cab",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
