{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "31dc98cc",
   "metadata": {},
   "source": [
    "# About this Notebook\n",
    "\n",
    "This notebook shall be used for inference. For a given list of DSIs and possibly the according adaptive threshold filters, a trained model is loaded and then applied receive depthmap estimations.\n",
    "\n",
    "It relies on the classes and functions defined in *Classes_and_Functions.ipynb*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f90dccc",
   "metadata": {},
   "source": [
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864c7190",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "import random\n",
    "import os\n",
    "import gc\n",
    "import re\n",
    "import time\n",
    "\n",
    "# Third-party library imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2  # OpenCV for adaptive filtering\n",
    "import psutil  # For system resource management\n",
    "from scipy.ndimage import convolve  # To convolve filtering masks\n",
    "\n",
    "# PyTorch specific imports\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24e55f9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notebooks\n",
    "import import_ipynb\n",
    "from Classes_and_Functions import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1f0762d",
   "metadata": {},
   "source": [
    "# Load Model\n",
    "\n",
    "First, we load our trained models. For a runtime analysis, only load a single model: set <b>num_models = 1</b>.\n",
    "\n",
    "To leverage ensemble learning instead, set <b>num_models </b> to a value greater than one and load the parameters for each individual model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a6d2ce4",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_models = 1 #  How many models. Set to 1 for runtime analysis.\n",
    "multi_pixel = False #  Single or multi-pixel version of the network\n",
    "sub_frame_radius_h = 3 #  Radius of the Sub-DSIs\n",
    "sub_frame_radius_w = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19f0d299",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize models\n",
    "models = [PixelwiseConvGRU(sub_frame_radius_h, sub_frame_radius_w, multi_pixel=multi_pixel) for _ in range(num_models)]\n",
    "# Send to cuda\n",
    "if torch.cuda.is_available():\n",
    "    for model in models:\n",
    "        model.cuda()\n",
    "# Print architecture\n",
    "print(models[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d3ec4c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path to load models from directory\n",
    "model_paths = [\"example_path_A.pth\", \"example_path_B.pth\"] # length has to be equal to num_models\n",
    "# Give names of model files\n",
    "model_files = [\"example_model_A.pth\", \"example_model_B.pth\"] # length has to be equal to num_models\n",
    "# Do not forget \".pth\"\n",
    "for idx, model_file in enumerate(model_files):\n",
    "    if not model_file.endswith(\".pth\"):\n",
    "        model_files[idx] += \".pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1bccb26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models parameters\n",
    "for idx, model in enumerate(models):\n",
    "    model.load_parameters(model_files[idx], model_path=model_paths[idx], optimizer=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1d12896",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use ensemble learning to create averaged model\n",
    "if num_models > 1:\n",
    "    model = AveragedNetwork(models)\n",
    "else:\n",
    "    model = models[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40d1e314",
   "metadata": {},
   "source": [
    "# Data\n",
    "\n",
    "We want to estimate depth for a given set of DSIs and adaptive gaussian threshold filters. If DSIs are not given yet, load them. If Threshold filters are not given yet, compute them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "642865fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "dsi_list = None #  Give here\n",
    "threshold_mask_list = None #  Give here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ee5f6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Give directory where DSIs are stored\n",
    "dsi_directory = \"example_dsi_directory\"\n",
    "modality = \"stereo\" #  Choose \"stereo\" or \"mono\"\n",
    "dsi_num_expression = \"\\d+\\.\\d+|d+\" #  Numerical expression of files for sorting\n",
    "# Set start and end index\n",
    "start_idx, end_idx = 0, None\n",
    "# Load the desired DSIs\n",
    "if dsi_list is None or dsi_list == []:\n",
    "    dsi_list = load_dsi_list(dsi_directory, modality, dsi_num_expression, start_idx, end_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d959523c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Original filter parameters for MVSEC indoor flying and DSEC zurich city 04a\n",
    "\"\"\"\n",
    "MVSEC indoor flying 1,2,3:\n",
    "    Original: filter_size = 5 | adaptive_threshold_c = -14\n",
    "    Denser: filter_size = 9 | adaptive_threshold_c = -10\n",
    "    max_confidence =  [57.7, 78, 78.8]\n",
    "DSEC zurich city 04a:\n",
    "    Original: filter_size = 5 | adaptive_threshold_c = -4\n",
    "    Denser: filter_size = 9 | adaptive_threshold_c = -2\n",
    "    max_confidence = 468\n",
    "\"\"\"\n",
    "\n",
    "# Choose filter parameters\n",
    "filter_size = 5 #  (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter\n",
    "adaptive_threshold_c = -14 #  (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.\n",
    "max_confidence = 57.7 #  (int): The maximum relevant ray count in the DSI sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90eae6af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create threshold masks\n",
    "if threshold_mask_list is None or threshold_mask_list == []:\n",
    "    threshold_mask_list = [get_threshold_mask(dsi, filter_size, adaptive_threshold_c, max_confidence) for dsi in dsi_list]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "291dc8be",
   "metadata": {},
   "source": [
    "# Inference\n",
    "\n",
    "The estimated depthmaps are generated by creating an instance of the *Estimated_Depthmaps* class.\n",
    "\n",
    "They are stored under <b>self.estimated_depths</b>.\n",
    "\n",
    "For time analysis, it is essential to run the inferece several times with <b>batch_size=1</b> for an accurate estimate.\n",
    "\n",
    "Note that the AveragedNetwork Class increases the inference time significantly. To analyze it, it is therefore best to set <b>num_models=1</b>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf98aa29",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run Inference\n",
    "inference = Estimated_Depthmaps(model,\n",
    "                                dsi_list,\n",
    "                                threshold_mask_list,\n",
    "                                # Input creation (Sub-DSIs)\n",
    "                                sub_frame_radius_h=sub_frame_radius_h,\n",
    "                                sub_frame_radius_w=sub_frame_radius_w,\n",
    "                                batch_size=1\n",
    "                               )\n",
    "# Get Inference Times\n",
    "network_inference_times = inference.network_inference_times\n",
    "# Take Average\n",
    "avg_inference_time = np.mean(network_inference_times)\n",
    "# Print average inference time of network application per Sub-DSI\n",
    "print(\"Inference Time [ms]:\", round(1000 * avg_inference_time, 2), \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "551cdcc3-3c74-4bcb-9da4-40061f37877d",
   "metadata": {},
   "source": [
    "The depthmaps are stored under <b>self.estimated_depths</b>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6e30f0f-c447-4186-9719-d71965d83715",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Sequence Length:\", len(inference.estimated_depths))\n",
    "print(\"Depthmap Shape\", inference.estimated_depths[0].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fbbb875",
   "metadata": {},
   "source": [
    "# Visualize\n",
    "\n",
    "The method <b>self.create_colored_depth_estimations()</b> applies the 'jet' colormap to them.\n",
    "\n",
    "The colored depthmaps are then stored under <b>self.colored_depth_estimations</b>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c494835b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Color depthmaps\n",
    "inference.create_colored_depth_estimations()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db05f504",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Running the following cell will plot the colored depthmap estimations\n",
    "figsize=(10,6)\n",
    "for depthmap in inference.colored_depth_estimations:\n",
    "    plt.figure(figsize=figsize)\n",
    "    plt.imshow(depthmap)\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (stereo_depth_estimation)",
   "language": "python",
   "name": "stereo_depth_estimation"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
