{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "auuaEk9iaMyg"
   },
   "source": [
    "Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.\n",
    "\n",
    "# Video Seal - Video inference, optimized for low RAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/private/home/pfz/09-videoseal/videoseal\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/private/home/pfz/miniconda3/envs/img/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    }
   ],
   "source": [
    "# run in the root of the repository\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    " \n",
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZIicYPSXaMyl"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/private/home/pfz/miniconda3/envs/img/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.animation import FuncAnimation\n",
    "import logging\n",
    "logging.getLogger(\"matplotlib.image\").setLevel(logging.ERROR)\n",
    "from IPython.display import HTML, display\n",
    "\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import ffmpeg\n",
    "import os\n",
    "import cv2\n",
    "import subprocess\n",
    "import torch\n",
    "\n",
    "from videoseal.utils.display import save_vid\n",
    "from videoseal.utils import Timer\n",
    "from videoseal.evals.full import setup_model_from_checkpoint\n",
    "from videoseal.evals.metrics import bit_accuracy, pvalue, capacity, psnr, ssim, msssim, linf\n",
    "from videoseal.data.datasets import VideoDataset\n",
    "from videoseal.augmentation import Identity, H264, Crop\n",
    "from videoseal.models.videoseal import Videoseal\n",
    "from videoseal.modules.jnd import JND\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = \"cpu\" \n",
    "# device = \"cuda\" \n",
    "\n",
    "def get_video_info(input_path):\n",
    "    # Open the video file\n",
    "    video = cv2.VideoCapture(input_path)\n",
    "\n",
    "    # Get video properties\n",
    "    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))\n",
    "    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))\n",
    "    fps = video.get(cv2.CAP_PROP_FPS)\n",
    "    codec = int(video.get(cv2.CAP_PROP_FOURCC))\n",
    "    num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))\n",
    "\n",
    "    # Decode codec to human-readable form\n",
    "    codec_str = \"\".join([chr((codec >> 8 * i) & 0xFF) for i in range(4)])\n",
    "\n",
    "    video.release()  # Close the video file\n",
    "\n",
    "    return {\n",
    "        \"width\": width,\n",
    "        \"height\": height,\n",
    "        \"fps\": fps,\n",
    "        \"codec\": codec_str,\n",
    "        \"num_frames\": num_frames\n",
    "    }\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = \"cpu\" \n",
    "# device = \"cuda\" "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4tGZxcKuaMyl"
   },
   "source": [
    "## Load the model\n",
    "\n",
    "The videoseal library provides pretrained models for embedding and extracting watermarks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1_8BMj5UaMym",
    "outputId": "cf796b09-9aea-4f56-eeda-bf8496d2b2db"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded successfully from ckpts/y_128b_img.pth with message: <All keys matched successfully>\n"
     ]
    }
   ],
   "source": [
    "# Load the VideoSeal model.\n",
    "model = setup_model_from_checkpoint(\"videoseal\")\n",
    "\n",
    "# Set the model to evaluation mode and move it to the selected device.\n",
    "model = model.eval()\n",
    "model = model.to(device)\n",
    "model.compile()\n",
    "\n",
    "# Setup the step size. Bigger step size makes embedding faster but loses a bit of robustness.\n",
    "model.step_size = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FDB8D9e6aMym"
   },
   "source": [
    "## Embedding\n",
    "\n",
    "The embedding process is the process of hiding the watermark in the video."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_and_add_ffmpeg():\n",
    "    try:\n",
    "        # Try multiple possible ffmpeg paths\n",
    "        ffmpeg_paths = [\n",
    "            'ffmpeg',\n",
    "            '/opt/homebrew/bin/ffmpeg',\n",
    "            '/usr/local/bin/ffmpeg'\n",
    "        ]\n",
    "        \n",
    "        for path in ffmpeg_paths:\n",
    "            try:\n",
    "                subprocess.run([path, '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
    "                print(f\"ffmpeg found at: {path}\")\n",
    "                if os.path.dirname(path) not in os.environ['PATH']:\n",
    "                    print(f\"Adding {os.path.dirname(path)} to PATH\")\n",
    "                    os.environ['PATH'] = os.path.dirname(path) + ':' + os.environ.get('PATH', '')\n",
    "                return\n",
    "            except FileNotFoundError:\n",
    "                continue\n",
    "                \n",
    "        raise FileNotFoundError(\"No ffmpeg installation found\")\n",
    "    except Exception as e:\n",
    "        raise RuntimeError(f\"ffmpeg check failed: {str(e)}\")\n",
    "\n",
    "check_and_add_ffmpeg()  # add path to the ffmpeg binary, from Mac homebrew or system\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WeT-7WzXaMym"
   },
   "outputs": [],
   "source": [
    "def embed_video_clip(\n",
    "    model: Videoseal,\n",
    "    clip: np.ndarray,\n",
    "    msgs: torch.Tensor\n",
    ") -> np.ndarray:\n",
    "    clip_tensor = torch.tensor(clip, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0\n",
    "    outputs = model.embed(clip_tensor, msgs=msgs, is_video=True, lowres_attenuation=True)\n",
    "    processed_clip = outputs[\"imgs_w\"]\n",
    "    processed_clip = (processed_clip * 255.0).byte().permute(0, 2, 3, 1).numpy()\n",
    "    return processed_clip\n",
    "\n",
    "def embed_video(\n",
    "    model: Videoseal,\n",
    "    input_path: str,\n",
    "    output_path: str,\n",
    "    chunk_size: int,\n",
    "    crf: int = 23\n",
    ") -> None:\n",
    "    # Read video dimensions\n",
    "    video_info = get_video_info(input_path)\n",
    "    width = int(video_info['width'])\n",
    "    height = int(video_info['height'])\n",
    "    fps = float(video_info['fps'])\n",
    "    codec = video_info['codec']\n",
    "    num_frames = int(video_info['num_frames'])\n",
    "\n",
    "    # Open the input video\n",
    "    process1 = (\n",
    "        ffmpeg\n",
    "        .input(input_path)\n",
    "        .output('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height), r=fps)\n",
    "        .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)\n",
    "    )\n",
    "    # Open the output video\n",
    "    process2 = (\n",
    "        ffmpeg\n",
    "        .input('pipe:', format='rawvideo', pix_fmt='rgb24', s='{}x{}'.format(width, height), r=fps)\n",
    "        .output(output_path, vcodec='libx264', pix_fmt='yuv420p', r=fps)\n",
    "        .overwrite_output()\n",
    "        .run_async(pipe_stdin=True, pipe_stderr=subprocess.PIPE)\n",
    "    )\n",
    "\n",
    "    # Create a random message\n",
    "    msgs = model.get_random_msg()\n",
    "    with open(output_path.replace(\".mp4\", \".txt\"), \"w\") as f:\n",
    "        f.write(\"\".join([str(msg.item()) for msg in msgs[0]]))\n",
    "\n",
    "    # Process the video\n",
    "    frame_size = width * height * 3\n",
    "    chunk = np.zeros((chunk_size, height, width, 3), dtype=np.uint8)\n",
    "    frame_count = 0\n",
    "    pbar = tqdm(total=num_frames, desc=\"Watermark embedding\")\n",
    "    while True:\n",
    "        in_bytes = process1.stdout.read(frame_size)\n",
    "        if not in_bytes:\n",
    "            break\n",
    "        frame = np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3])\n",
    "        chunk[frame_count % chunk_size] = frame\n",
    "        frame_count += 1\n",
    "        pbar.update(1)\n",
    "        if frame_count % chunk_size == 0:\n",
    "            processed_frame = embed_video_clip(model, chunk, msgs)\n",
    "            process2.stdin.write(processed_frame.tobytes())\n",
    "    process1.stdout.close()\n",
    "    process2.stdin.close()\n",
    "    process1.wait()\n",
    "    process2.wait()\n",
    "\n",
    "    return msgs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gKOJP7GTa1xO"
   },
   "source": [
    "You are free to upload any video and change the `video_path`.\n",
    "\n",
    "You can look at the watermark video output in the folder `outputs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "XX1TKDega1Wg",
    "outputId": "47248e63-0b10-41af-bf2b-66cf3ec7cc6f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Watermark embedding: 100%|██████████| 256/256 [00:14<00:00, 18.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Saved watermarked video to ./outputs/1.mp4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Path to the input video\n",
    "video_path = \"assets/videos/1.mp4\"\n",
    "\n",
    "# Create the output directory and path\n",
    "output_dir = \"./outputs\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "output_path = os.path.join(output_dir, os.path.basename(video_path))\n",
    "\n",
    "# Embed the watermark inside the video with a random msg\n",
    "msgs_ori = embed_video(model, video_path, output_path, 32)\n",
    "print(f\"\\nSaved watermarked video to {output_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nJPpKOMDaMym"
   },
   "source": [
    "## Extraction\n",
    "\n",
    "Load the video output from the embedding process and extract the watermark."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wJxWT4F6aMym",
    "outputId": "89b323a7-919e-4a29-8935-29dbd08fae52"
   },
   "outputs": [],
   "source": [
    "def detect_video_clip(\n",
    "    model: Videoseal,\n",
    "    clip: np.ndarray\n",
    ") -> torch.Tensor:\n",
    "    clip_tensor = torch.tensor(clip, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0\n",
    "    outputs = model.detect(clip_tensor, is_video=True)\n",
    "    output_bits = outputs[\"preds\"][:, 1:]  # exclude the first which may be used for detection\n",
    "    return output_bits\n",
    "\n",
    "def detect_video(\n",
    "    model: Videoseal,\n",
    "    input_path: str,\n",
    "    num_frames_for_extraction: int,\n",
    "    chunk_size: int\n",
    ") -> None:\n",
    "    # Read video dimensions\n",
    "    video_info = get_video_info(input_path)\n",
    "    width = int(video_info['width'])\n",
    "    height = int(video_info['height'])\n",
    "    num_frames = int(video_info['num_frames'])\n",
    "\n",
    "    soft_msgs = []\n",
    "    process1 = None\n",
    "    \n",
    "    try:\n",
    "        # Open the input video\n",
    "        process1 = (\n",
    "            ffmpeg\n",
    "            .input(input_path)\n",
    "            .output('pipe:', format='rawvideo', pix_fmt='rgb24')\n",
    "            .run_async(pipe_stdout=True, pipe_stderr=subprocess.PIPE)\n",
    "        )\n",
    "\n",
    "        # Process the video\n",
    "        frame_size = width * height * 3\n",
    "        chunk = np.zeros((chunk_size, height, width, 3), dtype=np.uint8)\n",
    "        frame_count = 0\n",
    "        current_chunk_size = 0\n",
    "        pbar = tqdm(total=num_frames, desc=\"Watermark extraction\")\n",
    "        \n",
    "        try:\n",
    "            while frame_count < num_frames_for_extraction:\n",
    "                in_bytes = process1.stdout.read(frame_size)\n",
    "                if not in_bytes:\n",
    "                    break\n",
    "                frame = np.frombuffer(in_bytes, np.uint8).reshape([height, width, 3])\n",
    "                chunk[frame_count % chunk_size] = frame\n",
    "                frame_count += 1\n",
    "                current_chunk_size += 1\n",
    "                pbar.update(1)\n",
    "                \n",
    "                if frame_count % chunk_size == 0:\n",
    "                    soft_msgs.append(detect_video_clip(model, chunk))\n",
    "                    current_chunk_size = 0\n",
    "        except BrokenPipeError:\n",
    "            print(\"Pipe closed unexpectedly. Finalizing extraction...\")\n",
    "        finally:\n",
    "            # Process any remaining frames in the last chunk\n",
    "            if current_chunk_size > 0:\n",
    "                last_chunk = chunk[:current_chunk_size]\n",
    "                soft_msgs.append(detect_video_clip(model, last_chunk))\n",
    "    except Exception as e:\n",
    "        print(f\"Error during video detection: {str(e)}\")\n",
    "        raise\n",
    "    finally:\n",
    "        # Ensure all resources are properly closed\n",
    "        try:\n",
    "            if process1 is not None:\n",
    "                process1.stdout.close()\n",
    "                process1.wait(timeout=5)\n",
    "        except Exception as e:\n",
    "            print(f\"Error closing process: {str(e)}\")\n",
    "\n",
    "    if not soft_msgs:\n",
    "        raise RuntimeError(\"No frames were successfully processed for watermark extraction\")\n",
    "        \n",
    "    soft_msgs = torch.cat(soft_msgs, dim=0)\n",
    "    soft_msgs = soft_msgs.mean(dim=0)  # Average the predictions across all frames\n",
    "    return soft_msgs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Watermark extraction:  12%|█▎        | 32/256 [00:00<00:06, 33.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Binary message extracted with 99.2% bit accuracy\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Detect the watermark\n",
    "num_frames_for_extraction = 32\n",
    "soft_msgs = detect_video(model, output_path, num_frames_for_extraction, 16)\n",
    "bit_acc = bit_accuracy(soft_msgs, msgs_ori).item() * 100\n",
    "print(f\"\\nBinary message extracted with {bit_acc:.1f}% bit accuracy\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "img",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
