{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PerceptionLM: Open-Access Data and Models for Detailed Visual Understanding\n",
    "Perception Language Model (PLM) is a state-of-the-art, fully open and reproducible MLLM for transparent research in image and video understanding.\n",
    "\n",
    "[![Hugging Face Collection](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face%20Collection-Models,%20Data,%20and%20Benchmarks-blue)](https://huggingface.co/collections/facebook/perception-lm-67f9783f171948c383ee7498)\n",
    "[![Paper](https://img.shields.io/badge/Technical%20Report-PerceptionLM-b31b1b.svg)](https://ai.meta.com/research/publications/perceptionlm-open-access-data-and-models-for-detailed-visual-understanding)\n",
    "[![Paper](https://img.shields.io/badge/arXiv-2504.13180-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2504.13180)\n",
    "[![ModelLicense](https://img.shields.io/badge/Model_License-FAIR_Research_License-lightgrey)](../../LICENSE.PLM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Login in HF hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "from huggingface_hub.hf_api import HfFolder\n",
    "# get your token here https://huggingface.co/settings/tokens\n",
    "HfFolder.save_token('YOUR_HF_TOKEN')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "source": [
    "import os\n",
    "import re \n",
    "\n",
    "import torch\n",
    "from PIL import Image, ImageDraw\n",
    "\n",
    "import time\n",
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    "import textwrap\n",
    "import requests\n",
    "import urllib.request\n",
    "\n",
    "from core.args import dataclass_from_dict\n",
    "from core.transforms.image_transform import get_image_transform\n",
    "from core.transforms.video_transform import get_video_transform\n",
    "from apps.plm.generate import PackedCausalTransformerGeneratorArgs, PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load PLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "source": [
    "# ckpt = \"facebook/Perception-LM-1B\"\n",
    "# ckpt = \"facebook/Perception-LM-3B\" \n",
    "ckpt = \"facebook/Perception-LM-8B\" \n",
    "model, tokenizer, config = load_consolidated_model_and_tokenizer(ckpt)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "source": [
    "def generate(\n",
    "    media_path,\n",
    "    question=\"Describe the image in details.\",\n",
    "    media_type=\"image\",\n",
    "    number_of_frames=4,\n",
    "    number_of_tiles=1,\n",
    "    temperature=0.0,\n",
    "    top_p=None,\n",
    "    top_k=None,\n",
    "    return_text=False,\n",
    "):\n",
    "    prompts = []\n",
    "    if media_type == \"image\":\n",
    "        transform = get_image_transform(\n",
    "            vision_input_type=(\n",
    "                \"vanilla\" if number_of_tiles == 1 else config.data.vision_input_type\n",
    "            ),\n",
    "            image_res=model.vision_model.image_size,\n",
    "            max_num_tiles=number_of_tiles,\n",
    "        )\n",
    "        if isinstance(media_path, str):\n",
    "            image = Image.open(media_path).convert(\"RGB\")\n",
    "        else:\n",
    "            image = media_path\n",
    "        image, _ = transform(image)\n",
    "        prompts.append((question, image))\n",
    "    elif media_type == \"video\":\n",
    "        transform = get_video_transform(\n",
    "            image_res=model.vision_model.image_size,\n",
    "        )\n",
    "        video_info = (media_path, number_of_frames, None, None, None)\n",
    "        frames, _ = transform(video_info)\n",
    "        prompts.append((question, frames))\n",
    "    else:\n",
    "        raise NotImplementedError(\n",
    "            f\"The provided generate function only supports image and video.\"\n",
    "        )\n",
    "    # Create generator\n",
    "    gen_cfg = dataclass_from_dict(\n",
    "        PackedCausalTransformerGeneratorArgs,\n",
    "        {\"temperature\": temperature, \"top_p\": top_p, \"top_k\": top_k},\n",
    "        strict=False,\n",
    "    )\n",
    "    generator = PackedCausalTransformerGenerator(gen_cfg, model, tokenizer)\n",
    "    # Run generation\n",
    "    start_time = time.time()\n",
    "    generation, loglikelihood, greedy = generator.generate(prompts)\n",
    "    end_time = time.time()\n",
    "    if return_text:\n",
    "        print(generation[0])\n",
    "        return generation[0]\n",
    "    \n",
    "    for i, gen in enumerate(generation):\n",
    "        # Calculate tokens per second\n",
    "        total_tokens = sum(\n",
    "            len(tokenizer.encode(gen, False, False)) for gen in generation\n",
    "        )\n",
    "        tokens_per_second = total_tokens / (end_time - start_time)\n",
    "        print(\"=================================================\")\n",
    "        print(textwrap.fill(gen, width=75))\n",
    "        print(f\"Tokens per second: {tokens_per_second:.2f}\")\n",
    "        print(\"=================================================\")\n",
    "\n",
    "def extract_all_bounding_boxes(text: str) -> list[list[str]]:\n",
    "    \"\"\"Extracts any list of arbitrary length from a string.\"\"\"\n",
    "    pattern = r\"\\[\\s*([^\\[\\]]*?)\\s*\\]\"\n",
    "    extracted_lists = [\n",
    "        [num.strip() for num in match.split(\",\")] for match in re.findall(pattern, text)\n",
    "    ]\n",
    "    return extracted_lists\n",
    "\n",
    "def rescale_2d_bboxes(bboxes, img_w, img_h, box_format=\"000\", verbose=True):\n",
    "    w, h = img_w, img_h\n",
    "    rescaled_bboxes = []\n",
    "    for bbox in bboxes:\n",
    "        try:\n",
    "            if box_format == \"000\":\n",
    "                bbox = [float(\"0.\" + b.strip()) for b in bbox]\n",
    "            elif box_format == \"standard\":\n",
    "                bbox = [float(b.strip()) for b in bbox]\n",
    "            else:\n",
    "                # we don't know the format. try both\n",
    "                try:\n",
    "                    bbox = [float(\"0.\" + b.strip()) for b in bbox]\n",
    "                except:\n",
    "                    bbox = [float(b.strip()) for b in bbox]\n",
    "\n",
    "            x1, y1, x2, y2 = bbox\n",
    "            bbox = [x1 * w, y1 * h, x2 * w, y2 * h]\n",
    "\n",
    "            rescaled_bboxes.append(bbox)\n",
    "        except Exception as e:\n",
    "            if verbose:\n",
    "                print(\"[rescale_2d_bboxes]:\", e, bbox, flush=True)\n",
    "            pass\n",
    "    return rescaled_bboxes\n",
    "\n",
    "def postprocess_grounding(x: str, img_w: int, img_h: int) -> list[float]:\n",
    "    bboxes = extract_all_bounding_boxes(x)\n",
    "    bboxes = rescale_2d_bboxes(bboxes, img_w, img_h)\n",
    "    if len(bboxes) > 0:\n",
    "        box = bboxes[0]\n",
    "    else:\n",
    "        # no box found.\n",
    "        box = [0, 0, img_w, img_h]\n",
    "    return box\n",
    "\n",
    "\n",
    "def generate_grounding(media_path: str, question: str, number_of_tiles: int):\n",
    "    image = Image.open(media_path)\n",
    "    w, h = image.size\n",
    "\n",
    "    print(\"Generating...\")\n",
    "    output = generate(media_path=media_path, question=question, number_of_tiles=number_of_tiles, media_type=\"image\", return_text=True)\n",
    "\n",
    "    box = postprocess_grounding(output, w, h)\n",
    "\n",
    "    draw = ImageDraw.Draw(image)\n",
    "    try:\n",
    "        x_min, y_min, x_max, y_max = box\n",
    "\n",
    "        if x_min < x_max and y_min < y_max:\n",
    "            # Draw the bounding box\n",
    "            draw.rectangle([x_min, y_min, x_max, y_max], outline=\"red\", width=2)\n",
    "    except Exception as e:\n",
    "        print(f\"Error drawing bounding box: {e}\")\n",
    "    return image"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run inference for image grounding task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "source": [
    "question_template = \"Provide a bounding box of the region this sentence describes: '{caption}'.\\nUse the format [x1, y1, x2, y2].\"\n",
    "\n",
    "image_url = \"http://farm3.staticflickr.com/2453/3867429392_ed6f3d337a_z.jpg\"\n",
    "image_path = \"3867429392_ed6f3d337a_z.jpg\"\n",
    "\n",
    "urllib.request.urlretrieve(image_url, image_path)\n",
    "\n",
    "description = \"white fire hydrant in the back\"\n",
    "question = question_template.format(caption=description)\n",
    "\n",
    "img = Image.open(image_path)\n",
    "display(img)\n",
    "\n",
    "img_drawn = generate_grounding(media_path=image_path, question=question, number_of_tiles=36)\n",
    "display(img_drawn)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Region captioning (bounding box as text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "metadata": {},
   "source": [
    "question = \"Please describe the region ({bbox}) in details.\\nThe region is in the format of [x1, y1, x2, y2].\"\n",
    "question = question.format(bbox=\"[040,482,112,576]\")\n",
    "generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type=\"image\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Region captioning (bounding box as drawing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 167,
   "metadata": {},
   "source": [
    "question = \"Please describe the region inside the rec rectangle.\"\n",
    "generate(media_path=img_drawn, question=question, number_of_tiles=36, media_type=\"image\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
