{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Reproducibility Challenge - Panoptic-DeepLab.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aiDAWgAYJQNs"
      },
      "source": [
        "#Notebook configurations"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KvtyejOxI_o-",
        "cellView": "form"
      },
      "source": [
        "CLONE = True #@param {type:\"boolean\"}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Sr9jAlmQ3Nu-"
      },
      "source": [
        "PULL = False #@param {type:\"boolean\"}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IdAX0rZ6hNwv",
        "cellView": "form"
      },
      "source": [
        "GDRIVE = False #@param {type:\"boolean\"}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "oK9Bfq7bVmAG"
      },
      "source": [
        "MLFLOW = False #@param {type:\"boolean\"}\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5-sJ-u6hnxk0"
      },
      "source": [
        "# Configure DAGsHub, GitHub and Git"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3rJfZtcwnxlE"
      },
      "source": [
        "import requests\n",
        "import getpass\n",
        "import datetime\n",
        "import os"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_CNtdeXnxlF"
      },
      "source": [
        "**Set Environment Variables - DAGsHub**\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NnjfCx4zrnr2",
        "cellView": "form"
      },
      "source": [
        "#@title Enter the DAGsHub repository owner name:\n",
        "\n",
        "DAGSHUB_REPO_OWNER= \"jinensetpal\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_v3Pnf5XqgJS",
        "cellView": "form"
      },
      "source": [
        "#@title Enter the DAGsHub repository name:\n",
        "\n",
        "DAGSHUB_REPO_NAME= \"panoptic-reproducibility\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Hx-CctUEqxlK",
        "cellView": "form"
      },
      "source": [
        "#@title Enter the username of your DAGsHub account:\n",
        "\n",
        "DAGSHUB_USER_NAME = \"\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MNOtEiGwrfKM"
      },
      "source": [
        "**Set Environment Variables - GitHub**\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "FTliKFejry9N"
      },
      "source": [
        "#@title Enter the GitHub repository owner name:\n",
        "\n",
        "GITHUB_REPO_OWNER= \"jinensetpal\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "pI0_N1YLry9R"
      },
      "source": [
        "#@title Enter the GitHub repository name:\n",
        "\n",
        "GITHUB_REPO_NAME= \"panoptic-reproducibility\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WgRNMvxxp_iq"
      },
      "source": [
        "#@title Enter the GitHub repository name:\n",
        "\n",
        "BRANCH= \"\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "cellView": "form",
        "id": "__FlIReIry9S"
      },
      "source": [
        "#@title Enter the username of your GitHub account:\n",
        "\n",
        "GITHUB_USER_NAME = \"\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "coMAE_ZQry9S"
      },
      "source": [
        "#@title Enter the email for your GitHub account:\n",
        "\n",
        "GITHUB_EMAIL = \"\" #@param {type:\"string\"}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kXwRqKhLoDyu",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e1526050-d28d-4790-c4a8-45226cc1fb3d"
      },
      "source": [
        "GITHUB_TOKEN = getpass.getpass('Please enter your GitHub token or password: ')\n",
        "DAGSHUB_TOKEN = getpass.getpass('Please enter your DAGsHub token or password: ')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5c0u8u0_ZZwD"
      },
      "source": [
        "from google.colab import drive\n",
        "if GDRIVE:\n",
        "  drive.mount('/content/drive')\n",
        "  %cd /content/drive/MyDrive"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z-KHaV-h9CnD"
      },
      "source": [
        "**Configure Git**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q1CPZP15Npnj"
      },
      "source": [
        "!git config --global user.email {GITHUB_EMAIL}\n",
        "!git config --global user.name {GITHUB_USER_NAME}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WPNKFBEFTlkH"
      },
      "source": [
        "**Clone the Repository**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IZdQl7CgCf9x",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6273caf5-ca56-4e4a-c974-e6135b7f9c54"
      },
      "source": [
        "if CLONE:\n",
        "  !git clone -b {BRANCH} https://{GITHUB_USER_NAME}:{GITHUB_TOKEN}@github.com/{GITHUB_REPO_OWNER}/{GITHUB_REPO_NAME}.git\n",
        "  %cd {GITHUB_REPO_NAME}\n",
        "if PULL:\n",
        "  !git pull"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'panoptic-reproducibility'...\n",
            "remote: Enumerating objects: 211, done.\u001b[K\n",
            "remote: Counting objects: 100% (211/211), done.\u001b[K\n",
            "remote: Compressing objects: 100% (166/166), done.\u001b[K\n",
            "remote: Total 211 (delta 61), reused 169 (delta 36), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (211/211), 27.64 MiB | 33.69 MiB/s, done.\n",
            "Resolving deltas: 100% (61/61), done.\n",
            "/content/panoptic-reproducibility\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Tdcm_muUCCz"
      },
      "source": [
        "**Install Requirements**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oZXyUglvswaw",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bc26d9bd-4e06-41e5-df50-2922726fadb2"
      },
      "source": [
        "!pip install --upgrade pip --quiet\n",
        "\n",
        "!pip install -r requirements.txt --quiet"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[K     |████████████████████████████████| 1.6MB 9.0MB/s \n",
            "\u001b[K     |████████████████████████████████| 413 kB 6.7 MB/s \n",
            "\u001b[K     |████████████████████████████████| 472 kB 50.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 637 kB 56.5 MB/s \n",
            "\u001b[K     |████████████████████████████████| 8.3 MB 21.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 54 kB 2.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 109 kB 44.7 MB/s \n",
            "\u001b[K     |████████████████████████████████| 28.5 MB 30 kB/s \n",
            "\u001b[K     |████████████████████████████████| 352 kB 40.9 MB/s \n",
            "\u001b[K     |████████████████████████████████| 78 kB 6.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 46 kB 3.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 108 kB 69.0 MB/s \n",
            "\u001b[K     |████████████████████████████████| 296 kB 51.9 MB/s \n",
            "\u001b[K     |████████████████████████████████| 40 kB 4.7 MB/s \n",
            "\u001b[K     |████████████████████████████████| 115 kB 49.9 MB/s \n",
            "\u001b[K     |████████████████████████████████| 4.6 MB 43.7 MB/s \n",
            "\u001b[K     |████████████████████████████████| 44 kB 2.5 MB/s \n",
            "\u001b[K     |████████████████████████████████| 170 kB 53.1 MB/s \n",
            "\u001b[K     |████████████████████████████████| 207 kB 56.4 MB/s \n",
            "\u001b[K     |████████████████████████████████| 49 kB 5.6 MB/s \n",
            "\u001b[K     |████████████████████████████████| 76 kB 4.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 529 kB 44.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 428 kB 42.5 MB/s \n",
            "\u001b[K     |████████████████████████████████| 56 kB 4.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 389 kB 53.2 MB/s \n",
            "\u001b[K     |████████████████████████████████| 3.2 MB 40.1 MB/s \n",
            "\u001b[K     |████████████████████████████████| 63 kB 1.5 MB/s \n",
            "\u001b[K     |████████████████████████████████| 75 kB 3.9 MB/s \n",
            "\u001b[K     |████████████████████████████████| 112 kB 75.1 MB/s \n",
            "\u001b[K     |████████████████████████████████| 68 kB 6.1 MB/s \n",
            "\u001b[K     |████████████████████████████████| 56 kB 3.9 MB/s \n",
            "\u001b[K     |████████████████████████████████| 2.6 MB 27.8 MB/s \n",
            "\u001b[K     |████████████████████████████████| 201 kB 63.8 MB/s \n",
            "\u001b[K     |████████████████████████████████| 64 kB 2.7 MB/s \n",
            "\u001b[K     |████████████████████████████████| 51 kB 5.5 MB/s \n",
            "\u001b[K     |████████████████████████████████| 546 kB 49.1 MB/s \n",
            "\u001b[K     |████████████████████████████████| 86 kB 5.6 MB/s \n",
            "\u001b[?25h  Building wheel for brotlipy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for pyreadline (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for configobj (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for dpath (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for flufl.lock (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for nanotime (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for pygtrie (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for atpublic (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for ftfy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for mailchecker (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for typing (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "google-colab 1.0.0 requires tornado~=5.1.0; python_version >= \"3.0\", but you have tornado 6.1 which is incompatible.\n",
            "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n",
            "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3Ej5H3dJWBKj"
      },
      "source": [
        "**Configure DVC**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OKazlYv0rKoC",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "aed907ca-dfd6-44a5-bd81-58f7055da1f5"
      },
      "source": [
        "# Import DVC package - relevant only when working in a Colab environment\n",
        "import dvc\n",
        "\n",
        "if CLONE:\n",
        "  # configure dvc\n",
        "  # Set DVC remote storage as 'DAGsHub storage'\n",
        "  !dvc remote add origin --local https://dagshub.com/{DAGSHUB_REPO_OWNER}/{DAGSHUB_REPO_NAME}.dvc\n",
        "\n",
        "  # General DVC configuration\n",
        "  !dvc remote modify --local origin auth basic\n",
        "  !dvc remote modify --local origin user {DAGSHUB_USER_NAME}\n",
        "  !dvc remote modify --local origin password {DAGSHUB_TOKEN}\n",
        "\n",
        "if PULL:\n",
        "  !dvc pull -r origin #&> /dev/null\n",
        "\n",
        "  #Make sure that all files were pulled\n",
        "  !dvc pull"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "ERROR: configuration error - config file error: remote 'origin' already exists. Use `-f|--force` to overwrite it.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kUrH4z2HBYZY"
      },
      "source": [
        "**Configure MLflow**"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LrZBqtl1CKAy"
      },
      "source": [
        "if MLFLOW:\n",
        "  !pip install mlflow --quiet\n",
        "\n",
        "  import mlflow\n",
        "\n",
        "  os.environ['MLFLOW_TRACKING_USERNAME'] = DAGSHUB_USER_NAME\n",
        "  os.environ['MLFLOW_TRACKING_PASSWORD'] = DAGSHUB_TOKEN\n",
        "\n",
        "  mlflow.set_tracking_uri(f'https://dagshub.com/{DAGSHUB_REPO_OWNER}/{DAGSHUB_REPO_NAME}.mlflow')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IKO4338nrQRB"
      },
      "source": [
        "# Playground"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2OwbvdaJe5AV"
      },
      "source": [
        "### Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "26eSb1Abe4vC"
      },
      "source": [
        "from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
        "import pandas as pd\n",
        "pd.set_option('display.max_rows', 500)\n",
        "pd.set_option('display.max_columns', 500)\n",
        "pd.set_option('display.width', 1000)\n",
        "pd.set_option('display.expand_frame_repr', False)\n",
        "\n",
        "import glob\n",
        "import os"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AkoFzAl6ezGl"
      },
      "source": [
        "### const"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IsG3oCjR1XH5"
      },
      "source": [
        "BASE_PATH = os.getcwd()\n",
        "BASE_DATA_PATH = os.path.join(BASE_PATH, \"data/example/cityscapes\")\n",
        "\n",
        "IMG_SIZE = (1024, 2048)\n",
        "IMG_SHAPE = IMG_SIZE + (3,)\n",
        "SEED = 1"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_YopWqLSOXU3"
      },
      "source": [
        "### Loading images from directory\n",
        "\n",
        "**Assumptions**:\n",
        "- The data set is too big to load to the run time. Therefore, we will use a generator to yield the data from the directory in batches to the model.\n",
        "\n",
        "**Implementation:**\n",
        "- The implementation is based on the [TensorFlow docs](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator) (under \"Example of transforming images and masks together.\" section)\n",
        "- We can't use the `ImageDataGenerator.flow_from_directory` method due to the following reasons:\n",
        "  - For every image we fit the model the target image is combined out of 3 images (`gtFine_color`, `gtFine_instanceIds`, `gtFine_labelIds`). However, this method is capable of reading only one target image at a time.\n",
        "\n",
        "  - The structure of the dataset is not fitted to the needs of this method:\n",
        "    - It's not separated into categories\n",
        "    - The targe is combined out of three different images. The method is not capable of batching multiple targets for the same input image.\n",
        "\n",
        "- Due to the above, we will use the `flow_from_dataframe` method. We will create a pandas data frame with the path to the input images (`*_X_path_df`). Based on the name of the images, we will create additional three data frames (`*_gtFine_color_path_df`, `*_gtFine_instanceIds_path_df`, `*_gtFine_labelIds_path_df`) that will correspond with the index of `*_X_path_df` \n",
        "\n",
        "- TODO: EDIT: load the data - zip(X,zip(col,label,instance))\n",
        "\n",
        "<br>\n",
        "\n",
        "**Additional implementation options:**\n",
        "- [Custemize the data generator function](https://medium.com/analytics-vidhya/write-your-own-custom-data-generator-for-tensorflow-keras-1252b64e41c3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hdncclPQac9O"
      },
      "source": [
        "1. Get all the paths to images"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gbYJonvCOW2T"
      },
      "source": [
        "# TODO: write test that checks that all the df are correlated\n",
        "\n",
        "# Input\n",
        "train_X_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"leftImg8bit/train/*/*\"))})\n",
        "test_X_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"leftImg8bit/test/*/*\"))})\n",
        "val_X_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"leftImg8bit/val/*/*\"))})\n",
        "\n",
        "# GT Train\n",
        "train_gtFine_color_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/train/*/*color*\"))}) \n",
        "train_gtFine_instanceIds_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/train/*/*instanceIds*\"))}) \n",
        "train_gtFine_labelIds_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/train/*/*labelIds*\"))})\n",
        "\n",
        "# GT Test\n",
        "test_gtFine_color_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/test/*/*color*\"))}) \n",
        "test_gtFine_instanceIds_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/test/*/*instanceIds*\"))}) \n",
        "test_gtFine_labelIds_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/test/*/*labelIds*\"))})\n",
        "\n",
        "# GT Validation\n",
        "val_gtFine_color_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/val/*/*color*\"))}) \n",
        "val_gtFine_instanceIds_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/val/*/*instanceIds*\"))}) \n",
        "val_gtFine_labelIds_path_df = pd.DataFrame({\"filename\":glob.glob(os.path.join(BASE_DATA_PATH,\"gtFine/val/*/*labelIds*\"))})"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a-IZL8kIGLqR"
      },
      "source": [
        "2. Create a `ImageDataGenerator` instance for every sub set"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6krdhpRIOTny"
      },
      "source": [
        "from numpy import linspace\n",
        "\n",
        "# create ImageDataGenerator instances with the same arguments\n",
        "data_gen_args = dict(preprocessing_function=lambda img: cv2.imread(img).resize(img, dsize=(1025, 2049), interpolation=cv2.INTER_LINEAR), # (1024, 2048) -> (1025, 2049); crop size > image for cityscapes; random padding to be implemented\n",
        "                     zoom_range=[0.5, 2.0], # random scaling with scales sampled from 0.5 -> 2.0 with step size 0.1\n",
        "                     horizontal_flip=True) # random horizontal flip\n",
        "\n",
        "# Input\n",
        "train_X_datagen = ImageDataGenerator(**data_gen_args)\n",
        "test_X_datagen = ImageDataGenerator(**data_gen_args)\n",
        "val_X_datagen = ImageDataGenerator(**data_gen_args)\n",
        "\n",
        "# GT Train\n",
        "train_gtFine_color_datagen = ImageDataGenerator(**data_gen_args)\n",
        "train_gtFine_instanceIds_datagen = ImageDataGenerator(**data_gen_args)\n",
        "train_gtFine_labelIds_datagen = ImageDataGenerator(**data_gen_args)\n",
        "\n",
        "# GT Test\n",
        "test_gtFine_color_datagen = ImageDataGenerator(**data_gen_args)\n",
        "test_gtFine_instanceIds_datagen = ImageDataGenerator(**data_gen_args)\n",
        "test_gtFine_labelIds_datagen = ImageDataGenerator(**data_gen_args)\n",
        "\n",
        "# GT Validation\n",
        "val_gtFine_color_datagen = ImageDataGenerator(**data_gen_args)\n",
        "val_gtFine_instanceIds_datagen = ImageDataGenerator(**data_gen_args)\n",
        "val_gtFine_labelIds_datagen = ImageDataGenerator(**data_gen_args)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D9lZRLGIGjC7"
      },
      "source": [
        "3. Apply the `flow_from_dataframe` method for every sub set with its corresponding paths data frame"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oFX81x6CfSsA",
        "outputId": "a847e847-ec65-47b8-c169-d414867e690a"
      },
      "source": [
        "# Provide the same seed and keyword arguments to the fit and flow methods\n",
        "\n",
        "# Input\n",
        "train_X_generator = train_X_datagen.flow_from_dataframe(\n",
        "    train_X_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None, \n",
        "    batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "test_X_generator = test_X_datagen.flow_from_dataframe(\n",
        "    test_X_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "    batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "val_X_generator = val_X_datagen.flow_from_dataframe(\n",
        "    val_X_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "    batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "\n",
        "# GT Train\n",
        "train_gtFine_color_generator = train_gtFine_color_datagen.flow_from_dataframe(\n",
        "  train_gtFine_color_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "train_gtFine_instanceIds_generator = train_gtFine_instanceIds_datagen.flow_from_dataframe(\n",
        "  train_gtFine_instanceIds_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "train_gtFine_labelIds_generator = train_gtFine_labelIds_datagen.flow_from_dataframe(\n",
        "  train_gtFine_labelIds_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "# GT Test\n",
        "test_gtFine_color_generator = test_gtFine_color_datagen.flow_from_dataframe(\n",
        "  test_gtFine_color_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "test_gtFine_instanceIds_generator = test_gtFine_instanceIds_datagen.flow_from_dataframe(\n",
        "  test_gtFine_instanceIds_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "test_gtFine_labelIds_generator = test_gtFine_labelIds_datagen.flow_from_dataframe(\n",
        "  test_gtFine_labelIds_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "# GT Validation\n",
        "val_gtFine_color_generator = val_gtFine_color_datagen.flow_from_dataframe(\n",
        "  val_gtFine_color_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "val_gtFine_instanceIds_generator = val_gtFine_instanceIds_datagen.flow_from_dataframe(\n",
        "  val_gtFine_instanceIds_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)\n",
        "\n",
        "val_gtFine_labelIds_generator = val_gtFine_labelIds_datagen.flow_from_dataframe(\n",
        "  val_gtFine_labelIds_path_df, y_col=None, target_size=IMG_SIZE,class_mode=None,\n",
        "  batch_size=2, seed=SEED,weight_col=None)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n",
            "Found 4 validated image filenames.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pBDRbGCIGysb"
      },
      "source": [
        "4. Zip the GT, the target generators, that will be iterated by the model, e.g.:\n",
        "\n",
        "- train_gtFine_color_generator = [1,2,3]\n",
        "- train_gtFine_instanceIds_generator = [1,2,3] \n",
        "- train_gtFine_labelIds_generator = [1,2,3]\n",
        "\n",
        "generators will yield - (1,1,1), (2,2,2), (3,3,3)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QdgxcT-MkJ9d"
      },
      "source": [
        "# combine generators of GT into one which yields 3 target images\n",
        "train_gt_generator = zip(train_gtFine_color_generator, train_gtFine_instanceIds_generator, train_gtFine_labelIds_generator)\n",
        "test_gt_generator = zip(test_gtFine_color_generator, test_gtFine_instanceIds_generator, test_gtFine_labelIds_generator)\n",
        "val_gt_generator = zip(val_gtFine_color_generator, val_gtFine_instanceIds_generator, val_gtFine_labelIds_generator)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BwukgXG2Hgrw"
      },
      "source": [
        "5. Zip the X set with the target.\n",
        "\n",
        "generator will yield - (x1,(1,1,1)), (x2,(2,2,2)), (x3,(3,3,3))"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pjoQtFuInzp7"
      },
      "source": [
        "# combine generators into one which yields image and masks\n",
        "train_generator = zip(train_X_generator, train_gt_generator)\n",
        "test_generator = zip(test_X_generator, test_gt_generator)\n",
        "val_generator = zip(val_X_generator, val_gt_generator)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jx_TKSCBIcui"
      },
      "source": [
        "### Build the model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "M_hAwuqk09VB"
      },
      "source": [
        "%env PYTHONPATH=/content/panoptic-reproducibility/src/models/\n",
        "\n",
        "from tensorflow.keras import Model\n",
        "from models import loss, heads, encoder, decoder, convolutions #TODO: Add remaining .py libraries for test runs "
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "J8cIcdqK_DcO"
      },
      "source": [
        "def build_model(input_shape=(None, 1025, 2049, 3)):\n",
        "  input = Input(shape=input_shape) \n",
        "  enc, ups1, ups2 = encoder.xception(input) # ups1, ups2 for split connections \n",
        "  \n",
        "  aspp_sem = convolutions.SwitchableAtrousConvolution(enc) \n",
        "  aspp_inst = convolutions.SwitchableAtrousConvolution(enc)\n",
        "\n",
        "  sem_decoder = decoder.get_semantic_decoder(aspp_sem, ups1, ups2)\n",
        "  inst_decoder = decoder.get_instacne_decoder(aspp_inst, ups1, ups2)\n",
        "\n",
        "  return [Model(input, sem_decoder), Model(input, inst_decoder)]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dcc2vMhc3Ue5"
      },
      "source": [
        "models = build_model(input_shape=(None, 1025, 2049, 3))\n",
        "for model in models: \n",
        "  model.summary()\n",
        "  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),\n",
        "                loss=loss.loss_sem(),\n",
        "                metrics=[tf.keras.metrics.MeanIoU(num_classes=19)]))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kbtdA88pcKi4"
      },
      "source": [
        "sem = heads.get_semantic_head()\n",
        "kpt = heads.get_instance_center_head()\n",
        "rgr = heads.get_instance_regression_head()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bf-zax2ScTNg"
      },
      "source": [
        "model.fit() # TODO: Training regime"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w_z604f1zufl"
      },
      "source": [
        "# Commit Files"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pvym9JU-cytP"
      },
      "source": [
        "!git status"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pCRtuk5Rzt4r"
      },
      "source": [
        "# !dvc add\n",
        "\n",
        "# !git add\n",
        "\n",
        "# !git commit\n",
        "\n",
        "# !git status"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5CjjbKpyz4u5"
      },
      "source": [
        "# Push Files"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Loqm2x1Grhrj"
      },
      "source": [
        "# !git push https://{GITHUB_USER_NAME}:{GITHUB_TOKEN}@github.com/{GITHUB_REPO_OWNER}/{GITHUB_REPO_NAME}.git {BRANCH}"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-Fsal7eKxzt9"
      },
      "source": [
        "# !dvc push -r origin"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C9qx1MaKljEx"
      },
      "source": [
        ""
      ]
    }
  ]
}
