{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.\n",
    "\n",
    "# Video Seal - Video inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/private/home/pfz/09-videoseal/fbresearch-new\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/private/home/pfz/miniconda3/envs/img/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
      "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
     ]
    }
   ],
   "source": [
    "# run in the root of the repository\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    " \n",
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/private/home/pfz/miniconda3/envs/img/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from videoseal.utils.display import save_vid\n",
    "from videoseal.utils import Timer\n",
    "from videoseal.evals.full import setup_model_from_checkpoint\n",
    "from videoseal.evals.metrics import bit_accuracy, pvalue, capacity, psnr, ssim, msssim, linf\n",
    "from videoseal.data.datasets import VideoDataset\n",
    "from videoseal.augmentation import Identity, H264, Crop\n",
    "from videoseal.modules.jnd import JND\n",
    "\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import gc\n",
    "\n",
    "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = \"cuda\" \n",
    "device = \"cpu\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "File /private/home/pfz/09-videoseal/fbresearch-new/ckpts/y_256b_img.pth exists, skipping download\n",
      "Model loaded successfully from /private/home/pfz/09-videoseal/fbresearch-new/ckpts/y_256b_img.pth with message: <All keys matched successfully>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing Videos for videoseal:   0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading video assets/videos/1.mp4 - took 24.40s\n",
      "embedding watermark  - took 14.85s\n",
      "compressing and detecting watermarks\n",
      "{'psnr': 48.327247619628906, 'bit_accuracy_original': 1.0, 'pvalue_original': 8.636168555094445e-78, 'capacity_original': 256.0, 'bit_accuracy_h264_30_+_crop_08': 0.61328125, 'pvalue_h264_30_+_crop_08': 0.00017424190213637938, 'capacity_h264_30_+_crop_08': 9.561767578125, 'bit_accuracy_h264_40': 0.796875, 'pvalue_h264_40': 9.255391288366052e-23, 'capacity_h264_40': 69.59762573242188, 'bit_accuracy_h264_50': 0.53515625, 'pvalue_h264_50': 0.1439946432658289, 'capacity_h264_50': 0.9136962890625}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing Videos for videoseal: 100%|██████████| 1/1 [02:54<00:00, 174.94s/it]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "\n",
    "# seed\n",
    "torch.manual_seed(0)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(0)\n",
    "\n",
    "# Directory containing videos\n",
    "video_dir = \"assets/videos/\"\n",
    "base_output_folder = \"outputs\"\n",
    "if not os.path.exists(base_output_folder):\n",
    "    os.makedirs(base_output_folder)\n",
    "\n",
    "# Example usage\n",
    "ckpts = {\n",
    "    # \"trustmark\": \"baseline/trustmark\",\n",
    "    # \"model\": \"baseline/model\",\n",
    "    # \"cin\": \"baseline/cin\",\n",
    "    # \"mbrs\": \"baseline/mbrs\",\n",
    "    # \"videoseal_0.0\": 'videoseal_0.0',\n",
    "    \"videoseal\": 'videoseal',\n",
    "}\n",
    "\n",
    "fps = 24 // 1\n",
    "num_videos = None\n",
    "\n",
    "# a timer to measure the time\n",
    "timer = Timer()\n",
    "\n",
    "# Iterate over all ckpts\n",
    "for model_name, ckpt in ckpts.items():\n",
    "    model = setup_model_from_checkpoint(ckpt)\n",
    "    model.eval()\n",
    "    model.compile()\n",
    "    model.to(device)\n",
    "\n",
    "    model.step_size = 4\n",
    "    model.video_mode = \"repeat\"\n",
    "\n",
    "    # Iterate over all video files in the directory\n",
    "    video_files = [f for f in os.listdir(video_dir) if f.endswith(\".mp4\")][:num_videos]\n",
    "\n",
    "    for video_file in tqdm(video_files, desc=f\"Processing Videos for {model_name}\"):\n",
    "        video_path = os.path.join(video_dir, video_file)\n",
    "        base_name = os.path.splitext(video_file)[0]\n",
    "\n",
    "        # Load video (assuming a function `load_video` exists)\n",
    "        timer.start()\n",
    "        vid, mask = VideoDataset.load_full_video_decord(video_path)\n",
    "        print(f\"loading video {video_path} - took {timer.stop():.2f}s\")\n",
    "\n",
    "        # Watermark embedding\n",
    "        timer.start()\n",
    "        outputs = model.embed(vid, is_video=True, lowres_attenuation=True)\n",
    "        print(f\"embedding watermark  - took {timer.stop():.2f}s\")\n",
    "\n",
    "        # compute diff\n",
    "        imgs = vid  # b c h w\n",
    "        imgs_w = outputs[\"imgs_w\"]  # b c h w\n",
    "        msgs = outputs[\"msgs\"]  # b k\n",
    "        diff = imgs_w - imgs\n",
    "\n",
    "        # # save\n",
    "        timer.start()\n",
    "        save_vid(imgs, f\"{base_output_folder}/{model_name}_{base_name}_ori.mp4\", fps)\n",
    "        save_vid(imgs_w, f\"{base_output_folder}/{model_name}_{base_name}_wm.mp4\", fps)\n",
    "        save_vid(10*diff.abs(), f\"{base_output_folder}/{model_name}_{base_name}_diff.mp4\", fps)\n",
    "\n",
    "        # Metrics\n",
    "        metrics = {\n",
    "            \"psnr\": psnr(imgs, imgs_w, is_video=True).mean().item(),\n",
    "            # \"ssim\": ssim(imgs, imgs_w).mean().item(),\n",
    "            # \"msssim\": msssim(imgs, imgs_w).mean().item(),\n",
    "            # \"linf\": linf(imgs, imgs_w).mean().item()\n",
    "        }\n",
    "\n",
    "        # Augment video\n",
    "        print(f\"compressing and detecting watermarks\")\n",
    "        for ii in range(4):\n",
    "        # for ii in range(1):\n",
    "            if ii == 0:\n",
    "                imgs_aug = imgs_w\n",
    "                label = \"Original\"\n",
    "            if ii == 1: \n",
    "                imgs_aug, _ = H264()(imgs_w, crf=30)\n",
    "                imgs_aug, _ = Crop()(imgs_aug, size=0.75)\n",
    "                label = \"H264 30 + Crop 0.8\"\n",
    "            if ii == 2: \n",
    "                imgs_aug, _ = H264()(imgs_w, crf=40)\n",
    "                label = \"H264 40\"\n",
    "            if ii == 3: \n",
    "                imgs_aug, _ = H264()(imgs_w, crf=50)\n",
    "                label = \"H264 50\"\n",
    "\n",
    "            # detect\n",
    "            timer.start()\n",
    "            aggregate = True\n",
    "            if not aggregate:\n",
    "                outputs = model.detect(imgs_aug, is_video=True)\n",
    "                preds = outputs[\"preds\"]\n",
    "                bit_preds = preds[:, 1:]  # b k ...\n",
    "                bit_accuracy_ = bit_accuracy(\n",
    "                    bit_preds,\n",
    "                    msgs\n",
    "                ).nanmean().item()\n",
    "                metrics[f\"bit_accuracy_{label.lower().replace(' ', '_').replace('.', '')}\"] = bit_accuracy_\n",
    "                # print(f\"{label} - Bit Accuracy: {bit_accuracy_:.3f} - took {timer.stop():.2f}s\")\n",
    "            else:\n",
    "                bit_preds = model.extract_message(imgs_aug)\n",
    "                bit_accuracy_ = bit_accuracy(\n",
    "                    bit_preds,\n",
    "                    msgs[:1]\n",
    "                ).nanmean().item()\n",
    "                pvalue_ = pvalue(\n",
    "                    bit_preds,\n",
    "                    msgs[:1]\n",
    "                ).nanmean().item()\n",
    "                capacity_ = capacity(\n",
    "                    bit_preds,\n",
    "                    msgs[:1]\n",
    "                ).nanmean().item()\n",
    "                metrics[f\"bit_accuracy_{label.lower().replace(' ', '_').replace('.', '')}\"] = bit_accuracy_\n",
    "                metrics[f\"pvalue_{label.lower().replace(' ', '_').replace('.', '')}\"] = pvalue_\n",
    "                metrics[f\"capacity_{label.lower().replace(' ', '_').replace('.', '')}\"] = capacity_\n",
    "                # print(f\"{label} - Bit Accuracy: {bit_accuracy_:.3f} - P-Value: {pvalue_:0.2e} - Capacity: {capacity_:.3f} - took {timer.stop():.2f}s\")\n",
    "        print(metrics)\n",
    "\n",
    "        del vid, outputs, imgs, imgs_w, diff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "img",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
