{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a7dea1e6",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# About this Notebook\n",
    "\n",
    "In this notebook we will define \n",
    "- the data class to transform DSIs into a format that can be processed by the neural network\n",
    "- a streamlined data class used directly for inference\n",
    "- the neural network class\n",
    "- the metrics and loss function\n",
    "- the functions for training process\n",
    "- the functions for the testing process\n",
    "\n",
    "These methods will then be used an executed in the other notebooks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5fe841e",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Dependencies\n",
    "\n",
    "We will use PyTorch as frame."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "007f0689",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "import random\n",
    "import os\n",
    "import time\n",
    "import gc\n",
    "import re\n",
    "\n",
    "# Third-party library imports\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2  # OpenCV for adaptive filtering\n",
    "import psutil  # For system resource management\n",
    "from scipy.ndimage import convolve  # To convolve filtering masks\n",
    "\n",
    "# PyTorch specific imports\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "644f8b82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get cpu or  gpu device.\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3347d9da-1ef2-4e3d-8508-250d6ac56cd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get path of directory\n",
    "current_dir = os.getcwd()\n",
    "parent_dir = os.path.dirname(current_dir)\n",
    "\n",
    "print(f\"Current directory: {current_dir}\")\n",
    "print(f\"Parent directory: {parent_dir}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "883f7eba",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Dataclass \n",
    "\n",
    "For Training, Testing and Visualization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c526d26b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DSI_Pixelswise_Dataset(Dataset):\n",
    "    \"\"\"\n",
    "    A dataset class to transform DSIs into data for the neural network.\n",
    "    DSIs are filtered for confident pixels by applying an adaptive threshold filter over the maximum ray counts of each pixel.\n",
    "    For each of these pixels, a surrounding subregion of the DSI is stored and normalized as a Sub-DSI to be used as data instance.\n",
    "    The inputs will be these Sub-DSIs, while the targets will be the ground true depth at the according pixels.\n",
    "    \n",
    "    Args:\n",
    "        # DSI Selection Arguments\n",
    "        dataset (str): The dataset used.\n",
    "        data_seq (str): The specific sequence of the chosen dataset. Must be adjusted to user.\n",
    "        dsi_directory (str): Directory of the DSIs. Must be adjusted to user.\n",
    "        depthmap_directory (str): Directory of the groundtrue depths for each DSI.\n",
    "        dsi_num_expression (str): Numcerical expression of DSI files for sorting.\n",
    "        depthmap_num_expression (str): Numcerical expression of depthmap files for sorting.\n",
    "        dsi_split (str or int): Which DSIs shall be considered. Can be \"all\", \"even\", \"odd\" or a number between 0 and 9.\n",
    "        dsi_ratio (float): Between 0 and 1. Defines the proportion of (random) DSIs that shall be used.\n",
    "        start_idx, end_idx (str): Start and stop indices for which DSIs to consider. \n",
    "        start_row, end_row, start_col, end_col (str): Define the rows and columns to be considered within each DSI.\n",
    "\n",
    "        # DSI processing\n",
    "        neg_depth_axis (bool): States whether the depth axis of the DSI has been defined negatively upon creation.\n",
    "        normalize_dsi (bool): DSIs can be normalized beforehand instead of scaling each Sub-DSIs individually by its own highest ray count.\n",
    "        \n",
    "        # Pixel selection\n",
    "        filter_size (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter.\n",
    "        adaptive_threshold_c (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.\n",
    "        visualization_mode (bool): Flag. If False only pixels with known ground truth depth as targets are stored as data points.\n",
    "\n",
    "        # Input creation (Sub-DSIs)\n",
    "        sub_frame_radius_h (int): Defines the radius of the frame at the height axis around the central pixel for the Sub-DSI.\n",
    "        sub_frame_radius_w (int): Defines the radius of the frame at the width axis around the central pixel for the Sub-DSI.\n",
    "        center_as_norm_ref (bool): Sub-DSIs can be normalized with respect to highest ray count of central pixel instead of total max value.\n",
    "        norm_pixel_pos (bool): Defines whether the pixel position should be normalized (to be used as additional input) or not (better for reconstruction).\n",
    "        \n",
    "        # Target creation (ground true depth)\n",
    "        multi_pixel (bool): Determines whether depth is predicted only for the central selected pixel or for the 8 neighbouring pixels as well.\n",
    "        inverse_space (bool): Defines if target depths and argmax estimates are normalized in linear or inverse space.\n",
    "        clip_targets (bool): Defines whether targets should be clipped to the believed min and max distance, thus being normalized to inbetween 0 and 1.\n",
    "    \n",
    "        # Execution and debugging\n",
    "        preload_data (bool): Defines whether the data should by loaded directly upon creating a class instance.\n",
    "        print_progress (bool): Decides whether the progress of creating the data for the class should by displayed.\n",
    "        debugging (bool): Debugging mode visually prints some data instances as images.\n",
    "    \n",
    "    Attributes:\n",
    "        # Parameters\n",
    "        frame_height, frame_width (int): The height and width of the frame, equal to the dimensions of the DSIs.\n",
    "        min_depth, max_depth (int): Set the estimated range of distance for which the DSIs were created.\n",
    "        max_confidence (int): The maximum relevant ray count within a dataset sequence.\n",
    "        distCoeffs, K, P (arr): Coefficients to describe the camera lens properties for (un)distortion.\n",
    "        # Data information\n",
    "        data_list (list): List of all data instances.\n",
    "        pixel_count (int): Denotes the total number of pixels for which depth would be predicted (includes pixels without available ground truth).\n",
    "    \n",
    "    The attributes are defined by the dataset sequence itself and the way the DSIs were created and should only be changed accordingly.\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 # DSI selection arguments\n",
    "                 dataset=\"mvsec_stereo\",\n",
    "                 data_seq=1,\n",
    "                 dsi_directory=None,\n",
    "                 depthmap_directory=None,\n",
    "                 dsi_num_expression=\"\\d+\\.\\d+|d+\",\n",
    "                 depthmap_num_expression=\"\\d+\",\n",
    "                 dsi_split=\"all\",\n",
    "                 dsi_ratio=1.0,\n",
    "                 start_idx=0, end_idx=None,\n",
    "                 start_row=0, end_row=None,\n",
    "                 start_col=0, end_col=None,\n",
    "                 # DSI processing\n",
    "                 neg_depth_axis=True,\n",
    "                 normalize_dsi=False,\n",
    "                 # Pixel selection\n",
    "                 filter_size=None,\n",
    "                 adaptive_threshold_c=None,\n",
    "                 visualization_mode=False,\n",
    "                 # Input creation (Sub-DSIs)\n",
    "                 sub_frame_radius_h=3,\n",
    "                 sub_frame_radius_w=3,\n",
    "                 center_as_norm_ref=False,\n",
    "                 norm_pixel_pos=False,\n",
    "                 # Target creation (ground true depth)\n",
    "                 multi_pixel=False,\n",
    "                 inverse_space=False,\n",
    "                 clip_targets=True,\n",
    "                 # Execution and debugging\n",
    "                 preload_data=True,\n",
    "                 print_progress=False,\n",
    "                 debugging=False\n",
    "                ):\n",
    "\n",
    "        # Args\n",
    "        self.dataset = dataset\n",
    "        self.data_seq = data_seq\n",
    "        self.dsi_num_expression = dsi_num_expression\n",
    "        self.depthmap_num_expression = depthmap_num_expression\n",
    "        self.dsi_split = dsi_split\n",
    "        self.dsi_ratio = dsi_ratio\n",
    "        self.neg_depth_axis = neg_depth_axis\n",
    "        self.normalize_dsi = normalize_dsi\n",
    "        self.visualization_mode = visualization_mode\n",
    "        self.sub_frame_radius_h = sub_frame_radius_h\n",
    "        self.sub_frame_radius_w = sub_frame_radius_w\n",
    "        self.center_as_norm_ref = center_as_norm_ref\n",
    "        self.norm_pixel_pos = norm_pixel_pos\n",
    "        self.multi_pixel = multi_pixel\n",
    "        self.inverse_space = inverse_space\n",
    "        self.clip_targets = clip_targets\n",
    "        self.preload_data = preload_data\n",
    "        self.print_progress = print_progress\n",
    "        self.debugging = debugging\n",
    "        \n",
    "        # Default DSI and depthmap directories\n",
    "        if \"mvsec\" in self.dataset:\n",
    "            self.depthmap_directory = f\"{parent_dir}/data/mvsec/indoor_flying{self.data_seq}/depthmaps/\"\n",
    "            dsi_path = f\"{parent_dir}/data/mvsec/indoor_flying{self.data_seq}\"\n",
    "        elif \"dsec\" in self.dataset:\n",
    "            self.depthmap_directory = f\"{parent_dir}/data/dsec/depthmaps/\"\n",
    "            dsi_path = f\"{parent_dir}/data/dsec\"\n",
    "        if \"mono\" in self.dataset:\n",
    "            self.dsi_directory = f\"{dsi_path}/dsi_monocular/\"\n",
    "        else:\n",
    "            self.dsi_directory = f\"{dsi_path}/dsi_stereo/\"\n",
    "        # Use adjusted directory instead if specified\n",
    "        if dsi_directory is not None:\n",
    "            self.dsi_directory = dsi_directory\n",
    "        if depthmap_directory is not None:\n",
    "            self.depthmap_directory = depthmap_directory\n",
    "        \n",
    "        # Assert that file ranges ranges are feasible\n",
    "        assert self.is_range_feasible(start_idx, end_idx)\n",
    "        self.start_idx, self.end_idx = start_idx, end_idx\n",
    "        assert self.is_range_feasible(start_row, end_row)\n",
    "        self.start_row, self.end_row = start_row, end_row\n",
    "        assert self.is_range_feasible(start_col, end_col)\n",
    "        self.start_col, self.end_col = start_col, end_col\n",
    "\n",
    "        # In debugging mode we only analyze a single DSI\n",
    "        if self.debugging:\n",
    "            self.end_idx = min(self.end_idx, self.start_idx + 2)\n",
    "\n",
    "        # Set filter parameters to default for the data sequence or to the value specified by the argument\n",
    "        self.filter_size = [5, 5, 5][self.data_seq-1] if \"mvsec\" in self.dataset else 5\n",
    "        if filter_size is not None:\n",
    "            self.filter_size = filter_size\n",
    "        self.adaptive_threshold_c = [-14, -14, -14][self.data_seq-1] if \"mvsec\" in self.dataset else -4\n",
    "        if adaptive_threshold_c is not None:\n",
    "            self.adaptive_threshold_c = adaptive_threshold_c\n",
    "        \n",
    "        # Attributes:\n",
    "        self.frame_height, self.frame_width = None, None #  Will be updated with the first DSI\n",
    "        # estimated depth range\n",
    "        self.min_depth = [1, 1, 1][self.data_seq-1] if \"mvsec\" in self.dataset else 4\n",
    "        self.max_depth = [6.5, 6.5, 6.5][self.data_seq-1] if \"mvsec\" in self.dataset else 50\n",
    "        # maximum relevant ray count\n",
    "        self.max_confidence = [57.7, 78, 78.8][self.data_seq-1] if \"mvsec\" in self.dataset else 468 \n",
    "        # camera lens coefficients for undistortion\n",
    "        self.distCoeffs = np.array([-0.048031442223833355, 0.011330957517194437, -0.055378166304281135, 0.021500973881459395])\n",
    "        self.K = np.reshape(\n",
    "            [226.38018519795807, 0.0, 173.6470807871759, 0.0, 226.15002947047415, 133.73271487507847, 0, 0, 1],\n",
    "            (3, 3)\n",
    "        )\n",
    "        self.P = np.reshape(\n",
    "            [199.6530123165822, 0.0, 177.43276376280926, 0.0, 0.0, 199.6530123165822, 126.81215684365904, 0.0, 0.0, 0.0, 1.0, 0.0],\n",
    "            (3, 4)\n",
    "        )\n",
    "\n",
    "        # Get file names of DSIs and ground true depths\n",
    "        self.dsi_files = self.get_files()\n",
    "        self.depthmap_files = self.get_files(depthmaps=True)\n",
    "        len(self.dsi_files) == len(self.depthmap_files)\n",
    "        \n",
    "        # Initialize data list\n",
    "        self.data_list = []\n",
    "        self.pixel_count = 0\n",
    "        # Store ground truth, confidence maps and argmax estimation maps for possible visualization purposes\n",
    "        if self.visualization_mode:\n",
    "            self.ground_truths = []\n",
    "            self.confidence_maps = []\n",
    "            self.argmax_maps = []\n",
    "        \n",
    "        # Create Data\n",
    "        if self.preload_data:\n",
    "            self.get_data()\n",
    "\n",
    "    \n",
    "    \"\"\"Special Methods:\"\"\"\n",
    "    def __len__(self):\n",
    "        return len(self.data_list)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # pixel_index, sub_dsi, pixel_depth, argmax_depth, frame_idx\n",
    "        return self.data_list[idx]\n",
    "\n",
    "    \n",
    "    \"\"\"Utility Methods:\"\"\"\n",
    "    def get_files(self, depthmaps=False):\n",
    "        \"\"\"Get all file names within selected index-range of DSIs or true depthmaps from the given directory.\"\"\"\n",
    "        if not depthmaps:\n",
    "            # Get DSI files\n",
    "            directory = self.dsi_directory\n",
    "            suffix = \"_fused.npy\" if \"mono\" not in self.dataset else \"_0.npy\"\n",
    "            num_expression = self.dsi_num_expression\n",
    "        else:\n",
    "            # Get ground true depthmap files\n",
    "            directory = self.depthmap_directory\n",
    "            suffix = \".npy\"\n",
    "            num_expression = self.depthmap_num_expression\n",
    "\n",
    "        # Load according list of file names from directory\n",
    "        files = [file for file in os.listdir(directory) if file.endswith(suffix)]\n",
    "        # Sort files based on number in name\n",
    "        files.sort(key=lambda file: float(re.findall(num_expression, file)[0]))\n",
    "        # Select for indices\n",
    "        files = files[self.start_idx:self.end_idx]\n",
    "        \n",
    "        return files\n",
    "\n",
    "    def get_data(self):\n",
    "        \"\"\"Iterate over pixels of DSI and add all selected pixels to the data list.\"\"\"\n",
    "        for idx, dsi_file in enumerate(self.dsi_files):\n",
    "            # Decide whether frame should be processed:\n",
    "            if self.should_process_dsi(idx):\n",
    "                # Select pixels within DSI, create the associated data instances and append them to self.data_list\n",
    "                self.get_pixels(idx, dsi_file)\n",
    "\n",
    "    def get_pixels(self, idx, dsi_file):\n",
    "        \"\"\"\n",
    "        This is the main method for creating the data.\n",
    "        It loads the DSI and the associated depthmap.\n",
    "        Then, a selection filter is applied to the DSI to select pixels that shall become data instances.\n",
    "        The surrounding Sub-DSIs are created and the true depth values are selected as targets.\n",
    "        Additionally, the pixel position and the argmax estimation along the depth axis are stored,\n",
    "        Together, for every selected pixel, all these information are added as one data instance to self.data_list.\n",
    "        \"\"\"\n",
    "        # Track progress in processing the DSIs \n",
    "        dsi_idx = self.start_idx + idx\n",
    "        if self.print_progress:\n",
    "            print(\"Load DSI\", dsi_idx)\n",
    "        # Load DSI with threshold_mask\n",
    "        dsi, threshold_mask = self.get_dsi(dsi_file)\n",
    "        # Update frame dimensions\n",
    "        if idx == 0:\n",
    "            self.frame_height, self.frame_width = dsi.shape[1:]\n",
    "        \n",
    "        # Load depthmap with target_mask\n",
    "        depthmap, target_mask = self.get_depthmap(idx)\n",
    "        # Create mask to unselect pixels too close to the border to create a Sub-DSI around it\n",
    "        border_mask = np.zeros_like(threshold_mask)\n",
    "        border_mask[self.sub_frame_radius_h:-self.sub_frame_radius_h, self.sub_frame_radius_w:-self.sub_frame_radius_w] = True\n",
    "        # Combine both masks\n",
    "        selection_mask = border_mask & threshold_mask\n",
    "        # Count pixels\n",
    "        if self.multi_pixel:\n",
    "            # Expand mask to include adjacent pixels\n",
    "            kernel = np.ones((3,3))\n",
    "            expanded_mask = convolve(selection_mask, kernel, mode=\"constant\", cval=0)\n",
    "            # Add pixel count\n",
    "            self.pixel_count += np.sum(expanded_mask)\n",
    "        else:\n",
    "            self.pixel_count += np.sum(selection_mask)\n",
    "        # Add target available mask for training and numerical evaluation\n",
    "        if not self.visualization_mode:\n",
    "            selection_mask &= target_mask\n",
    "        # Deduce indices of selected pixels\n",
    "        selected_indices = list(zip(*np.where(selection_mask)))\n",
    "        \n",
    "        # Get scaled argmax along depth axis as competetive estimate\n",
    "        argmax_estimates = (dsi.argmax(dim=0) + 1) / dsi.shape[0] \n",
    "        # If we use the linear instead of the inverse linear space, the argmax estimates have to be projected accordingly\n",
    "        if not self.inverse_space:\n",
    "            # Backproject into original depth space\n",
    "            argmax_estimates = 1 / (1/self.min_depth - argmax_estimates * (1/self.min_depth - 1/self.max_depth))\n",
    "            # Project into linear depth space\n",
    "            argmax_estimates = (argmax_estimates - self.min_depth) / (self.max_depth - self.min_depth)\n",
    "        # Store argmax estimates\n",
    "        if self.visualization_mode:\n",
    "            self.argmax_maps.append(argmax_estimates)\n",
    "        \n",
    "        # Create data for selected pixels and add to self.data_list\n",
    "        for pixel_index in selected_indices:\n",
    "            # Get argmax depth estimate for pixel\n",
    "            argmax_depth = argmax_estimates[pixel_index].clone()\n",
    "            # Get ground true depth at pixel\n",
    "            if not self.multi_pixel:\n",
    "                pixel_depth = depthmap[pixel_index].clone()\n",
    "            else:\n",
    "                # For the multi-pixel version we also need the pixel depths at the direct neighbors\n",
    "                x, y = pixel_index\n",
    "                pixel_depth = depthmap[x-1 : x+2, y-1 : y+2].flatten()\n",
    "            # Get sub DSI around selected pixel\n",
    "            sub_dsi = self.get_sub_dsi(dsi, pixel_index)\n",
    "            # Normalize pixel position and convert to tensor\n",
    "            pixel_pos = torch.tensor(pixel_index)\n",
    "            if self.norm_pixel_pos:\n",
    "                pixel_pos /= torch.tensor(dsi.shape[1:]) #  Divide through frame_x and frame_y size of DSI\n",
    "            # Add data to list of data\n",
    "            pixel_data = (pixel_pos, sub_dsi, pixel_depth, argmax_depth, dsi_idx)\n",
    "            self.data_list.append(pixel_data)\n",
    "\n",
    "        # imshow all data steps for debugging\n",
    "        if self.debugging:\n",
    "            self.visualize_data_for_debugging(dsi, depthmap, threshold_mask, target_mask, border_mask, selection_mask, sub_dsi)        \n",
    "        # delete DSI from memory\n",
    "        del dsi\n",
    "        gc.collect()\n",
    "\n",
    "    \n",
    "    \"\"\"Helper Methods:\"\"\"\n",
    "    def should_process_dsi(self, dsi_idx):\n",
    "        \"\"\"Select whether a DSI should be processed based on the data split and the desired ratio of processed DSIs.\"\"\"\n",
    "        # Only select subset of DSIs\n",
    "        if random.random() > self.dsi_ratio:\n",
    "            return False\n",
    "    \n",
    "        # Select eitehr every DSI, only even/odd ones or only every 10th DSI \n",
    "        if self.dsi_split == \"all\":\n",
    "            return True\n",
    "        elif self.dsi_split == \"even\":\n",
    "            return dsi_idx % 2 == 0\n",
    "        elif self.dsi_split == \"odd\":\n",
    "            return dsi_idx % 2 == 1\n",
    "        elif self.dsi_split in range(10):\n",
    "            return dsi_idx % 10 == self.dsi_split\n",
    "        else:\n",
    "            raise ValueError(\"Invalid value for dsi_split. Must be set to 'all', 'even', or 'odd'. Current value: {}\".format(self.dsi_split))    \n",
    "\n",
    "    def get_dsi(self, dsi_file):\n",
    "        \"\"\"Load and process DSI. Create adaptive threshold mask based on maximum amount of rays per pixel.\"\"\"\n",
    "        # Load (specified area of) DSI as 3d-numpy array\n",
    "        dsi = np.load(f\"{self.dsi_directory}{dsi_file}\")[:, self.start_row:self.end_row, self.start_col:self.end_col]\n",
    "        # Compute threshold mask\n",
    "        threshold_mask = self.get_threshold_mask(dsi)\n",
    "        # Normalize DSI (alternatively normalize Sub-DSI)\n",
    "        if self.normalize_dsi and np.max(dsi) > 0:\n",
    "            dsi /= np.max(dsi)\n",
    "        # Transform DSI to pytorch tensor\n",
    "        dsi = torch.from_numpy(dsi)\n",
    "        # Flip DSI along depth axis\n",
    "        if self.neg_depth_axis:\n",
    "            dsi = dsi.flip(dims=[0])\n",
    "        return dsi, threshold_mask\n",
    "\n",
    "    def get_threshold_mask(self, dsi):\n",
    "        \"\"\"Create adaptive threshold mask based on maximum amount of counted rays per pixel.\"\"\"\n",
    "        # Take the max values of DSI along the depth axis\n",
    "        confidence_map = np.max(dsi, axis=0)\n",
    "        # Determine the maximum value for normalization\n",
    "        dsi_max_confidence = np.max(confidence_map)\n",
    "        normalization_max_confidence = max(self.max_confidence, dsi_max_confidence)\n",
    "        # Scale it to inbetween 0 and 255\n",
    "        confidence_map_normalized = np.around(confidence_map * 255 / normalization_max_confidence).astype('uint8')    \n",
    "        # Store confidence map\n",
    "        if self.visualization_mode:\n",
    "            self.confidence_maps.append(confidence_map_normalized)\n",
    "        # Apply adaptive threshold\n",
    "        threshold_mask = cv2.adaptiveThreshold(confidence_map_normalized, 255,\n",
    "                                               cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,\n",
    "                                               self.filter_size, self.adaptive_threshold_c)\n",
    "        return threshold_mask.astype(bool)\n",
    "\n",
    "    def get_depthmap(self, idx):\n",
    "        \"\"\"Load and preprocess depthmap and create mask for pixels with available ground true depth.\"\"\"\n",
    "        # Get depthmap file\n",
    "        depthmap_file = self.depthmap_files[idx]\n",
    "        # Load groundtrue depthmap as 2d-numpy array\n",
    "        depthmap = np.load(f\"{self.depthmap_directory}{depthmap_file}\")\n",
    "        # Undistort and eliminate resulting zero values\n",
    "        if \"mvsec\" in self.dataset:\n",
    "            depthmap = cv2.fisheye.undistortImage(depthmap, self.K, self.distCoeffs, None, self.P)\n",
    "        # Zero values mean no ground true depth is available\n",
    "        depthmap[depthmap == 0] = np.nan\n",
    "        # Select specified area\n",
    "        depthmap = depthmap[self.start_row:self.end_row, self.start_col:self.end_col]\n",
    "        # Trunkate values outside of predicted range\n",
    "        if self.clip_targets:\n",
    "            depthmap = depthmap.clip(self.min_depth, self.max_depth)\n",
    "        # Scale to inbetween 0 and 1\n",
    "        if self.inverse_space:\n",
    "            # If depth levels shall be projected linearly into inverse space\n",
    "            depthmap = (1/self.min_depth - 1/depthmap) / (1/self.min_depth - 1/self.max_depth)\n",
    "        else:\n",
    "            # If depth levels shall be projected into linear space\n",
    "            depthmap = (depthmap - self.min_depth) / (self.max_depth - self.min_depth)\n",
    "    \n",
    "        # Get target mask (available ground true depth values)\n",
    "        target_mask = ~np.isnan(depthmap)\n",
    "\n",
    "        # Store ground truth depth map\n",
    "        if self.visualization_mode:\n",
    "            self.ground_truths.append(depthmap.copy())\n",
    "        # Transform to pytorch tensor\n",
    "        depthmap = torch.from_numpy(depthmap)\n",
    "    \n",
    "        return depthmap, target_mask\n",
    "\n",
    "    def get_sub_dsi(self, dsi, pixel_index):\n",
    "        \"\"\"Get sub DSI around selected pixel with frame size of 2*sub_frame_radius + 1.\"\"\"\n",
    "        # Get frame borders\n",
    "        h, w = pixel_index\n",
    "        sub_frame_h = slice(h - self.sub_frame_radius_h, h + self.sub_frame_radius_h + 1)\n",
    "        sub_frame_w = slice(w - self.sub_frame_radius_w, w + self.sub_frame_radius_w + 1)\n",
    "        # Select subregion of DSI\n",
    "        sub_dsi = dsi[:,sub_frame_h, sub_frame_w].clone()\n",
    "        # If DSI has not been normalized, normalize on Sub-DSI level\n",
    "        if not self.normalize_dsi:\n",
    "            # Normalization can be done either with regards to the highest ray count at the central pixel or the total Sub-DSI\n",
    "            if self.center_as_norm_ref:\n",
    "                # Max value at central pixel\n",
    "                max_val = sub_dsi[:, self.sub_frame_radius_h//2, self.sub_frame_radius_w//2].max()\n",
    "            else:\n",
    "                # max value of entire Sub-DSI\n",
    "                max_val = sub_dsi.max()\n",
    "            # normalize\n",
    "            if max_val > 0:\n",
    "                sub_dsi /= max_val\n",
    "                \n",
    "        return sub_dsi\n",
    "\n",
    "    def is_range_feasible(self, start, end):\n",
    "        \"\"\"Assert that start and end indices are feasible.\"\"\"\n",
    "        if not isinstance(start, int):\n",
    "            return False\n",
    "        if start < 0:\n",
    "            return False\n",
    "        if end:\n",
    "            if not isinstance(end, int):\n",
    "                return False\n",
    "            if start > end:\n",
    "                return False\n",
    "        return True\n",
    "\n",
    "    def visualize_data_for_debugging(self, dsi, depthmap, threshold_mask, target_mask, border_mask, selection_mask, sub_dsi):\n",
    "        \"\"\"Plot images of data for debugging.\"\"\"\n",
    "        import matplotlib.pyplot as plt\n",
    "        # DSI confidence map\n",
    "        dsi_max_vals = dsi.numpy().max(axis=0)\n",
    "        plt.figure(figsize=(16, 12))\n",
    "        plt.subplot(2,2,1)\n",
    "        plt.imshow(dsi_max_vals, cmap='Greys')\n",
    "        plt.colorbar(label='Max Value')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('DSI Confidence Map')\n",
    "        horizontal = torch.flip(torch.from_numpy(dsi_max_vals), [-1]).numpy()\n",
    "        plt.subplot(2,2,2)\n",
    "        plt.imshow(horizontal, cmap='Greys')\n",
    "        plt.colorbar(label='Max Value')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Horizontally Flipped')\n",
    "        vertical = torch.flip(torch.from_numpy(dsi_max_vals), [-2]).numpy()\n",
    "        plt.subplot(2,2,3)\n",
    "        plt.imshow(vertical, cmap='Greys')\n",
    "        plt.colorbar(label='Max Value')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Vertically Flipped')\n",
    "        rotated = torch.flip(torch.from_numpy(dsi_max_vals), [-2, -1]).numpy()\n",
    "        plt.subplot(2,2,4)\n",
    "        plt.imshow(rotated, cmap='Greys')\n",
    "        plt.colorbar(label='Max Value')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Rotated 180 Degrees')\n",
    "        plt.show()\n",
    "        # Mask for adaptive threshold\n",
    "        plt.figure(figsize=(8, 6))\n",
    "        plt.imshow(threshold_mask, cmap='Greys') #RdYlGn\n",
    "        plt.colorbar(label='Pixel Selected')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Adaptive Threshold Mask')\n",
    "        plt.show()\n",
    "        # Argmax depth estimate map\n",
    "        argmax_estimates = (dsi.argmax(dim=0) + 1) / dsi.shape[0]\n",
    "        plt.figure(figsize=(8, 6))\n",
    "        plt.imshow(argmax_estimates.numpy(), cmap='jet')\n",
    "        plt.colorbar(label='Depth')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Argmax Depth Estimate')\n",
    "        plt.show()\n",
    "        # Ground true depth map\n",
    "        plt.figure(figsize=(8, 6))\n",
    "        plt.imshow(depthmap.numpy(), cmap='jet')\n",
    "        plt.colorbar(label='Depth')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Undistorted Groundtrue Depthmap')\n",
    "        plt.show()\n",
    "        # Mask of available groundtrue depths\n",
    "        plt.figure(figsize=(10, 8))\n",
    "        plt.imshow(target_mask, cmap='Greys')\n",
    "        plt.colorbar(label='Pixel Selected')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Available Groundtrue Depths')\n",
    "        plt.show()\n",
    "        # Mask eliminating borders\n",
    "        plt.figure(figsize=(10, 8))\n",
    "        plt.imshow(border_mask, cmap='Greys')\n",
    "        plt.colorbar(label='Pixel Selected')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Pixels Within Border')\n",
    "        plt.show()\n",
    "        # Combined mask of selected pixels\n",
    "        plt.figure(figsize=(10, 8))\n",
    "        plt.imshow(selection_mask, cmap='Greys')\n",
    "        plt.colorbar(label='Pixel Selected')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Selected Pixels')\n",
    "        plt.show()\n",
    "        # Sub-DSI confidence map\n",
    "        sub_dsi_max_vals = sub_dsi.numpy().max(axis=0)\n",
    "        plt.figure(figsize=(8, 6))\n",
    "        plt.imshow(sub_dsi_max_vals, cmap='Greys')\n",
    "        plt.colorbar(label='Max Value')\n",
    "        plt.xlabel('Pixel X')\n",
    "        plt.ylabel('Pixel Y')\n",
    "        plt.title('Sub DSI Certainty Map')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed611257",
   "metadata": {},
   "source": [
    "### Debugging\n",
    "\n",
    "Uncomment to use debugging method of data class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7a5de36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# debug_1 = DSI_Pixelswise_Dataset(data_seq=1, start_idx = 500, end_idx = 501, debugging=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe58a7b1",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# debug_2 = DSI_Pixelswise_Dataset(data_seq=2, start_idx = 400, end_idx = 401, debugging=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcdc1c00",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# debug_3 = DSI_Pixelswise_Dataset(data_seq=3, start_idx = 600, end_idx = 601, debugging=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7e8b81d",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Inference-Dataclass\n",
    "\n",
    "Streamlined, soly for inference for a given list of DSIs and possibly given adpative gaussian threshold filters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998e9180-7a86-46cb-be75-c3b15b0379e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def transform_dsi(dsi):\n",
    "    \"\"\"Transform DSI to torch tensor.\"\"\"\n",
    "    # Transform DSI to pytorch tensor\n",
    "    dsi = torch.from_numpy(dsi).to(device)\n",
    "    # Flip DSI along depth axis\n",
    "    dsi = dsi.flip(dims=[0])\n",
    "    return dsi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa44d847",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_dsi_list(dsi_directory, modality, dsi_num_expression, start_idx=0, end_idx=None):\n",
    "    \"\"\"\n",
    "    Function to load all DSIs from a directory, starting and ending at a given index.\n",
    "    Args:\n",
    "        modality (str): Either stereo or mono.\n",
    "        dsi_num_expression (str): Numerical expression of DSIs for sorting.\n",
    "        start_idx, end_idx (str): Start and end index of DSI sequence.\n",
    "    \"\"\"\n",
    "    # Name expression of files\n",
    "    suffix = \"_fused.npy\" if modality == \"stereo\" else \"_0.npy\"\n",
    "    # Load according list of file names from directory\n",
    "    files = [file for file in os.listdir(dsi_directory) if file.endswith(suffix)]\n",
    "    # Sort files based on number in name\n",
    "    files.sort(key=lambda file: float(re.findall(dsi_num_expression, file)[0]))\n",
    "    # Select for indices\n",
    "    files = files[start_idx:end_idx]\n",
    "    # Load DSIs from their files\n",
    "    dsi_list = [transform_dsi(np.load(f\"{dsi_directory}{dsi_file}\")) for dsi_file in files]\n",
    "\n",
    "    return dsi_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5d45a6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_threshold_mask(dsi, filter_size, adaptive_threshold_c, max_confidence):\n",
    "    \"\"\"\n",
    "    Function to create an adaptive threshold mask based on maximum amount of counted rays per pixel.\n",
    "    Args:\n",
    "        dsi (numpy arr): A DSI of dimensions depth, height, width\n",
    "        filter_size (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter.\n",
    "        adaptive_threshold_c (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.\n",
    "        max_confidence (int): The maximum relevant ray count in the DSI sequence.\n",
    "    \"\"\"\n",
    "    # Take the max values of DSI along the depth axis\n",
    "    confidence_map = torch.max(dsi, dim=0)[0].cpu().numpy()  # np.max(dsi, axis=0) if DSI is numpy array\n",
    "    # Determine the maximum value for normalization\n",
    "    dsi_max_confidence = np.max(confidence_map)\n",
    "    normalization_max_confidence = max(max_confidence, dsi_max_confidence)\n",
    "    # Scale it to inbetween 0 and 255\n",
    "    confidence_map_normalized = np.around(confidence_map * 255 / normalization_max_confidence).astype('uint8')    \n",
    "    # Apply adaptive threshold\n",
    "    threshold_mask = cv2.adaptiveThreshold(confidence_map_normalized, 255,\n",
    "                                           cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,\n",
    "                                           filter_size, adaptive_threshold_c)\n",
    "    return threshold_mask.astype(bool)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dcf8fa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Estimated_Depthmaps():\n",
    "    \"\"\"\n",
    "    A dataset class to create the estimated depthmaps as inference from a given list of DSIs.\n",
    "    The adaptive gaussian threshold filter is either given or automatically computed along the way.\n",
    "    The depthmaps are then stored as a list of normalized numpy arrays under self.estimated_depths.\n",
    "    These depthmaps can be colored. For this, call the method self.create_colored_depth_estimations().\n",
    "    The colored depthmaps are then stored under self.colored_depth_estimations.\n",
    "\n",
    "    Args:\n",
    "        model (torch.nn.Module): A trained network model.\n",
    "        dsi_list (list of numpy arrs): A list of DSIs with dimensions (depth, height, width).\n",
    "        # Pixel selection\n",
    "        threshold_mask_list (list of numpy arrs): A list of adaptive gaussian threshold filters for each DSI.\n",
    "        filter_size (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter.\n",
    "        adaptive_threshold_c (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.\n",
    "        max_confidence (int): The maximum relevant ray count in the DSI sequence.\n",
    "        # Input creation (Sub-DSIs)\n",
    "        sub_frame_radius_h (int): Defines the radius of the frame at the height axis around the central pixel for the Sub-DSI.\n",
    "        sub_frame_radius_w (int): Defines the radius of the frame at the width axis around the central pixel for the Sub-DSI.\n",
    "        batch_size (int): The batch size for the Dataloader when applying the network.\n",
    "        \n",
    "    Attributes:\n",
    "        # Parameters\n",
    "        frame_height, frame_width (int): The height and width of the frame, equal to the dimensions of the DSIs.\n",
    "        estimated_depths (list of numpy arrs): A list of the estimated depthmaps by the model for each DSI.\n",
    "        colored_depth_estimations (list of numpy arrs): A colored version of the estimated_depths list. Cmap is 'jet'.\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 model,\n",
    "                 dsi_list,\n",
    "                 threshold_mask_list,\n",
    "                 # Input creation (Sub-DSIs)\n",
    "                 sub_frame_radius_h=3,\n",
    "                 sub_frame_radius_w=3,\n",
    "                 batch_size=1024,\n",
    "                 ):\n",
    "        # Args and Attrbts\n",
    "        self.model = model\n",
    "        self.dsi_list = dsi_list\n",
    "        self.threshold_mask_list = threshold_mask_list\n",
    "        self.sub_frame_radius_h = sub_frame_radius_h\n",
    "        self.sub_frame_radius_w = sub_frame_radius_w\n",
    "        self.batch_size = batch_size\n",
    "        self.frame_height, self.frame_width = self.dsi_list[0].shape[1:]\n",
    "        # Automatically send model to the available device\n",
    "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        model.device = self.device\n",
    "        model.to(model.device)  # Send the model to the device\n",
    "        # Estimated Depthmaps\n",
    "        self.estimated_depths = [torch.full((self.frame_height, self.frame_width), float('nan'), device=self.device) for _ in self.dsi_list]\n",
    "        # Colored images\n",
    "        self.colored_depth_estimations = []\n",
    "        # Measure inference time to apply network to single Batch of Sub-DSIs\n",
    "        self.network_inference_times = []\n",
    "\n",
    "        # Iterate through list of DSIs, get threshold mask, transform to data, apply network, assign estimated depths to pixels\n",
    "        for dsi_idx, dsi in enumerate(self.dsi_list):\n",
    "            threshold_mask = threshold_mask_list[dsi_idx]\n",
    "            # Select pixels from DSI and create Sub-DSIs around them as data points\n",
    "            data_for_inference = DSI_Pixels_for_Inference(dsi, threshold_mask, self.sub_frame_radius_h, self.sub_frame_radius_w)\n",
    "            # Load data into DataLoader to better parallelization\n",
    "            dataloader = DataLoader(data_for_inference, batch_size=self.batch_size, shuffle=False)\n",
    "            with torch.no_grad():\n",
    "                # Iterate through batches to apply network and assign estimated depths\n",
    "                for batch, batch_data in enumerate(dataloader):\n",
    "                    pixel_position, network_depth = self.apply_network(batch_data)\n",
    "                    self.assign_pixel_depth(dsi_idx, pixel_position, network_depth)\n",
    "            \n",
    "    \n",
    "    \"\"\"Methods\"\"\"\n",
    "    def get_threshold_mask(self, dsi):\n",
    "        \"\"\"Create adaptive threshold mask based on maximum amount of counted rays per pixel.\"\"\"\n",
    "        # Take the max values of DSI along the depth axis\n",
    "        confidence_map = torch.max(dsi, dim=0)[0].cpu().numpy()  # np.max(dsi, axis=0) if DSI is numpy array\n",
    "        # Determine the maximum value for normalization\n",
    "        dsi_max_confidence = np.max(confidence_map)\n",
    "        normalization_max_confidence = max(self.max_confidence, dsi_max_confidence)\n",
    "        # Scale it to inbetween 0 and 255\n",
    "        confidence_map_normalized = np.around(confidence_map * 255 / normalization_max_confidence).astype('uint8')    \n",
    "        # Apply adaptive threshold\n",
    "        threshold_mask = cv2.adaptiveThreshold(confidence_map_normalized, 255,\n",
    "                                                cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY,\n",
    "                                                self.filter_size, self.adaptive_threshold_c)\n",
    "        return threshold_mask.astype(bool)\n",
    "\n",
    "    def apply_network(self, batch_data):\n",
    "        \"\"\"Apply network to batch data.\"\"\"\n",
    "        # Get batch data\n",
    "        batch_data = tuple(tensor.to(self.device) for tensor in batch_data)\n",
    "        pixel_position, sub_dsi = batch_data\n",
    "        # Normalize pixel position\n",
    "        norm_pixel_position = pixel_position / torch.tensor((self.frame_height, self.frame_width), device=self.device)\n",
    "        # Create input for model\n",
    "        input = (norm_pixel_position, sub_dsi)\n",
    "        # Compute prediction\n",
    "        st = time.time()\n",
    "        network_depth = self.model(input)\n",
    "        self.network_inference_times.append(time.time() - st)\n",
    "        # Clip network estimations to inbetween 0 and 1\n",
    "        network_depth = network_depth.clip(0,1)\n",
    "\n",
    "        return pixel_position, network_depth\n",
    "\n",
    "    def assign_pixel_depth(self, dsi_idx, pixel_position, network_depth):\n",
    "        \"\"\"Assign estimated depths to the given pixels of the current DSI.\"\"\"\n",
    "        # Assign estimated depths to each pixel position\n",
    "        if not self.model.multi_pixel:\n",
    "            h, w = pixel_position[:, 0], pixel_position[:, 1]\n",
    "            self.estimated_depths[dsi_idx][h, w] = network_depth\n",
    "        else:\n",
    "            for pixel_idx, pixel_depth in enumerate(network_depth):\n",
    "                # Get height and width position of individual pixel\n",
    "                h, w = pixel_position[pixel_idx]\n",
    "                i = 0\n",
    "                # Iterate over left, right, top and down neighbours\n",
    "                for row in range(h - 1, h + 2):\n",
    "                    for col in range(w - 1, w + 2):\n",
    "                        self.estimated_depths[dsi_idx][row, col] = pixel_depth[i].item()\n",
    "                        i += 1\n",
    "    \n",
    "    def create_colored_depth_estimations(self):\n",
    "        \"\"\"Color the estimated depthmaps with the cmap 'jet'.\"\"\"\n",
    "        # Copy the jet colormap and set color for NaN values\n",
    "        cmap = plt.colormaps[\"jet\"]\n",
    "        cmap.set_bad(color='white')\n",
    "        # Apply colormap\n",
    "        self.colored_depth_estimations = [cmap(np.ma.masked_invalid(depthmap)) for depthmap in self.estimated_depths]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56ea7618",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DSI_Pixels_for_Inference(Dataset):\n",
    "    \"\"\"\n",
    "    A dataset class to transform a single DSI into data for the neural network.\n",
    "    This class is for inference of a single DSI only. It is a streamlined version of the DSI_Pixelswise_Dataset class.\n",
    "    The DSI is filtered for confident pixels by applying an adaptive threshold filter, created over the maximum ray counts of each pixel.\n",
    "    For each of these pixels, a surrounding subregion of the DSI is stored and normalized as a Sub-DSI to be used as data instance.\n",
    "    These instances are the inputs to the network model.\n",
    "\n",
    "    Args:\n",
    "        dsi (numpy arr): The DSI with dimensions (depth, height, width).\n",
    "        threshold_mask (numpy arr): An adaptive gaussian threshold filter to select pixels with high confidence due to their maximum ray count.\n",
    "        sub_frame_radius_h (int): Defines the radius of the frame at the height axis around the central pixel for the Sub-DSI.\n",
    "        sub_frame_radius_w (int): Defines the radius of the frame at the width axis around the central pixel for the Sub-DSI.\n",
    "    \"\"\"\n",
    "    def __init__(self, dsi, threshold_mask, sub_frame_radius_h, sub_frame_radius_w):\n",
    "        # Args\n",
    "        self.threshold_mask = threshold_mask\n",
    "        self.sub_frame_radius_h = sub_frame_radius_h\n",
    "        self.sub_frame_radius_w = sub_frame_radius_w\n",
    "        \n",
    "        # Dataset\n",
    "        self.pixel_pos = self.get_indices()  # Filter for confident pixels\n",
    "        self.sub_dsis = self.get_sub_dsis(dsi)  # Get Sub-DSIs around them\n",
    "\n",
    "    \n",
    "    \"\"\"Special Methods\"\"\"\n",
    "    def __len__(self):\n",
    "        return len(self.sub_dsis)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # pixel_index, sub_dsi\n",
    "        return (self.pixel_pos[idx], self.sub_dsis[idx])\n",
    "        \n",
    "    \n",
    "    \"\"\"Utility Methods\"\"\"\n",
    "    def get_indices(self):\n",
    "        \"\"\"Select confident pixels.\"\"\"\n",
    "        # Create mask to unselect pixels too close to the border to create a Sub-DSI around it\n",
    "        border_mask = np.zeros_like(self.threshold_mask)\n",
    "        border_mask[self.sub_frame_radius_h:-self.sub_frame_radius_h, self.sub_frame_radius_w:-self.sub_frame_radius_w] = True\n",
    "        # Combine with threshold masks\n",
    "        selection_mask = border_mask & self.threshold_mask\n",
    "        # Deduce indices of selected pixels\n",
    "        selected_indices = list(zip(*np.where(selection_mask)))\n",
    "        # Transform to tensor\n",
    "        selected_indices = torch.tensor(selected_indices).to(device)\n",
    "        \n",
    "        return selected_indices\n",
    "    \n",
    "    def get_sub_dsis(self, dsi):\n",
    "        \"\"\"Get sub DSI around selected pixel with frame size of 2*sub_frame_radius + 1.\"\"\"\n",
    "        # Get dimensions\n",
    "        number_pixels, depth, height, width = len(self.pixel_pos), *dsi.shape\n",
    "        # Generate subregion dimensions\n",
    "        sub_height = 2 * self.sub_frame_radius_h + 1\n",
    "        sub_width = 2 * self.sub_frame_radius_w + 1\n",
    "        # Extract coordinates\n",
    "        h_indices = self.pixel_pos[:, 0]\n",
    "        w_indices = self.pixel_pos[:, 1]\n",
    "        # Create grid offsets\n",
    "        h_offsets = torch.arange(-self.sub_frame_radius_h, self.sub_frame_radius_h + 1, device=dsi.device)\n",
    "        w_offsets = torch.arange(-self.sub_frame_radius_w, self.sub_frame_radius_w + 1, device=dsi.device)\n",
    "        # Create meshgrid for offsets\n",
    "        h_grid, w_grid = torch.meshgrid(h_offsets, w_offsets, indexing='ij')  # Shape: (sub_height, sub_width)\n",
    "        h_grid = h_grid.flatten()  # Shape: (sub_height * sub_width)\n",
    "        w_grid = w_grid.flatten()\n",
    "        # Expand pixel indices for batched subregion extraction\n",
    "        h_indices = h_indices.unsqueeze(1) + h_grid  # Shape: (number_pixels, sub_height * sub_width)\n",
    "        w_indices = w_indices.unsqueeze(1) + w_grid\n",
    "         # Gather subregions\n",
    "        sub_dsis = dsi[:, h_indices, w_indices]  # Shape: (depth, number_pixels, sub_height * sub_width)\n",
    "        # Reshape to (number_pixels, depth, sub_height, sub_width)\n",
    "        sub_dsis = sub_dsis.transpose(0,1)\n",
    "        sub_dsis = sub_dsis.view(number_pixels, depth, sub_height, sub_width)\n",
    "        # Normalize each subregion and store data\n",
    "        sub_dsis = sub_dsis / sub_dsis.amax(dim=(1, 2, 3), keepdim=True).clamp(min=1e-8)\n",
    "\n",
    "        return sub_dsis"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd8f9d42",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Neural Network\n",
    "\n",
    "We define the neural network architecture as a class with two methods to save and load parameters.\n",
    "We also define an additional class that averages the estimates of several of our network to leverage ensemble learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37969bfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PixelwiseConvGRU(nn.Module):\n",
    "    \"\"\"\n",
    "    A neural network class to predict pixel-wise depth.\n",
    "    Input: Sub-DSI\n",
    "    Output: Depth estimate for central pixel and, if multi-pixel is set to True, also of the 8 dircetly neighboring pixels\n",
    "    Architecture: 3D-Convolution -> Flatten -> GRU -> Final hidden state -> Dense layer -> Output.\n",
    "    \n",
    "    Args:\n",
    "        sub_frame_radius_h (int): Radius at the length axis of the frame of the Sub-DSI.\n",
    "        sub_frame_radius_w (int): Radius at the length axis of the frame of the Sub-DSI.\n",
    "        out_channels (int): Number of output channels for the 3D-convolution.\n",
    "        multi_pixel (bool): Decides whether depth shall be estimated only for the central pixel or also at the 8 neighboring pixels.\n",
    "        use_pixel_pos (bool): An option to append the pixel coordinates to the data vector after the GRU for additional information.\n",
    "                                Pixel positions must be normalized herefor by the DSI_Pixelswise_Dataset.\n",
    "        hidden_size_scale (int): A scaling factor to scale the size of the inputs for the GRU to the size of the hidden states.\n",
    "        num_gru_layers (int): Defines how many GRU layers should be stacked sequentially.\n",
    "        bidirectional (bool): Defines whether the GRU layer(s) should work bidirectionally.\n",
    "        dropout_rate (float): Rate for dropout.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self,\n",
    "                 sub_frame_radius_h,\n",
    "                 sub_frame_radius_w,\n",
    "                 out_channels=4,\n",
    "                 multi_pixel=False,\n",
    "                 use_pixel_pos=False,\n",
    "                 hidden_size_scale=1,\n",
    "                 num_gru_layers=1,\n",
    "                 bidirectional=False,\n",
    "                 dropout_rate=0\n",
    "                ):\n",
    "        # Inherit\n",
    "        super(PixelwiseConvGRU, self).__init__()\n",
    "    \n",
    "        # Args\n",
    "        self.sub_frame_radius_h = sub_frame_radius_h\n",
    "        self.sub_frame_radius_w = sub_frame_radius_w        \n",
    "        # The size of the Sub-DSI frame is 2 times its radius plus the central pixel\n",
    "        self.sub_frame_size_h = 2 * sub_frame_radius_h + 1\n",
    "        self.sub_frame_size_w = 2 * sub_frame_radius_w + 1\n",
    "        self.out_channels = out_channels\n",
    "        self.multi_pixel = multi_pixel\n",
    "        self.use_pixel_pos = use_pixel_pos\n",
    "        self.hidden_size_scale = hidden_size_scale\n",
    "        self.num_gru_layers = num_gru_layers\n",
    "        self.bidirectional = bidirectional\n",
    "        self.dropout_rate = dropout_rate\n",
    "\n",
    "        # Deduct sizes\n",
    "        self.gru_input_size = self.out_channels * (self.sub_frame_size_h-2) * (self.sub_frame_size_w-2)  # Frame size is reduced since we do not apply padding\n",
    "        self.gru_hidden_size = self.gru_input_size * self.hidden_size_scale\n",
    "        self.output_dim = 1 if not self.multi_pixel else 9\n",
    "        \n",
    "        # 3D-convolution layer\n",
    "        self.conv3d = nn.Sequential(\n",
    "            nn.Conv3d(\n",
    "                in_channels=1,\n",
    "                out_channels=self.out_channels,\n",
    "                kernel_size=(3, 3, 3),\n",
    "                # Pad only along the depth dimension\n",
    "                # since ray counts are effectively zero for the padded depth levels\n",
    "                padding=(1, 0, 0), \n",
    "                stride=(2,1,1)\n",
    "            ),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(self.dropout_rate)\n",
    "            )\n",
    "        \n",
    "        # GRU layer\n",
    "        self.gru = nn.GRU(\n",
    "            input_size = self.gru_input_size,\n",
    "            hidden_size = self.gru_hidden_size,\n",
    "            num_layers = self.num_gru_layers,\n",
    "            dropout = self.dropout_rate,\n",
    "            bidirectional=self.bidirectional,\n",
    "            batch_first = True\n",
    "            )\n",
    "\n",
    "        # Output layer\n",
    "        self.dense_output = nn.Sequential(\n",
    "            nn.Linear(\n",
    "                # A bidircetional GRU would have double the output size\n",
    "                # and if the pixel position shall be considered, two entries will be appended\n",
    "                (1+self.bidirectional)*self.gru_hidden_size + 2*self.use_pixel_pos, self.gru_hidden_size),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(self.dropout_rate),\n",
    "            nn.Linear(self.gru_hidden_size, self.output_dim)\n",
    "            )\n",
    "\n",
    "        # Automatically send model to the available device\n",
    "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.to(self.device)  # Send the model to the device\n",
    "        \n",
    "    def forward(self, input):\n",
    "        # Preprocess input\n",
    "        pixel_position, sub_dsi = input\n",
    "        batch_size, depth_levels = sub_dsi.shape[:2]\n",
    "        \n",
    "        # Apply 3D-convolution\n",
    "        sub_dsi_conv = self.conv3d(sub_dsi.unsqueeze(dim=1))\n",
    "        # Flatten\n",
    "        sub_dsi_conv_flat = sub_dsi_conv.transpose(1,2).flatten(start_dim=2)\n",
    "        # Check whether dimensions match from 3D-convolution to GRU\n",
    "        batch_size, depth_levels, tensor_size = sub_dsi_conv_flat.size()\n",
    "        assert tensor_size == self.gru_input_size\n",
    "        \n",
    "        # Apply GRU\n",
    "        h_seq, _ = self.gru(sub_dsi_conv_flat)\n",
    "        # Take final hidden state\n",
    "        h_n = h_seq[:,-1,:]\n",
    "        # If selected, appenid pixel position\n",
    "        if self.use_pixel_pos:\n",
    "            h_n = torch.cat([pixel_position, h_n], dim=-1)\n",
    "        # Check whether dimensions match from GRU output to the final dense output-layer\n",
    "        assert h_n.size() == (batch_size, (1+self.bidirectional)*self.gru_hidden_size + 2*self.use_pixel_pos)\n",
    "\n",
    "        # Apply final dense layer to obtain final estimate\n",
    "        output = self.dense_output(h_n)\n",
    "        # Assert correct output dimension\n",
    "        assert output.size() == (batch_size, self.output_dim)\n",
    "        # Squeeze if single pixel\n",
    "        if not self.multi_pixel:\n",
    "            output = output.squeeze(dim=-1)\n",
    "        \n",
    "        return output\n",
    "\n",
    "    def save_model(self, optimizer, model_file, model_path=None, print_save=True):\n",
    "        \"\"\"Method to save model and optimizer parameters to model_path and model_file.\"\"\"\n",
    "        if model_path is None:\n",
    "            # Set default model path\n",
    "            model_path = f\"{parent_dir}/models/\"\n",
    "            \n",
    "        torch.save({\n",
    "            \"model_state_dict\": self.state_dict(),\n",
    "            \"optimizer_state_dict\": optimizer.state_dict()},\n",
    "            os.path.join(model_path, model_file)\n",
    "                  )\n",
    "        # Print success message\n",
    "        if print_save:\n",
    "            print(f\"Saved PyTorch Model and Optimizer State to {model_path}{model_file}\")\n",
    "\n",
    "    def load_parameters(self, model_file, model_path=None, optimizer=None):\n",
    "        \"\"\"Method to load model parameters from model_path and model_file.\n",
    "        If an optimizer is selected, its parameters are loaded, too.\n",
    "        \"\"\"\n",
    "        if model_path is None:\n",
    "            # Set default model path\n",
    "            model_path = f\"{parent_dir}/models/\"\n",
    "        checkpoint = torch.load(os.path.join(model_path, model_file), map_location=self.device)\n",
    "        self.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "        if optimizer is not None:\n",
    "            optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34097a0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AveragedNetwork(nn.Module):\n",
    "    \"\"\"Neural network class to average predictions of several networks.\"\"\"\n",
    "    def __init__(self, neural_nets):\n",
    "        super(AveragedNetwork, self).__init__()\n",
    "        # Inherit\n",
    "        self.multi_pixel = neural_nets[0].multi_pixel\n",
    "        self.sub_frame_radius_h = neural_nets[0].sub_frame_radius_h\n",
    "        self.sub_frame_radius_w = neural_nets[0].sub_frame_radius_w\n",
    "        \n",
    "        # Automatically send averaged model to the available device\n",
    "        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.to(self.device)  # Send the model to the device\n",
    "        \n",
    "        # List of neural networks as argument\n",
    "        if torch.cuda.device_count() > 1:\n",
    "            # Use DataParallel if multiple GPUs are available\n",
    "            self.neural_nets = nn.ModuleList([nn.DataParallel(net) for net in neural_nets])\n",
    "        else:\n",
    "            # List of neural networks as argument\n",
    "            self.neural_nets = neural_nets\n",
    "\n",
    "    def forward(self, input):\n",
    "        # Forward pass through all networks\n",
    "        outputs = [neural_net(input) for neural_net in self.neural_nets]\n",
    "        \n",
    "        # Calculate the average prediction\n",
    "        average_output = sum(outputs) / len(outputs)\n",
    "\n",
    "        return average_output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11c8b887",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Loss Function\n",
    "\n",
    "We write a loss function that ignores NaN-values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96f93cad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomMAELoss(nn.Module):\n",
    "    \"\"\"Custom loss function to apply L1-loss but ignore Nan-values.\"\"\"\n",
    "    def __init__(self):\n",
    "        super(CustomMAELoss, self).__init__()\n",
    "\n",
    "    def forward(self, prediction, target):\n",
    "        # Check and flatten inputs if necessary\n",
    "        if prediction.dim() > 2:\n",
    "            prediction = prediction.flatten(start_dim=1)\n",
    "        if target.dim() > 2:\n",
    "            target = target.flatten(start_dim=1)\n",
    "\n",
    "        # Ensure the prediction and target tensors have the same shape\n",
    "        assert prediction.shape == target.shape, \"Prediction and target must have the same shape\"\n",
    "\n",
    "        # Compute mask to ignore NaNs\n",
    "        valid_mask = ~torch.isnan(target)\n",
    "        valid_predictions = prediction[valid_mask]\n",
    "        valid_targets = target[valid_mask]\n",
    "\n",
    "        # Calculate the absolute errors only on valid (non-NaN) entries\n",
    "        abs_errors = torch.abs(valid_predictions - valid_targets)\n",
    "\n",
    "        # Return the mean of these errors\n",
    "        return torch.mean(abs_errors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "589232b5",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Metrics\n",
    "\n",
    "We define an evaluation method to measure performance by computing metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09c1d001",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_performance(data, network_estimates, argmax_estimates, true_depths):\n",
    "    \"\"\"\n",
    "    Evaluates performance of network and argmax approach against true depths.\n",
    "    Estimates are brought into the right form and then used to compute metrics.\n",
    "    The data instance itself needs to be passed to derive hyperparameters.\n",
    "    \"\"\"\n",
    "    # Eliminate nan values\n",
    "    valid_mask = ~torch.isnan(true_depths)\n",
    "    network_estimates = network_estimates[valid_mask]\n",
    "    argmax_estimates = argmax_estimates[valid_mask]\n",
    "    true_depths = true_depths[valid_mask]\n",
    "    \n",
    "    # Project values to original space in meters\n",
    "    network_estimates = network_estimates * (data.max_depth - data.min_depth) + data.min_depth\n",
    "    argmax_estimates = argmax_estimates * (data.max_depth - data.min_depth) + data.min_depth\n",
    "    true_depths = true_depths * (data.max_depth - data.min_depth) + data.min_depth\n",
    "\n",
    "    # Get hyperparameters for camera and scene\n",
    "    if \"mvsec\" in data.dataset:\n",
    "        b, f = 0.09988137641750752, 226.38018519795807\n",
    "    elif \"dsec\" in data.dataset:\n",
    "        b, f = 0.6, 557.2412109375\n",
    "    \n",
    "    # Compute absolute network performance for epoch\n",
    "    errors_names = [\"MAE\", \"MedAE\", \"Bad Pix\", \"SILog\", \"ARE\", \"log RMSE\", \"delta1\", \"delta2\", \"delta3\"]\n",
    "    network_errors = [error_value for error_value in compute_metrics(network_estimates, true_depths, b, f)]\n",
    "    argmax_errors = [error_value for error_value in compute_metrics(argmax_estimates, true_depths, b, f)]\n",
    "\n",
    "    # Scale distance errors to centimeters and quotients to percentages\n",
    "    for i, error_name in enumerate(errors_names):\n",
    "        network_errors[i] *= 100\n",
    "        argmax_errors[i] *= 100\n",
    "    \n",
    "    # Create output string\n",
    "    network_string = \"Network Test Error Performance:\\n\"\n",
    "    argmax_string = \"Argmax Test Error Performance:\\n\"\n",
    "    for i, error_name in enumerate(errors_names):\n",
    "        network_string += f\" {error_name}: {network_errors[i].item():>0.2f} |\"\n",
    "        argmax_string += f\" {error_name}: {argmax_errors[i].item():>0.2f} |\"\n",
    "    # Add number of points for inference\n",
    "    network_string += f\" #Pix: {data.pixel_count}\"\n",
    "    argmax_string += f\" #: {data.pixel_count}\"\n",
    "    \n",
    "    # Print performance\n",
    "    print(network_string)\n",
    "    print(argmax_string)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "572c1629",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(estimate, target, b, f):\n",
    "    \"\"\"Compute metrics given estimates and true depth targets.\"\"\"\n",
    "    # Data size\n",
    "    n = len(estimate)\n",
    "    # Epsilon to avoid division by zero\n",
    "    epsilon = 0.00000000001\n",
    "    estimate += epsilon\n",
    "    target += epsilon\n",
    "    \n",
    "    # MAE\n",
    "    MAE = torch.mean(torch.abs(estimate - target))\n",
    "    # MedAE\n",
    "    MedAE = torch.median(torch.abs(estimate - target))\n",
    "    # Bad pix\n",
    "    err = torch.abs(1 / estimate - 1 / target) * b * f\n",
    "    rel_err = err * target / b / f\n",
    "    badp = torch.sum((err > 5) & (rel_err > 0.05)) / n\n",
    "    # SILog\n",
    "    di = torch.log(target) - torch.log(estimate)\n",
    "    SILog = 1 / n * torch.sum(di ** 2) - 1 / (n * n) * torch.sum(di) ** 2\n",
    "    # Abs rel diff error\n",
    "    ARE = 1 / n * torch.sum(torch.abs(estimate - target) / estimate)\n",
    "    # log RMSE\n",
    "    lRMSE = (1 / n * torch.sum((torch.log(target) - torch.log(estimate)) ** 2)) ** 0.5\n",
    "    # Inlier ratios\n",
    "    delta = torch.max(estimate / target, target / estimate)\n",
    "    delta1 = torch.sum(delta < 1.25) / n\n",
    "    delta2 = torch.sum(delta < 1.25 ** 2) / n\n",
    "    delta3 = torch.sum(delta < 1.25 ** 3) / n\n",
    "\n",
    "    return MAE, MedAE, badp, SILog, ARE, lRMSE, delta1, delta2, delta3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d974adc",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Training\n",
    "\n",
    "We define the training process. Within it, there is a training loop for every batch. Both are defined by the two subsequent functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035cce09",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataloader, data, model, loss_fn, optimizer, data_augmentation=False):\n",
    "    \"\"\"\n",
    "    Function to define the training process.\n",
    "    The data instance itself is needed to derive hyperparameters of the dataset.\n",
    "    \"\"\"\n",
    "    # Set device\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    # Set model to training mode\n",
    "    model.train()\n",
    "    # Get size of entire dataset\n",
    "    data_size = len(dataloader.dataset)\n",
    "    # Get number of batches\n",
    "    num_batches = len(dataloader)\n",
    "    # Account for single or multi pixel network-version\n",
    "    num_estims = 9 if model.multi_pixel else 1\n",
    "    # Track estimates and true depths\n",
    "    epoch_network_estimates = torch.zeros(num_estims * data_size, dtype=torch.float32, device=device)\n",
    "    epoch_argmax_estimates = torch.zeros(num_estims * data_size, dtype=torch.float32, device=device)\n",
    "    epoch_true_depths = torch.zeros(num_estims * data_size, dtype=torch.float32, device=device)\n",
    "    # Track current index for these tensors\n",
    "    current_idx = 0\n",
    "    \n",
    "    # Iterate over batches\n",
    "    for batch, batch_data in enumerate(dataloader):\n",
    "        # If available, use GPU (device has to be set earlier)\n",
    "        batch_data = tuple(tensor.to(device) for tensor in batch_data)\n",
    "        # Get batch data\n",
    "        pixel_position, sub_dsi, true_depth, argmax_depth, frame_idx = batch_data\n",
    "        batch_size = true_depth.size(0)\n",
    "        # Train on batch and return network prediction (without augmented predictions)\n",
    "        pred = train_batch(batch_data, model, loss_fn, optimizer, data_augmentation=data_augmentation)\n",
    "        # Clip network estimations to inbetween 0 and 1\n",
    "        network_depth = pred.clip(0,1)\n",
    "        # Update epoch estimates and target values\n",
    "        epoch_network_estimates[current_idx:current_idx + num_estims * batch_size] = network_depth.flatten()\n",
    "        epoch_argmax_estimates[current_idx:current_idx + num_estims * batch_size] = argmax_depth.repeat_interleave(num_estims)\n",
    "        epoch_true_depths[current_idx:current_idx + num_estims * batch_size] = true_depth.flatten()\n",
    "        # Update index\n",
    "        current_idx += num_estims * batch_size\n",
    "        # Clear memory cache\n",
    "        gc.collect()\n",
    "            \n",
    "    # Compute and print performance for epoch\n",
    "    evaluate_performance(data, epoch_network_estimates, epoch_argmax_estimates, epoch_true_depths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09a87343",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_batch(batch_data, model, loss_fn, optimizer, data_augmentation=False):\n",
    "    \"\"\"The iteration step for training on each batch.\"\"\"\n",
    "    # Get data\n",
    "    pixel_position, sub_dsi, true_depth, argmax_depth, frame_idx = batch_data\n",
    "    # Augment data\n",
    "    if data_augmentation:\n",
    "        pixel_position = pixel_position.clone()\n",
    "        # Invert x-axis\n",
    "        if random.random() > 0.5:\n",
    "            pixel_position[:,0] = 1 - pixel_position[:,0]\n",
    "            sub_dsi = sub_dsi.flip([-2])\n",
    "        # Invert y-axis\n",
    "        if random.random() > 0.5:\n",
    "            pixel_position[:,1] = 1 - pixel_position[:,1]\n",
    "            sub_dsi = sub_dsi.flip([-1])\n",
    "    # Define input\n",
    "    input = (pixel_position, sub_dsi)\n",
    "    \n",
    "    # Compute prediction loss\n",
    "    pred = model(input)\n",
    "    loss = loss_fn(pred, true_depth)\n",
    "    \n",
    "    # Backpropagation\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61c24692",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# Testing\n",
    "\n",
    "We now define the testing process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc1eecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(dataloader, data, model, flip_horizontal=False, flip_vertical=False, rotate=0):\n",
    "    \"\"\"\n",
    "    Function to test the performance of the model.\n",
    "    The data instance itself has to be given to derive hyperparameters.\n",
    "    Data augmentation can be applied by flipping the data horizontally or vertically.\n",
    "    To rotate the data by 0, 90, 180 or 270 degrees, set rotate to 0, 1, 2 or 3.\n",
    "    \"\"\"\n",
    "    # Set device\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    # Set the model to evaluation mode - important for batch normalization and dropout layers\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    # Get size of entire dataset\n",
    "    data_size = len(dataloader.dataset)\n",
    "    # Get number of batches\n",
    "    num_batches = len(dataloader)\n",
    "    # Account for single or multi pixel network-version\n",
    "    num_estims = 9 if model.multi_pixel else 1\n",
    "    # Track estimates and true depths\n",
    "    epoch_network_estimates = torch.zeros(num_estims * data_size, dtype=torch.float32, device=device)\n",
    "    epoch_argmax_estimates = torch.zeros(num_estims * data_size, dtype=torch.float32, device=device)\n",
    "    epoch_true_depths = torch.zeros(num_estims * data_size, dtype=torch.float32, device=device)\n",
    "    # Track current index for these tensors\n",
    "    current_idx = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for batch, batch_data in enumerate(dataloader):    \n",
    "            # If available, use GPU (device has to be set earlier)\n",
    "            batch_data = tuple(tensor.to(device) for tensor in batch_data)\n",
    "            # Get batch data\n",
    "            pixel_position, sub_dsi, true_depth, argmax_depth, frame_idx = batch_data\n",
    "            batch_size = true_depth.size(0)\n",
    "            # Rotate and/or mirror data\n",
    "            if flip_horizontal:\n",
    "                sub_dsi = sub_dsi.flip([-1])\n",
    "            if flip_vertical:\n",
    "                sub_dsi = sub_dsi.flip([-2])\n",
    "            if rotate > 0:\n",
    "                sub_dsi = torch.rot90(sub_dsi, k=rotate, dims=[-1, -2])             \n",
    "            # Get input\n",
    "            input = (pixel_position, sub_dsi)\n",
    "            # Compute prediction\n",
    "            network_depth = model(input)\n",
    "            # Clip network estimations to inbetween 0 and 1\n",
    "            network_depth = network_depth.clip(0,1)\n",
    "            # Update epoch estimates and target values\n",
    "            epoch_network_estimates[current_idx:current_idx + num_estims * batch_size] = network_depth.flatten()\n",
    "            epoch_argmax_estimates[current_idx:current_idx + num_estims * batch_size] = argmax_depth.repeat_interleave(num_estims)\n",
    "            epoch_true_depths[current_idx:current_idx + num_estims * batch_size] = true_depth.flatten()\n",
    "            # Update index\n",
    "            current_idx += num_estims * batch_size\n",
    "    \n",
    "    # Compute and print performance for epoch\n",
    "    evaluate_performance(data, epoch_network_estimates, epoch_argmax_estimates, epoch_true_depths)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
