{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Otter Video Demo\n",
    "\n",
    "Current Otter Video is Otter-v0.2-DC (0612), means it’s trianed on MIMIC-IT-DC at June 12th. The code reads a video and uniformly extracts 16 frames, so avoid using excessively long videos if you want the model to generate specific descriptions.\n",
    "\n",
    "If your machine has over 16G GPU memory, you can run our model locally in fp16 mode for tasks like video labeling and identifying harmful content. For machines with over 36G GPU memory (by combining multiple cards with [device_map='auto'](https://huggingface.co/docs/accelerate/usage_guides/big_modeling) to one model different cards), you can run our model in the more accurate fp32 mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mimetypes\n",
    "import os\n",
    "from typing import Union\n",
    "import cv2\n",
    "import requests\n",
    "import torch\n",
    "import transformers\n",
    "from PIL import Image\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../..\")\n",
    "from otter.modeling_otter import OtterForConditionalGeneration\n",
    "\n",
    "# Disable warnings\n",
    "requests.packages.urllib3.disable_warnings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------- Utility Functions -------------------\n",
    "\n",
    "\n",
    "def get_content_type(file_path):\n",
    "    content_type, _ = mimetypes.guess_type(file_path)\n",
    "    return content_type\n",
    "\n",
    "\n",
    "# ------------------- Image and Video Handling Functions -------------------\n",
    "\n",
    "\n",
    "def extract_frames(video_path, num_frames=16):\n",
    "    video = cv2.VideoCapture(video_path)\n",
    "    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "    frame_step = total_frames // num_frames\n",
    "    frames = []\n",
    "\n",
    "    for i in range(num_frames):\n",
    "        video.set(cv2.CAP_PROP_POS_FRAMES, i * frame_step)\n",
    "        ret, frame = video.read()\n",
    "        if ret:\n",
    "            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
    "            frame = Image.fromarray(frame).convert(\"RGB\")\n",
    "            frames.append(frame)\n",
    "\n",
    "    video.release()\n",
    "    return frames\n",
    "\n",
    "\n",
    "def get_image(url: str) -> Union[Image.Image, list]:\n",
    "    if \"://\" not in url:  # Local file\n",
    "        content_type = get_content_type(url)\n",
    "    else:  # Remote URL\n",
    "        content_type = requests.head(url, stream=True, verify=False).headers.get(\"Content-Type\")\n",
    "\n",
    "    if \"image\" in content_type:\n",
    "        if \"://\" not in url:  # Local file\n",
    "            return Image.open(url)\n",
    "        else:  # Remote URL\n",
    "            return Image.open(requests.get(url, stream=True, verify=False).raw)\n",
    "    elif \"video\" in content_type:\n",
    "        video_path = \"temp_video.mp4\"\n",
    "        if \"://\" not in url:  # Local file\n",
    "            video_path = url\n",
    "        else:  # Remote URL\n",
    "            with open(video_path, \"wb\") as f:\n",
    "                f.write(requests.get(url, stream=True, verify=False).content)\n",
    "        frames = extract_frames(video_path)\n",
    "        if \"://\" in url:  # Only remove the temporary video file if it was downloaded\n",
    "            os.remove(video_path)\n",
    "        return frames\n",
    "    else:\n",
    "        raise ValueError(\"Invalid content type. Expected image or video.\")\n",
    "\n",
    "\n",
    "# ------------------- OTTER Prompt and Response Functions -------------------\n",
    "\n",
    "\n",
    "def get_formatted_prompt(prompt: str) -> str:\n",
    "    return f\"<image>User: {prompt} GPT:<answer>\"\n",
    "\n",
    "\n",
    "def get_response(input_data, prompt: str, model=None, image_processor=None, tensor_dtype=None) -> str:\n",
    "    if isinstance(input_data, Image.Image):\n",
    "        vision_x = image_processor.preprocess([input_data], return_tensors=\"pt\")[\"pixel_values\"].unsqueeze(1).unsqueeze(0)\n",
    "    elif isinstance(input_data, list):  # list of video frames\n",
    "        vision_x = image_processor.preprocess(input_data, return_tensors=\"pt\")[\"pixel_values\"].unsqueeze(0).unsqueeze(0)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid input data. Expected PIL Image or list of video frames.\")\n",
    "\n",
    "    lang_x = model.text_tokenizer(\n",
    "        [\n",
    "            get_formatted_prompt(prompt),\n",
    "        ],\n",
    "        return_tensors=\"pt\",\n",
    "    )\n",
    "\n",
    "    # Get the data type from model's parameters\n",
    "    model_dtype = next(model.parameters()).dtype\n",
    "\n",
    "    # Convert tensors to the model's data type\n",
    "    vision_x = vision_x.to(dtype=model_dtype)\n",
    "    lang_x_input_ids = lang_x[\"input_ids\"]\n",
    "    lang_x_attention_mask = lang_x[\"attention_mask\"]\n",
    "\n",
    "    bad_words_id = model.text_tokenizer([\"User:\", \"GPT1:\", \"GFT:\", \"GPT:\"], add_special_tokens=False).input_ids\n",
    "    generated_text = model.generate(\n",
    "        vision_x=vision_x.to(model.device),\n",
    "        lang_x=lang_x_input_ids.to(model.device),\n",
    "        attention_mask=lang_x_attention_mask.to(model.device),\n",
    "        max_new_tokens=512,\n",
    "        num_beams=3,\n",
    "        no_repeat_ngram_size=3,\n",
    "        bad_words_ids=bad_words_id,\n",
    "    )\n",
    "    parsed_output = (\n",
    "        model.text_tokenizer.decode(generated_text[0])\n",
    "        .split(\"<answer>\")[-1]\n",
    "        .lstrip()\n",
    "        .rstrip()\n",
    "        .split(\"<|endofchunk|>\")[0]\n",
    "        .lstrip()\n",
    "        .rstrip()\n",
    "        .lstrip('\"')\n",
    "        .rstrip('\"')\n",
    "    )\n",
    "    return parsed_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ------------------- Main Function -------------------\n",
    "load_bit = \"fp32\"\n",
    "if load_bit == \"fp16\":\n",
    "    precision = {\"torch_dtype\": torch.float16}\n",
    "elif load_bit == \"bf16\":\n",
    "    precision = {\"torch_dtype\": torch.bfloat16}\n",
    "elif load_bit == \"fp32\":\n",
    "    precision = {\"torch_dtype\": torch.float32}\n",
    "\n",
    "# This model version is trained on MIMIC-IT DC dataset.\n",
    "model = OtterForConditionalGeneration.from_pretrained(\"annonymous/OTTER-9B-DenseCaption\", device_map=\"auto\", **precision)\n",
    "tensor_dtype = {\"fp16\": torch.float16, \"bf16\": torch.bfloat16, \"fp32\": torch.float32}[load_bit]\n",
    "\n",
    "model.text_tokenizer.padding_side = \"left\"\n",
    "tokenizer = model.text_tokenizer\n",
    "image_processor = transformers.CLIPImageProcessor()\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    video_url = input(\"Enter video path: \")  # Replace with the path to your video file, could be any common format.\n",
    "\n",
    "    frames_list = get_image(video_url)\n",
    "\n",
    "    while True:\n",
    "        prompts_input = input(\"Enter prompts: \")\n",
    "\n",
    "        if prompts_input.lower() == \"quit\":\n",
    "            break\n",
    "\n",
    "        print(f\"\\nPrompt: {prompts_input}\")\n",
    "        response = get_response(frames_list, prompts_input, model, image_processor, tensor_dtype)\n",
    "        print(f\"Response: {response}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "otter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
