{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "da26c286-a234-4185-9701-fbbe10523b1a",
   "metadata": {},
   "source": [
    "# About this Notebook\n",
    "\n",
    "This notebook serves to visualize the networks performance on a given dataset and compare it to the MC-EMVS method. It consists of the following steps:\n",
    "1. Load dataset\n",
    "2. Load network\n",
    "3. Apply network to dataset\n",
    "4. Create frames\n",
    "5. Display video"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7eb430b-2286-4482-ace8-f2354d0eb5de",
   "metadata": {},
   "source": [
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14336e28-d4bf-41b5-9ae3-b927ac0fc5ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "import random\n",
    "import os\n",
    "import gc\n",
    "import re\n",
    "import time\n",
    "\n",
    "# Third-party library imports\n",
    "import numpy as np\n",
    "import cv2  # OpenCV for adaptive filtering\n",
    "import psutil  # For system resource management\n",
    "from scipy.ndimage import convolve  # To convolve filtering masks\n",
    "\n",
    "# PyTorch specific imports\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# Matplotlib for plots\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.animation as animation\n",
    "from matplotlib.colors import LinearSegmentedColormap, ListedColormap\n",
    "# HTML for video rendering\n",
    "from IPython.display import HTML\n",
    "plt.rcParams['animation.embed_limit'] = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fa59004-885b-4513-94dd-a58bfd5d839c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notebooks\n",
    "import import_ipynb\n",
    "from Classes_and_Functions import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f92524e-12c6-41c3-a3e3-1b91001fa8c6",
   "metadata": {},
   "source": [
    "# Hyperparameters\n",
    "\n",
    "First, define the hyperparameters of which dataset to use, what filter to apply, how Sub-DSIs shall be constructed and the whether to use the single or the multi-pixel version of the network. More options exist for the dataset, see *Classes_and_Functions.ipynb*\n",
    "\n",
    "Quick overview:\n",
    "* Everything can be left at default except the path for the <b>dsi_directory</b> and the <b>depthmap_directory</b>. \n",
    "* The default is the single-pixel version of the network, to use the multi-pixel version set <b>multi_pixel=True</b>.\n",
    "* The process is set to MVSEC stereo on default. If desired, switch to <b>dataset=\"mvsec_mono\"</b>, <b>dataset=\"dsec\"</b> or <b>dataset=\"dsec_mono\"</b>.\n",
    "* The filter parameters are set to default, but we used <b>filter_size=9</b> and an <b>adaptive_threshold_c=-10</b> for MVSEC and <b>adaptive_threshold_c=-10</b> for DSEC for training and testing instead. Feel free to replicate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7486bbe4-3a5a-4e04-91b0-a7d08589326a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Hyperparameters for the dataset:\n",
    "    # DSI Selection Arguments\n",
    "    dataset (str): The dataset used.\n",
    "    data_seq (int): Sequence (MVSEC) or half (DSEC) to be visualized.\n",
    "    dsi_directory (str): Directory of the DSIs. Must be adjusted to user.\n",
    "    depthmap_directory (str): Directory of the groundtrue depths for each DSI.\n",
    "    dsi_num_expression (str): Number expression of DSI files for sorting.\n",
    "    depthmap_num_expression (str): Number expression of depthmap files for sorting.\n",
    "    start_idx, end_idx (str): Start and stop indices for which DSIs to consider. \n",
    "\n",
    "    # Pixel selection\n",
    "    filter_size (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter.\n",
    "    adaptive_threshold_c (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.\n",
    "\n",
    "    # Sub-DSIs sizes\n",
    "    sub_frame_radius_h (int): Defines the radius of the frame at the height axis around the central pixel for the Sub-DSI.\n",
    "    sub_frame_radius_w (int): Defines the radius of the frame at the width axis around the central pixel for the Sub-DSI.\n",
    "\n",
    "    # Network version\n",
    "    multi_pixel (bool): Determines whether depth is predicted only for the central selected pixel or for the 8 neighbouring pixels as well.\n",
    "\"\"\"\n",
    "\n",
    "# Dataset selection\n",
    "dataset = \"mvsec_stereo\" #  Options: mvsec_stereo, mvsec_mono, dsec_stereo, dsec_mono\n",
    "data_seq = 1 #  Options: 1,2,3 for MVSEC, 1,2 for DSEC (refers to which half)\n",
    "\n",
    "# Directories\n",
    "data_directory = f\"/mvsec/indoor_flying{data_seq}\" if \"mvsec\" in dataset else \"dsec\"\n",
    "modality = \"monocular\" if \"mono\" in dataset else \"stereo\"\n",
    "dsi_directory = f\"{parent_dir}/data/{data_directory}/dsi_{modality}/\" #  Set your path here\n",
    "depthmap_directory = f\"{parent_dir}/data/{data_directory}/depthmaps/\" #  Set your path here\n",
    "\n",
    "# Number expressions of files\n",
    "dsi_num_expression = \"\\d+\\.\\d+|d+\"\n",
    "depthmap_num_expression = \"\\d+\"\n",
    "\n",
    "# Stard and end index of DSIs\n",
    "start_idx = 0\n",
    "end_idx = None\n",
    "\n",
    "# Filter parameters for pixel selection\n",
    "filter_size = None #  None automatically sets original value. We used 9 for training and testing on MVSEC and DSEC instead\n",
    "adaptive_threshold_c = None #  None automatically sets original value. We used -10 for MVSEC and -2 for DSEC instead\n",
    "\n",
    "# Sub-DSI sizes\n",
    "sub_frame_radius_h = 3\n",
    "sub_frame_radius_w = 3\n",
    "\n",
    "# Network version\n",
    "multi_pixel = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82563ba7-67b9-4f72-9617-957c38f9537e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If DSEC was selected as dataset, data_seq refers to the half of the zurich_city04a sequence that shall be visualized.\n",
    "# middle_idx is set to the middle of the index for the zurich_city04a sequence, but can be set to a different custom value as well.\n",
    "if \"dsec\" in dataset:\n",
    "    middle_idx = 174\n",
    "    # Assign half for visualization based on chosen test_seq.\n",
    "    if test_seq == 1:\n",
    "        # First half being used for visualization.\n",
    "        end_idx = middle_idx\n",
    "    elif test_seq == 2:\n",
    "        # Second half being used for visualization.\n",
    "        start_idx = middle_idx\n",
    "    else:\n",
    "        # Make sure that one half is selected.\n",
    "        raise Exception(\"Select one of two halfes for visualization.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24c57c18-f9d5-4970-a150-da33800b53c6",
   "metadata": {},
   "source": [
    "# Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f0c0690-24c4-46cb-9ae1-f364f461d24a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Decide whether the progress of reading in the DSIs shall be printed for tracking\n",
    "print_progress = True\n",
    "\n",
    "# Create dataset\n",
    "data = DSI_Pixelswise_Dataset(visualization_mode=True,\n",
    "                              dataset=dataset,\n",
    "                              data_seq=data_seq,\n",
    "                              dsi_directory=dsi_directory,\n",
    "                              depthmap_directory=depthmap_directory,\n",
    "                              start_idx=start_idx, end_idx=end_idx,\n",
    "                              filter_size=filter_size,\n",
    "                              adaptive_threshold_c=adaptive_threshold_c,\n",
    "                              sub_frame_radius_h=sub_frame_radius_h,\n",
    "                              sub_frame_radius_w=sub_frame_radius_w,\n",
    "                              multi_pixel=multi_pixel,\n",
    "                              clip_targets=False,\n",
    "                              print_progress=print_progress\n",
    "                             )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d09bd2a-d491-4dba-a3e1-78ff271282e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrap data into Dataloader\n",
    "batch_size = 2048\n",
    "dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b45cbccb-7f83-4297-bcd4-91869c90ca9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print data dimensions\n",
    "data_size = len(data)\n",
    "sub_dsi_size = data.data_list[0][1].shape\n",
    "\n",
    "print(\"data size:\", data_size)\n",
    "print(\"pixel number for inference:\", data.pixel_count)\n",
    "print(\"sub dsi size:\", sub_dsi_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa693d60-aeef-4d18-b520-2bf7674d8c57",
   "metadata": {},
   "source": [
    "# Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569c3b68-9238-4e83-a69e-139741a2c112",
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many models\n",
    "num_models = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78121550-50af-4df6-a6ac-a4c677040b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize models\n",
    "models = [PixelwiseConvGRU(sub_frame_radius_h, sub_frame_radius_w, multi_pixel=multi_pixel) for _ in range(num_models)]\n",
    "# Send to cuda\n",
    "if torch.cuda.is_available():\n",
    "    for model in models:\n",
    "        model.cuda()\n",
    "# Print architecture\n",
    "print(models[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b855934-47da-4fde-b9ed-0419ddfa7e14",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path to load models from directory\n",
    "model_directory = f\"/mvsec/indoor_flying{data_seq}\" if \"mvsec\" in dataset else f\"/dsec/dsec_half{test_seq}\"\n",
    "model_paths = [f\"{parent_dir}/models/{model_directory}/\"] * num_models # length has to be equal to num_models\n",
    "# Give names of model files\n",
    "prefix_sequence = f\"indoor_flying{data_seq}\" if \"mvsec\" in dataset else f\"dsec_half{test_seq}\"\n",
    "prefix_modality = modality if not multi_pixel else \"multipixel\"\n",
    "model_files = [f\"{prefix_sequence}_{prefix_modality}_even_model.pth\",\n",
    "               f\"{prefix_sequence}_{prefix_modality}_odd_model.pth\"][:num_models] # length has to be equal to num_models\n",
    "# Do not forget \".pth\"\n",
    "for idx, model_file in enumerate(model_files):\n",
    "    if not model_file.endswith(\".pth\"):\n",
    "        model_files[idx] += \".pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "186321ba-b20d-4055-8f01-28224d2fd825",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models parameters\n",
    "for idx, model in enumerate(models):\n",
    "    model.load_parameters(model_files[idx], model_path=model_paths[idx], optimizer=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c2c0cd9-c5a4-4f5f-a227-daaf685ac6c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use ensemble learning to create averaged model\n",
    "model = AveragedNetwork(models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3104292-c865-449e-ba5a-eaad5fbd718d",
   "metadata": {},
   "source": [
    "# Apply Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46da5bfc-d9cf-4fce-bb28-4172aadf2c06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_network_to_data(neural_network, dataloader, start_idx, end_idx):\n",
    "    \"\"\"\n",
    "    Function to apply the network to the Sub-DSIs.\n",
    "    Output will we a list of list.\n",
    "    Each inner list consists of the data for the selected pixel of each DSI.\n",
    "    The outter list represents frames, each one identified with data of the associated DSI.\n",
    "\n",
    "    Args:\n",
    "       neural_network (pth): Model to estimate depth.\n",
    "       dataloader (DataLoader): Iterator over the created dataset from the selected sequence.\n",
    "       start_idx, end_idx: Start and end index of DSI sequence.\n",
    "    \"\"\"\n",
    "\n",
    "    # Create empty dictionary of frames for faster access\n",
    "    frames_dict = {frame_idx : [] for frame_idx in range(end_idx - start_idx)}\n",
    "\n",
    "    # Set the model to evaluation mode - important for batch normalization and dropout layers\n",
    "    neural_network.eval()\n",
    "\n",
    "    # Iterate over all data points, each one associated with a selected pixel from on DSI\n",
    "    with torch.no_grad():\n",
    "        for batch, batch_data in enumerate(dataloader):\n",
    "            # If available, use GPU (device has to be set earlier)\n",
    "            batch_data = (tensor.to(device) for tensor in batch_data)\n",
    "    \n",
    "            # Get batch data\n",
    "            pixel_positions, sub_dsis, true_depths, argmax_depths, dsi_idxs = batch_data\n",
    "            batch_size = true_depths.size(0)\n",
    "    \n",
    "            # Get input\n",
    "            network_input = (pixel_positions, sub_dsis)\n",
    "    \n",
    "            # Predict\n",
    "            network_based_estimates = neural_network(network_input)\n",
    "\n",
    "            # Save the predicted information of each pixel to the associated frame\n",
    "            for i in range(batch_size):\n",
    "                # Get associated frame\n",
    "                frame_idx = dsi_idxs[i].item() - start_idx\n",
    "                # Append information\n",
    "                # Sub_DSI is replaced by the networks prediction\n",
    "                frames_dict[frame_idx].append((pixel_positions[i],\n",
    "                                               network_based_estimates[i],\n",
    "                                               true_depths[i],\n",
    "                                               argmax_depths[i]))\n",
    "\n",
    "    # Transform dictionary to list of lists\n",
    "    frames_data = list(frames_dict.values())\n",
    "\n",
    "    return frames_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b86124cb-0857-4403-bb83-d8af34cfc6e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute index of last DSI\n",
    "if end_idx is None:\n",
    "    end_idx = start_idx + len(data.ground_truths)\n",
    "# Apply network to data to predict depth for each pixel\n",
    "frames_data = apply_network_to_data(model, dataloader, start_idx, end_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38f9269-f093-489d-baba-694df38e90f5",
   "metadata": {},
   "source": [
    "# Create Frames\n",
    "\n",
    "Having the information for each pixel, we not want to create the actual, colored frames.\n",
    "Each frame consists of 6 subframes, ordered the following way:\n",
    "1. row: Events | MC-EMVS Argmax | Dense Ground Truth\n",
    "2. row: Confidence Map | Our model estimate | Masked Ground Truth\n",
    "\n",
    "All subframes will be colored by \"jet\", ranging from blue for close to red for distant objects. The only exception is the confidence map, which will be gray scaled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edaebf1d-585d-48b5-a0cb-86e2875e4626",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_frames(frames_data, data, thicken_pixels=True, frame_size=None):\n",
    "    \"\"\"\n",
    "    Creates colored list of numpy frames from data of selected pixels.\n",
    "\n",
    "    Args:\n",
    "        frames_data (list): List of estimated pixel data for each frame. Output from apply_network_to_data. \n",
    "        data (Dataset): Dataset created by the class DSI_Pixelswise_Dataset(Dataset).\n",
    "        thicken_pixels (bool): Thicken pixels by 1 in each direction. Becomes irrelevant for the multi-pixel network version.\n",
    "        frame_size (tuple): Frame size will be set by the choice of dataset, but can be set individually as well.    \n",
    "    \"\"\"\n",
    "\n",
    "    # Get frame size\n",
    "    if frame_size is None:\n",
    "        frame_size = (260, 346) if \"mvsec\" in data.dataset else (480,640)\n",
    "\n",
    "    # Color Maps\n",
    "    cmap_jet = plt.colormaps[\"jet\"]\n",
    "    cmap_jet.set_bad(color=\"white\")  # Set NaN color to white\n",
    "    cmap_gray = plt.colormaps[\"gray\"].reversed()\n",
    "\n",
    "    # Set start and end indices for each sub-frame\n",
    "    h_0, h_1, h_2 = [frame_size[0] * i for i in range(2+1)]\n",
    "    w_0, w_1, w_2, w_3 = [frame_size[1] * i for i in range(3+1)]\n",
    "\n",
    "    # Initialize list of frames\n",
    "    frames = []\n",
    "\n",
    "    # Iterate over frames\n",
    "    for frame_idx, frame_data in enumerate(frames_data):\n",
    "        # Create empty frame\n",
    "        frame = np.full((frame_size[0] * 2, frame_size[1] * 3), np.nan)\n",
    "        \n",
    "        # Confidence Map\n",
    "        confidence_map = data.confidence_maps[frame_idx]\n",
    "        frame[h_1:h_2, w_0:w_1] = confidence_map / 255\n",
    "        \n",
    "        # Dense Ground True Depth\n",
    "        frame[h_0:h_1, w_2:w_3] = data.ground_truths[frame_idx]\n",
    "\n",
    "        # Iterate over each pixel of frame:\n",
    "        for pixel_data in frame_data:\n",
    "            # Get data\n",
    "            pixel_position, network_depth, true_depth, argmax_depth = pixel_data\n",
    "            # Scale pixel position\n",
    "            if data.norm_pixel_pos:\n",
    "                pixel_position = pixel_position * torch.tensor(frame_size)\n",
    "                pixel_position = pixel_position.round().int()\n",
    "            pos_x, pos_y = pixel_position.tolist()\n",
    "\n",
    "            # Check if network is single- or multi-pixel version\n",
    "            if data.multi_pixel:\n",
    "                # Multi-pixel creates 3x3 grid for each pixel anyway\n",
    "                thicken_pixels = True\n",
    "                # Index to iterate over 3x3 grid of estimates per pixel\n",
    "                idx = 0\n",
    "                # Argmax depth estimation is not thickened in this case\n",
    "                frame[h_0 + pos_x, w_1 + pos_y] = argmax_depth\n",
    "\n",
    "            # Iterate over 3x3 grid around pixel to either thicken pixel visualization\n",
    "            # or apply multi-pixel version of network\n",
    "            for row in range(pos_x - thicken_pixels, pos_x + thicken_pixels + 1):\n",
    "                for col in range(pos_y - thicken_pixels, pos_y + thicken_pixels + 1):\n",
    "                    if data.multi_pixel:\n",
    "                        # Create 3x3 grid around selected pixel\n",
    "                        frame[h_1 + row, w_1 + col] = network_depth[idx]\n",
    "                        frame[h_1 + row, w_2 + col] = true_depth[idx]\n",
    "                        # Iterate over 3x3 grid\n",
    "                        idx += 1\n",
    "                    else:\n",
    "                        # Thicken selected pixel visualization by 3x3\n",
    "                        frame[h_1 + row, w_1 + col] = network_depth\n",
    "                        frame[h_1 + row, w_2 + col] = true_depth\n",
    "                        # Thicken argmax depth estimatino as well\n",
    "                        frame[h_0 + row, w_1 + col] = argmax_depth\n",
    "\n",
    "        # Apply colormaps\n",
    "        colored_frame = np.zeros((*frame.shape, 4)) \n",
    "        mask_jet = np.ones(frame.shape, dtype=bool)\n",
    "        mask_jet[h_1:h_2, w_0:w_1] = False\n",
    "        # Apply jet to all sub-frames but the confidence map\n",
    "        colored_frame[mask_jet] = cmap_jet(np.ma.masked_invalid(frame[mask_jet]))\n",
    "        # Apply gray scales to confidence map sub-frame\n",
    "        colored_frame[~mask_jet] = cmap_gray(frame[~mask_jet])\n",
    "\n",
    "        # Append colored frame to list of frames\n",
    "        frames.append(colored_frame)\n",
    "\n",
    "    return frames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fac5c4d7-387f-4062-bc5d-29ded9eb9537",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create colored frames\n",
    "frames = create_frames(frames_data, data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7157c29-4484-4e88-aa86-56990d532ff4",
   "metadata": {},
   "source": [
    "# Create Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "839ae094-83c6-4485-a6b0-5908ec07ada8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_video(frames, output_file=None, frame_size=None):\n",
    "    \"\"\"\n",
    "    Create a video player to display the colored frames.\n",
    "    It will consist of 2x3 sub-frames.\n",
    "    \n",
    "    Args:\n",
    "        frames (list): Colored frames as list of numpy arrays from create_frames.\n",
    "        output_file (str): If an output file is selected, the video will be saved under that file name. \n",
    "        frame_size (tuple): Frame size will be set by the choice of dataset, but can be set individually as well.\n",
    "    \"\"\"\n",
    "\n",
    "    # Get frame size\n",
    "    if frame_size is None:\n",
    "        frame_size = (260, 346) if \"mvsec\" in data.dataset else (480,640)\n",
    "    \n",
    "    # Calculate figure size to match the aspect ratio of frames\n",
    "    h, w = frame_size\n",
    "    title_space = 20\n",
    "    fig_width = 15  # inches\n",
    "    fig_height = 1.2 * fig_width * ((h * 2) / (w * 3)) # Maintain aspect ratio\n",
    "    line_width = 0.8\n",
    "\n",
    "    # Create figure\n",
    "    fig, ax = plt.subplots(figsize=(fig_width, fig_height))\n",
    "\n",
    "    # Remove padding and margins from the figure and axes\n",
    "    fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)\n",
    "    ax.set_xlim(0, w*3 + line_width)\n",
    "    ax.set_ylim(h*2 + line_width, 0)\n",
    "    ax.axis('off')  # Turn off axis\n",
    "\n",
    "    # Set up the lines to separate the subframes\n",
    "    for i in range(4):\n",
    "        ax.axvline(x=w * i, color='black', linewidth=line_width)\n",
    "    for i in range(3):\n",
    "        ax.axhline(y=h * i, color='black', linewidth=line_width)\n",
    "\n",
    "    # Add titles for each subframe\n",
    "    upper_titles = [\"Events\", \"MC-EMVS Denser Filter\", \"Ground Truth\"]\n",
    "    lower_titles = [\"Confidence Map\", \"Ours (DERD-Net)\", \"Masked Ground Truth\"]\n",
    "    if multi_pixel:\n",
    "        lower_titles[-2] = \"Ours (DERD-Net Multi-Pixel)\"\n",
    "\n",
    "    for i, title in enumerate(upper_titles):\n",
    "        ax.text(i*w + w//2, - title_space, title, horizontalalignment='center', verticalalignment='center', color='black', fontsize=20)\n",
    "    for i, title in enumerate(lower_titles):\n",
    "        ax.text(i*w + w//2, h*2 + title_space, title, horizontalalignment='center', verticalalignment='top', color='black', fontsize=20)\n",
    "\n",
    "    # Load images from frames\n",
    "    def updatefig(i):\n",
    "        im.set_array(frames[i])\n",
    "        return im,\n",
    "\n",
    "    # Create animation\n",
    "    im = ax.imshow(frames[0], animated=True)\n",
    "    ani = animation.FuncAnimation(fig, updatefig, frames=len(frames), interval=50, blit=True)\n",
    "\n",
    "    # Save video under output_file if one is selected\n",
    "    if output_file:\n",
    "        ani.save(output_file, fps=20, writer='ffmpeg')  # Adjust FPS as needed\n",
    "        print(f\"Video saved to {output_file}\")\n",
    "    \n",
    "    plt.close(fig)\n",
    "    return HTML(ani.to_html5_video())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "720cb7d7-8b83-4b2e-99bc-274058c91269",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_file = \"example_file.mp4\"\n",
    "display_video(frames, output_file=output_file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (stereo_depth_estimation)",
   "language": "python",
   "name": "stereo_depth_estimation"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
