{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5576bea5",
   "metadata": {},
   "source": [
    "# About this Notebook\n",
    "\n",
    "This notebook can be used for training and testing the neural network.\n",
    "\n",
    "It relies on the classes and functions defined in *Classes_and_Functions.ipynb*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2794f840-e75a-4488-8976-f2d8918bc19c",
   "metadata": {},
   "source": [
    "# Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f737b80b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard library imports\n",
    "import random\n",
    "import os\n",
    "import gc\n",
    "import re\n",
    "import time\n",
    "\n",
    "# Third-party library imports\n",
    "import numpy as np\n",
    "import cv2  # OpenCV for adaptive filtering\n",
    "import psutil  # For system resource management\n",
    "from scipy.ndimage import convolve  # To convolve filtering masks\n",
    "\n",
    "# PyTorch specific imports\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset, Subset\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eeaa1f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Notebooks\n",
    "import import_ipynb\n",
    "from Classes_and_Functions import *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0c05c14",
   "metadata": {},
   "source": [
    "# Hyperparameters\n",
    "\n",
    "First, define the hyperparameters of which dataset to use, what filter to apply, how Sub-DSIs shall be constructed and the whether to use the single or the multi-pixel version of the network. More options exist for the dataset, see *Classes_and_Functions.ipynb*\n",
    "\n",
    "Quick overview:\n",
    "* Everything can be left at default except the path for the <b>dsi_directory</b> and the <b>depthmap_directory</b>. \n",
    "* The default is the single-pixel version of the network, to use the multi-pixel version set <b>multi_pixel=True</b>.\n",
    "* The process is set to MVSEC stereo on default. If desired, switch to <b>dataset=\"mvsec_mono\"</b>, <b>dataset=\"dsec\"</b> or <b>dataset=\"dsec_mono\"</b>.\n",
    "* The filter parameters are set to default, but we used <b>filter_size=9</b> and an <b>adaptive_threshold_c=-10</b> for MVSEC and <b>adaptive_threshold_c=-10</b> for DSEC for training and testing instead. Feel free to replicate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a855fc74",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Hyperparameters for the dataset:\n",
    "    # DSI Selection Arguments\n",
    "    dataset (str): The dataset used.\n",
    "    test_seq (int): Sequence (MVSEC) or half (DSEC) for testing.\n",
    "    train_seq_A, train_seq_B (str): Sequences for training.\n",
    "    dsi_directory (str): Directory of the DSIs. Must be adjusted to user.\n",
    "    depthmap_directory (str): Directory of the groundtrue depths for each DSI.\n",
    "    dsi_num_expression (str): Number expression of DSI files for sorting.\n",
    "    depthmap_num_expression (str): Number expression of depthmap files for sorting.\n",
    "    dsi_split (str or int)\n",
    "    dsi_split (str or int): Which DSIs shall be considered.\n",
    "                            Can be \"all\", \"even\", \"odd\" or a number between 0 and 9, refering to the last digit of its id.\n",
    "    dsi_ratio (float): Between 0 and 1. Defines the proportion of (random) DSIs that shall be used.\n",
    "    start_idx, end_idx (str): Start and stop indices for which DSIs to consider. \n",
    "    start_row, end_row, start_col, end_col (str): Define the rows and columns to be considered within each DSI.\n",
    "\n",
    "    # Pixel selection\n",
    "    filter_size (int): Determines the size of the neighbourhood area when applying the adaptive threshold filter.\n",
    "    adaptive_threshold_c (int): Constant that is subtracted from the mean of the neighbourhood pixels when apply the adaptive threshold filter.\n",
    "\n",
    "    # Sub-DSIs sizes\n",
    "    sub_frame_radius_h (int): Defines the radius of the frame at the height axis around the central pixel for the Sub-DSI.\n",
    "    sub_frame_radius_w (int): Defines the radius of the frame at the width axis around the central pixel for the Sub-DSI.\n",
    "\n",
    "    # Network version\n",
    "    multi_pixel (bool): Determines whether depth is predicted only for the central selected pixel or for the 8 neighbouring pixels as well.\n",
    "\"\"\"\n",
    "\n",
    "# Dataset selection\n",
    "dataset = \"mvsec_stereo\" #  Options: mvsec_stereo, mvsec_mono, dsec_stereo, dsec_mono\n",
    "test_seq = 1  #  Options: 1,2,3 for MVSEC, 1,2 for DSEC (refers to which half)\n",
    "train_seq_A, train_seq_B = {1,2,3} - {test_seq}\n",
    "\n",
    "# Directories\n",
    "modality = \"monocular\" if \"mono\" in dataset else \"stereo\"\n",
    "if \"mvsec\" in dataset:\n",
    "    test_directory = f\"/mvsec/indoor_flying{test_seq}\"\n",
    "    train_A_directory = f\"/mvsec/indoor_flying{train_seq_A}\"\n",
    "    train_B_directory = f\"/mvsec/indoor_flying{train_seq_B}\"\n",
    "elif \"dsec\" in dataset:\n",
    "    test_directory = \"/dsec\"\n",
    "    train_A_directory = \"/dsec\"\n",
    "    train_B_directory = \"/dsec\"\n",
    "dsi_directory_test = f\"{parent_dir}/data/{test_directory}/dsi_{modality}/\" #  Set your path here\n",
    "dsi_directory_train_A = f\"{parent_dir}/data/{train_A_directory}/dsi_{modality}/\"\n",
    "dsi_directory_train_B = f\"{parent_dir}/data/{train_B_directory}/dsi_{modality}/\"\n",
    "depthmap_directory_test = f\"{parent_dir}/data/{test_directory}/depthmaps/\" #  Set your path here\n",
    "depthmap_directory_train_A = f\"{parent_dir}/data/{train_A_directory}/depthmaps/\"\n",
    "depthmap_directory_train_B = f\"{parent_dir}/data/{train_B_directory}/depthmaps/\"\n",
    "\n",
    "# Number expressions of files\n",
    "dsi_num_expression = \"\\d+\\.\\d+|d+\"\n",
    "depthmap_num_expression = \"\\d+\"\n",
    "\n",
    "# dsi_split\n",
    "dsi_split_test = \"all\" #  Options: all, even, odd, 0, 1, ..., 9\n",
    "dsi_split_train_A = \"all\"\n",
    "dsi_split_train_B = \"all\"\n",
    "# dsi_ratio\n",
    "dsi_ratio_test = 1.0 #  Options: 0 < dsi_ratio <= 1\n",
    "dsi_ratio_train_A = 1.0\n",
    "dsi_ratio_train_B = 1.0\n",
    "\n",
    "# start_idx and end_idx\n",
    "start_idx_test, end_idx_test = 0, None \n",
    "start_idx_train_A, end_idx_train_A = 0, None\n",
    "start_idx_train_B, end_idx_train_B = 0, None\n",
    "\n",
    "# start and end idx of rows and columns\n",
    "start_row_test, end_row_test = 0, None\n",
    "start_col_test, end_col_test = 0, None\n",
    "start_row_train_A, end_row_train_A = 0, None\n",
    "start_col_train_A, end_col_train_A = 0, None\n",
    "start_row_train_B, end_row_train_B = 0, None\n",
    "start_col_train_B, end_col_train_B = 0, None\n",
    "\n",
    "# Filter parameters for pixel selection\n",
    "filter_size_test = None #  None automatically sets original value. We used 9 for training and testing on MVSEC and DSEC instead\n",
    "filter_size_train_A = None #  9 for MVSEC and DSEC\n",
    "filter_size_train_B = None #  9 for MVSEC and DSEC\n",
    "adaptive_threshold_c_test = None #  None automatically sets original value. We used -10 for MVSEC and -2 for DSEC instead\n",
    "adaptive_threshold_c_train_A = None #  -10 for MVSEC and -2 for DSEC\n",
    "adaptive_threshold_c_train_B = None #  -10 for MVSEC and -2 for DSEC\n",
    "\n",
    "# Sub-DSI sizes\n",
    "sub_frame_radius_h = 3\n",
    "sub_frame_radius_w = 3\n",
    "\n",
    "# Network version\n",
    "multi_pixel = False #  Set to True for multi-pixel network version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82189bc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If DSEC was selected as dataset, test_seq will be used for testing and the other one for training.\n",
    "# middle_idx is set to the middle of the index for the zurich_city04a sequence, but can be set to a different custom value as well.\n",
    "if \"dsec\" in dataset:\n",
    "    # For DSEC training and test sets can be split by dividing one sequence\n",
    "    middle_idx = 174\n",
    "    # Assign one half for testing and one for training based on chosen test_seq.\n",
    "    if test_seq == 1:\n",
    "        # First half being used for testing and second half for training.\n",
    "        end_idx_test = middle_idx\n",
    "        start_idx_train_A = middle_idx\n",
    "    elif test_seq == 2:\n",
    "        # Second half being used for testing and first half for training.\n",
    "        start_idx_test = middle_idx\n",
    "        end_idx_train_A = middle_idx\n",
    "    else:\n",
    "        # Make sure that one half is selected.\n",
    "        raise Exception(\"Select one of two halfes for testing.\")\n",
    "    # No second training set is needed for DSEC\n",
    "    end_idx_train_B = start_idx_train_B"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "615cf4c8",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2723a1e4",
   "metadata": {},
   "source": [
    "### Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f124813f",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(0)\n",
    "# Decide whether the progress of reading in the DSIs shall be printed for tracking\n",
    "print_progress = True\n",
    "\n",
    "# Create training data\n",
    "training_data_A = DSI_Pixelswise_Dataset(dataset=dataset,\n",
    "                                         data_seq=train_seq_A,\n",
    "                                         dsi_directory=dsi_directory_train_A,\n",
    "                                         depthmap_directory=depthmap_directory_train_A,\n",
    "                                         dsi_num_expression=dsi_num_expression,\n",
    "                                         depthmap_num_expression=depthmap_num_expression,\n",
    "                                         dsi_split=dsi_split_train_A,\n",
    "                                         dsi_ratio=dsi_ratio_train_A,\n",
    "                                         start_idx=start_idx_train_A, end_idx=end_idx_train_A,\n",
    "                                         start_row=start_row_train_A, end_row=end_row_train_A,\n",
    "                                         start_col=start_col_train_A, end_col=end_col_train_A,\n",
    "                                         filter_size=filter_size_train_A,\n",
    "                                         adaptive_threshold_c=adaptive_threshold_c_train_A,\n",
    "                                         sub_frame_radius_h=sub_frame_radius_h,\n",
    "                                         sub_frame_radius_w=sub_frame_radius_w,\n",
    "                                         multi_pixel=multi_pixel,\n",
    "                                         clip_targets=True, #  Clip depths for training\n",
    "                                         print_progress=print_progress\n",
    "                                        )\n",
    "\n",
    "if print_progress: print(\"\")\n",
    "\n",
    "training_data_B = DSI_Pixelswise_Dataset(dataset=dataset,\n",
    "                                         data_seq=train_seq_B,\n",
    "                                         dsi_directory=dsi_directory_train_B,\n",
    "                                         depthmap_directory=depthmap_directory_train_B,\n",
    "                                         dsi_num_expression=dsi_num_expression,\n",
    "                                         depthmap_num_expression=depthmap_num_expression,\n",
    "                                         dsi_split=dsi_split_train_B,\n",
    "                                         dsi_ratio=dsi_ratio_train_B,\n",
    "                                         start_idx=start_idx_train_B, end_idx=end_idx_train_B,\n",
    "                                         start_row=start_row_train_B, end_row=end_row_train_B,\n",
    "                                         start_col=start_col_train_B, end_col=end_col_train_B,\n",
    "                                         filter_size=filter_size_train_B,\n",
    "                                         adaptive_threshold_c=adaptive_threshold_c_train_B,\n",
    "                                         sub_frame_radius_h=sub_frame_radius_h,\n",
    "                                         sub_frame_radius_w=sub_frame_radius_w,\n",
    "                                         multi_pixel=multi_pixel,\n",
    "                                         clip_targets=True,\n",
    "                                         print_progress=print_progress\n",
    "                                        )    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4172cd72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge training data\n",
    "training_data = ConcatDataset([training_data_A, training_data_B])\n",
    "# Inherit some attributes\n",
    "training_data.dataset = training_data_A.dataset\n",
    "training_data.pixel_count = training_data_A.pixel_count + training_data_B.pixel_count\n",
    "training_data.frame_height, training_data.frame_width = training_data_A.frame_height, training_data_A.frame_width\n",
    "training_data.min_depth, training_data. max_depth = training_data_A.min_depth, training_data_B. max_depth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b31ad55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrap data into Dataloader\n",
    "batch_size = 64\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9de438c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print data dimensions\n",
    "training_data_A_size = len(training_data_A)\n",
    "training_data_B_size = len(training_data_B)\n",
    "training_data_size = len(training_data)\n",
    "sub_dsi_size = training_data_A.data_list[0][1].shape\n",
    "\n",
    "print(\"training data A size:\", training_data_A_size)\n",
    "print(\"training data B size:\", training_data_B_size)\n",
    "print(\"training data size:\", training_data_size)\n",
    "print(\"pixel number for inference:\", training_data.pixel_count)\n",
    "print(\"sub dsi size:\", sub_dsi_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "909dc749",
   "metadata": {},
   "source": [
    "### Initialize Model\n",
    "\n",
    "More options exist for the network architecture, see *Classes_and_Functions.ipynb*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c549f7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model\n",
    "model = PixelwiseConvGRU(sub_frame_radius_h, sub_frame_radius_w, multi_pixel=multi_pixel)\n",
    "# Send to cuda\n",
    "if torch.cuda.is_available():\n",
    "    model.cuda()\n",
    "# Print architecture\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9173cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define conditions for the training process\n",
    "epochs = 3\n",
    "data_augmentation = False #  data_augmentation randomly inverts DSIs on horizontally and/or vertically\n",
    "learning_rate = 1e-3\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "loss_fn = CustomMAELoss() if multi_pixel else torch.nn.L1Loss() # CustomMAELoss is L1Loss which ignores  NaN-values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc85b8e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path to store model in directory\n",
    "model_directory = test_directory if \"mvsec\" in dataset else f\"/dsec/dsec_half{test_seq}\"\n",
    "model_path = f\"{parent_dir}/models/{model_directory}/\"\n",
    "# Define name of model file\n",
    "model_file = \"self_trained_model\"\n",
    "# In the training process, the current epoch will be added to each files name\n",
    "# Therefore do NOT set \".pth\"\n",
    "if model_file.endswith(\".pth\"):\n",
    "    model_file = model_file[:-4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5aa5737",
   "metadata": {},
   "outputs": [],
   "source": [
    "# If desired, a previous trained version of the model can be loaded. To do so, set the previous epoch.\n",
    "previous_epoch = 0\n",
    "previous_epoch > 0:\n",
    "    previous_model_path = model_path #  Set your path here\n",
    "    previous_model_file = f\"{model_file}_epoch_{previous_epoch}.pth\" #  Set your file name here\n",
    "    model.load_parameters(previous_model_file, model_path=previous_model_path, optimizer=optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d50ea58",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a7ec95c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start the training process\n",
    "for epoch in range(epochs):\n",
    "    print(f\"Epoch {epoch+1}\\n-------------------------------\")\n",
    "    # Train and track time\n",
    "    st = time.time()\n",
    "    train(train_dataloader, training_data, model, loss_fn, optimizer, data_augmentation=data_augmentation)\n",
    "    ct = time.time()\n",
    "    print(\"\\n\", \"Training time [min]: \", (ct-st)//60, sep=\"\")\n",
    "    # Add epoch to file name\n",
    "    epoch_model_file = f\"{model_file}_epoch_{epoch+1}.pth\"\n",
    "    # Save\n",
    "    model.save_model(optimizer, epoch_model_file, model_path=model_path, print_save=True) #  Set print_save to False to not print message\n",
    "    print(\"\")\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "373d4695",
   "metadata": {},
   "source": [
    "# Testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29982494",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ad6ddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(0)\n",
    "# Decide whether the progress of reading in the DSIs shall be printed for tracking\n",
    "print_progress = True\n",
    "\n",
    "# Create testset\n",
    "test_data = DSI_Pixelswise_Dataset(dataset=dataset,\n",
    "                                   data_seq=test_seq,\n",
    "                                   dsi_directory=dsi_directory_test,\n",
    "                                   depthmap_directory=depthmap_directory_test,\n",
    "                                   dsi_num_expression=dsi_num_expression,\n",
    "                                   depthmap_num_expression=depthmap_num_expression,\n",
    "                                   dsi_split=dsi_split_test,\n",
    "                                   dsi_ratio=dsi_ratio_test,\n",
    "                                   start_idx=start_idx_test, end_idx=end_idx_test,\n",
    "                                   start_row=start_row_test, end_row=end_row_test,\n",
    "                                   start_col=start_col_test, end_col=end_col_test,\n",
    "                                   filter_size=filter_size_test,\n",
    "                                   adaptive_threshold_c=adaptive_threshold_c_test,\n",
    "                                   sub_frame_radius_h=sub_frame_radius_h,\n",
    "                                   sub_frame_radius_w=sub_frame_radius_w,\n",
    "                                   multi_pixel=multi_pixel,\n",
    "                                   clip_targets=False, #  Do not clip depths for testing\n",
    "                                   print_progress=print_progress\n",
    "                                  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b3656aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrap data into Dataloader\n",
    "batch_size = 64\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f376b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print data dimensions\n",
    "test_data_size = len(test_data)\n",
    "sub_dsi_size = test_data.data_list[0][1].shape\n",
    "\n",
    "print(\"test data size:\", test_data_size)\n",
    "print(\"pixel number for inference:\", test_data.pixel_count)\n",
    "print(\"sub dsi size:\", sub_dsi_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fa1ef31",
   "metadata": {},
   "source": [
    "### Load Models\n",
    "\n",
    "We load our trained models. Since we leverage ensemble learning, the parameters for each individual model have to be loaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb8cd8d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many models\n",
    "num_models = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13de715e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize models\n",
    "models = [PixelwiseConvGRU(sub_frame_radius_h, sub_frame_radius_w, multi_pixel=multi_pixel) for _ in range(num_models)]\n",
    "# Send to cuda\n",
    "if torch.cuda.is_available():\n",
    "    for model in models:\n",
    "        model.cuda()\n",
    "# Print architecture\n",
    "print(models[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cc3c548",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set path to load models from directory\n",
    "model_directory = f\"/mvsec/indoor_flying{test_seq}\" if \"mvsec\" in dataset else f\"/dsec/dsec_half{test_seq}\"\n",
    "model_paths = [f\"{parent_dir}/models/{model_directory}/\"] * num_models # length has to be equal to num_models\n",
    "# Give names of model files\n",
    "prefix_sequence = f\"indoor_flying{test_seq}\" if \"mvsec\" in dataset else f\"dsec_half{test_seq}\"\n",
    "prefix_modality = modality if not multi_pixel else \"multipixel\"\n",
    "model_files = [f\"{prefix_sequence}_{prefix_modality}_even_model.pth\",\n",
    "               f\"{prefix_sequence}_{prefix_modality}_odd_model.pth\"][:num_models] # length has to be equal to num_models\n",
    "# Do not forget \".pth\"\n",
    "for idx, model_file in enumerate(model_files):\n",
    "    if not model_file.endswith(\".pth\"):\n",
    "        model_files[idx] += \".pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0d98adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load models parameters\n",
    "for idx, model in enumerate(models):\n",
    "    model.load_parameters(model_files[idx], model_path=model_paths[idx], optimizer=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e31c3669",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use ensemble learning to create averaged model\n",
    "model = AveragedNetwork(models)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11434000",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65e54270",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decide whether test data should be inverted horizontally\n",
    "flip_horizontal = False\n",
    "# Decide whether test data should be inverted vertically\n",
    "flip_vertical = False\n",
    "# To rotate the data by 0, 90, 180 or 270 degrees, set rotate to 0, 1, 2 or 3.\n",
    "rotate = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3017c51c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test(test_dataloader, test_data, model, flip_horizontal=flip_horizontal, flip_vertical=flip_vertical, rotate=rotate)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (stereo_depth_estimation)",
   "language": "python",
   "name": "stereo_depth_estimation"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
