{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F88mignPnalS"
   },
   "source": [
    "# Introduction\n",
    "\n",
    "This colab is design to run the pretrained models from [GeoDiff](https://github.com/MinkaiXu/GeoDiff).\n",
    "The visualization code is inspired by this PyMol [colab](https://colab.research.google.com/gist/iwatobipen/2ec7faeafe5974501e69fcc98c122922/pymol.ipynb#scrollTo=Hm4kY7CaZSlw).\n",
    "\n",
    "The goal is to generate physically accurate molecules. Given the input of a molecule graph (atom and bond structures with their connectivity -- in the form of a 2d graph). What we want to generate is a stable 3d structure of the molecule.\n",
    "\n",
    "This colab uses GEOM datasets that have multiple 3d targets per configuration, which provide more compelling targets for generative methods.\n",
    "\n",
    "> Colab made by [natolambert](https://twitter.com/natolambert).\n",
    "\n",
    "![diffusers_library](https://github.com/huggingface/diffusers/raw/main/docs/source/imgs/diffusers_library.jpg)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7cnwXMocnuzB"
   },
   "source": [
    "## Installations\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ff9SxWnaNId9"
   },
   "source": [
    "### Install Conda"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1g_6zOabItDk"
   },
   "source": [
    "Here we check the `cuda` version of colab. When this was built, the version was always 11.1, which impacts some installation decisions below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "K0ofXobG5Y-X",
    "outputId": "572c3d25-6f19-4c1e-83f5-a1d084a3207f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nvcc: NVIDIA (R) Cuda compiler driver\n",
      "Copyright (c) 2005-2021 NVIDIA Corporation\n",
      "Built on Sun_Feb_14_21:12:58_PST_2021\n",
      "Cuda compilation tools, release 11.2, V11.2.152\n",
      "Build cuda_11.2.r11.2/compiler.29618528_0\n"
     ]
    }
   ],
   "source": [
    "!nvcc --version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VfthW90vI0nw"
   },
   "source": [
    "Install Conda for some more complex dependencies for geometric networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "2WNFzSnbiE0k",
    "outputId": "690d0d4d-9d0a-4ead-c6dc-086f113f532f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install -q condacolab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NUsbWYCUI7Km"
   },
   "source": [
    "Setup Conda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FZelreINdmd0",
    "outputId": "635f0cb8-0af4-499f-e0a4-b3790cb12e9f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✨🍰✨ Everything looks OK!\n"
     ]
    }
   ],
   "source": [
    "import condacolab\n",
    "\n",
    "\n",
    "condacolab.install()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JzDHaPU7I9Sn"
   },
   "source": [
    "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JMxRjHhL7w8V",
    "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n",
      "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
      "\n",
      "## Package Plan ##\n",
      "\n",
      "  environment location: /usr/local\n",
      "\n",
      "  added / updated specs:\n",
      "    - cudatoolkit=11.1\n",
      "    - pytorch\n",
      "    - torchaudio\n",
      "    - torchvision\n",
      "\n",
      "\n",
      "The following packages will be downloaded:\n",
      "\n",
      "    package                    |            build\n",
      "    ---------------------------|-----------------\n",
      "    conda-22.9.0               |   py37h89c1867_1         960 KB  conda-forge\n",
      "    ------------------------------------------------------------\n",
      "                                           Total:         960 KB\n",
      "\n",
      "The following packages will be UPDATED:\n",
      "\n",
      "  conda                               4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n",
      "\n",
      "\n",
      "\n",
      "Downloading and Extracting Packages\n",
      "conda-22.9.0         | 960 KB    | : 100% 1.0/1 [00:00<00:00,  4.15it/s]\n",
      "Preparing transaction: / \b\bdone\n",
      "Verifying transaction: \\ \b\bdone\n",
      "Executing transaction: / \b\bdone\n",
      "Retrieving notices: ...working... done\n"
     ]
    }
   ],
   "source": [
    "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n",
    "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QDS6FPZ0Tu5b"
   },
   "source": [
    "Need to remove a pathspec for colab that specifies the incorrect cuda version."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dq1lxR10TtrR",
    "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n"
     ]
    }
   ],
   "source": [
    "!rm /usr/local/conda-meta/pinned"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z1L3DdZOJB30"
   },
   "source": [
    "Install torch geometric (used in the model later)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "D5ukfCOWfjzK",
    "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
      "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
      "\n",
      "## Package Plan ##\n",
      "\n",
      "  environment location: /usr/local\n",
      "\n",
      "  added / updated specs:\n",
      "    - pytorch-geometric=1.7.2\n",
      "\n",
      "\n",
      "The following packages will be downloaded:\n",
      "\n",
      "    package                    |            build\n",
      "    ---------------------------|-----------------\n",
      "    decorator-4.4.2            |             py_0          11 KB  conda-forge\n",
      "    googledrivedownloader-0.4  |     pyhd3deb0d_1           7 KB  conda-forge\n",
      "    jinja2-3.1.2               |     pyhd8ed1ab_1          99 KB  conda-forge\n",
      "    joblib-1.2.0               |     pyhd8ed1ab_0         205 KB  conda-forge\n",
      "    markupsafe-2.1.1           |   py37h540881e_1          22 KB  conda-forge\n",
      "    networkx-2.5.1             |     pyhd8ed1ab_0         1.2 MB  conda-forge\n",
      "    pandas-1.2.3               |   py37hdc94413_0        11.8 MB  conda-forge\n",
      "    pyparsing-3.0.9            |     pyhd8ed1ab_0          79 KB  conda-forge\n",
      "    python-dateutil-2.8.2      |     pyhd8ed1ab_0         240 KB  conda-forge\n",
      "    python-louvain-0.15        |     pyhd8ed1ab_1          13 KB  conda-forge\n",
      "    pytorch-cluster-1.5.9      |py37_torch_1.8.0_cu111         1.2 MB  rusty1s\n",
      "    pytorch-geometric-1.7.2    |py37_torch_1.8.0_cu111         445 KB  rusty1s\n",
      "    pytorch-scatter-2.0.8      |py37_torch_1.8.0_cu111         6.1 MB  rusty1s\n",
      "    pytorch-sparse-0.6.12      |py37_torch_1.8.0_cu111         2.9 MB  rusty1s\n",
      "    pytorch-spline-conv-1.2.1  |py37_torch_1.8.0_cu111         736 KB  rusty1s\n",
      "    pytz-2022.4                |     pyhd8ed1ab_0         232 KB  conda-forge\n",
      "    scikit-learn-1.0.2         |   py37hf9e9bfc_0         7.8 MB  conda-forge\n",
      "    scipy-1.7.3                |   py37hf2a6cf1_0        21.8 MB  conda-forge\n",
      "    setuptools-59.8.0          |   py37h89c1867_1         1.0 MB  conda-forge\n",
      "    threadpoolctl-3.1.0        |     pyh8a188c0_0          18 KB  conda-forge\n",
      "    ------------------------------------------------------------\n",
      "                                           Total:        55.9 MB\n",
      "\n",
      "The following NEW packages will be INSTALLED:\n",
      "\n",
      "  decorator          conda-forge/noarch::decorator-4.4.2-py_0 None\n",
      "  googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n",
      "  jinja2             conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n",
      "  joblib             conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n",
      "  markupsafe         conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n",
      "  networkx           conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n",
      "  pandas             conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n",
      "  pyparsing          conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n",
      "  python-dateutil    conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n",
      "  python-louvain     conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n",
      "  pytorch-cluster    rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n",
      "  pytorch-geometric  rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n",
      "  pytorch-scatter    rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n",
      "  pytorch-sparse     rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n",
      "  pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n",
      "  pytz               conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n",
      "  scikit-learn       conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n",
      "  scipy              conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n",
      "  threadpoolctl      conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n",
      "\n",
      "The following packages will be DOWNGRADED:\n",
      "\n",
      "  setuptools                          65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n",
      "\n",
      "\n",
      "\n",
      "Downloading and Extracting Packages\n",
      "scikit-learn-1.0.2   | 7.8 MB    | : 100% 1.0/1 [00:01<00:00,  1.37s/it]              \n",
      "pytorch-scatter-2.0. | 6.1 MB    | : 100% 1.0/1 [00:06<00:00,  6.18s/it]\n",
      "pytorch-geometric-1. | 445 KB    | : 100% 1.0/1 [00:02<00:00,  2.53s/it]\n",
      "scipy-1.7.3          | 21.8 MB   | : 100% 1.0/1 [00:03<00:00,  3.06s/it]\n",
      "python-dateutil-2.8. | 240 KB    | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n",
      "pytorch-spline-conv- | 736 KB    | : 100% 1.0/1 [00:01<00:00,  1.00s/it]\n",
      "pytorch-sparse-0.6.1 | 2.9 MB    | : 100% 1.0/1 [00:07<00:00,  7.51s/it]\n",
      "pyparsing-3.0.9      | 79 KB     | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n",
      "pytorch-cluster-1.5. | 1.2 MB    | : 100% 1.0/1 [00:02<00:00,  2.78s/it]\n",
      "jinja2-3.1.2         | 99 KB     | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n",
      "decorator-4.4.2      | 11 KB     | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n",
      "joblib-1.2.0         | 205 KB    | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n",
      "pytz-2022.4          | 232 KB    | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n",
      "python-louvain-0.15  | 13 KB     | : 100% 1.0/1 [00:00<00:00,  3.34it/s]\n",
      "googledrivedownloade | 7 KB      | : 100% 1.0/1 [00:00<00:00,  3.33it/s]\n",
      "threadpoolctl-3.1.0  | 18 KB     | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n",
      "markupsafe-2.1.1     | 22 KB     | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n",
      "pandas-1.2.3         | 11.8 MB   | : 100% 1.0/1 [00:02<00:00,  2.08s/it]               \n",
      "networkx-2.5.1       | 1.2 MB    | : 100% 1.0/1 [00:01<00:00,  1.39s/it]\n",
      "setuptools-59.8.0    | 1.0 MB    | : 100% 1.0/1 [00:00<00:00,  4.25it/s]\n",
      "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n",
      "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
      "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
      "Retrieving notices: ...working... done\n"
     ]
    }
   ],
   "source": [
    "!conda install -c rusty1s pytorch-geometric=1.7.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ppxv6Mdkalbc"
   },
   "source": [
    "### Install Diffusers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mgQA_XN-XGY2",
    "outputId": "85392615-b6a4-4052-9d2a-79604be62c94"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content\n",
      "Cloning into 'diffusers'...\n",
      "remote: Enumerating objects: 9298, done.\u001b[K\n",
      "remote: Counting objects: 100% (40/40), done.\u001b[K\n",
      "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
      "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n",
      "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n",
      "Resolving deltas: 100% (6168/6168), done.\n",
      "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "%cd /content\n",
    "\n",
    "# install latest HF diffusers (will update to the release once added)\n",
    "!git clone https://github.com/huggingface/diffusers.git\n",
    "!pip install -q /content/diffusers\n",
    "\n",
    "# dependencies for diffusers\n",
    "!pip install -q datasets transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LZO6AJKuJKO8"
   },
   "source": [
    "Check that torch is installed correctly and utilizing the GPU in the colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 53
    },
    "id": "gZt7BNi1e1PA",
    "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    },
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'1.8.2'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "print(torch.cuda.is_available())\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KLE7CqlfJNUO"
   },
   "source": [
    "### Install Chemistry-specific Dependencies\n",
    "\n",
    "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0CPv_NvehRz3",
    "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
      "Collecting rdkit\n",
      "  Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n",
      "Installing collected packages: rdkit\n",
      "Successfully installed rdkit-2022.3.5\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install rdkit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "88GaDbDPxJ5I"
   },
   "source": [
    "### Get viewer from nglview\n",
    "\n",
    "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n",
    "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n",
    "The rdmol in this object is a source of ground truth for the generated molecules.\n",
    "\n",
    "You will use one rendering function from nglviewer later!\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "jcl8GCS2mz6t",
    "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
      "Collecting nglview\n",
      "  Downloading nglview-3.0.3.tar.gz (5.7 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
      "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
      "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n",
      "Collecting jupyterlab-widgets\n",
      "  Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting ipywidgets>=7\n",
      "  Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting widgetsnbextension~=4.0\n",
      "  Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting ipython>=6.1.0\n",
      "  Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting ipykernel>=4.5.1\n",
      "  Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting traitlets>=4.3.1\n",
      "  Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n",
      "Collecting pyzmq>=17\n",
      "  Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting matplotlib-inline>=0.1\n",
      "  Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n",
      "Collecting tornado>=6.1\n",
      "  Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting nest-asyncio\n",
      "  Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n",
      "Collecting debugpy>=1.0\n",
      "  Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting psutil\n",
      "  Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting jupyter-client>=6.1.12\n",
      "  Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting pickleshare\n",
      "  Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n",
      "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n",
      "Collecting backcall\n",
      "  Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n",
      "Collecting pexpect>4.3\n",
      "  Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting pygments\n",
      "  Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting jedi>=0.16\n",
      "  Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n",
      "  Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n",
      "Collecting parso<0.9.0,>=0.8.0\n",
      "  Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n",
      "Collecting entrypoints\n",
      "  Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n",
      "Collecting jupyter-core>=4.9.2\n",
      "  Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting ptyprocess>=0.5\n",
      "  Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n",
      "Collecting wcwidth\n",
      "  Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n",
      "Building wheels for collected packages: nglview\n",
      "  Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n",
      "  Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n",
      "Successfully built nglview\n",
      "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n",
      "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    },
    {
     "data": {
      "application/vnd.colab-display-data+json": {
       "pip_warning": {
        "packages": [
         "pexpect",
         "pickleshare",
         "wcwidth"
        ]
       }
      }
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "!pip install nglview"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8t8_e_uVLdKB"
   },
   "source": [
    "## Create a diffusion model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G0rMncVtNSqU"
   },
   "source": [
    "### Model class(es)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5FEXz5oXkzt"
   },
   "source": [
    "Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-3-P4w5sXkRU"
   },
   "outputs": [],
   "source": [
    "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n",
    "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n",
    "from dataclasses import dataclass\n",
    "from typing import Callable, Tuple, Union\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor, nn\n",
    "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n",
    "from torch_geometric.nn import MessagePassing, radius, radius_graph\n",
    "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n",
    "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n",
    "from torch_scatter import scatter_add\n",
    "from torch_sparse import SparseTensor, coalesce\n",
    "\n",
    "from diffusers.configuration_utils import ConfigMixin, register_to_config\n",
    "from diffusers.modeling_utils import ModelMixin\n",
    "from diffusers.utils import BaseOutput"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EzJQXPN_XrMX"
   },
   "source": [
    "Helper classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oR1Y56QiLY90"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class MoleculeGNNOutput(BaseOutput):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
    "            Hidden states output. Output of last layer of model.\n",
    "    \"\"\"\n",
    "\n",
    "    sample: torch.Tensor\n",
    "\n",
    "\n",
    "class MultiLayerPerceptron(nn.Module):\n",
    "    \"\"\"\n",
    "    Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n",
    "    Args:\n",
    "        input_dim (int): input dimension\n",
    "        hidden_dim (list of int): hidden dimensions\n",
    "        activation (str or function, optional): activation function\n",
    "        dropout (float, optional): dropout rate\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n",
    "        super(MultiLayerPerceptron, self).__init__()\n",
    "\n",
    "        self.dims = [input_dim] + hidden_dims\n",
    "        if isinstance(activation, str):\n",
    "            self.activation = getattr(F, activation)\n",
    "        else:\n",
    "            print(f\"Warning, activation passed {activation} is not string and ignored\")\n",
    "            self.activation = None\n",
    "        if dropout > 0:\n",
    "            self.dropout = nn.Dropout(dropout)\n",
    "        else:\n",
    "            self.dropout = None\n",
    "\n",
    "        self.layers = nn.ModuleList()\n",
    "        for i in range(len(self.dims) - 1):\n",
    "            self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"\"\"\"\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            x = layer(x)\n",
    "            if i < len(self.layers) - 1:\n",
    "                if self.activation:\n",
    "                    x = self.activation(x)\n",
    "                if self.dropout:\n",
    "                    x = self.dropout(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class ShiftedSoftplus(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ShiftedSoftplus, self).__init__()\n",
    "        self.shift = torch.log(torch.tensor(2.0)).item()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return F.softplus(x) - self.shift\n",
    "\n",
    "\n",
    "class CFConv(MessagePassing):\n",
    "    def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n",
    "        super(CFConv, self).__init__(aggr=\"add\")\n",
    "        self.lin1 = Linear(in_channels, num_filters, bias=False)\n",
    "        self.lin2 = Linear(num_filters, out_channels)\n",
    "        self.nn = mlp\n",
    "        self.cutoff = cutoff\n",
    "        self.smooth = smooth\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        torch.nn.init.xavier_uniform_(self.lin1.weight)\n",
    "        torch.nn.init.xavier_uniform_(self.lin2.weight)\n",
    "        self.lin2.bias.data.fill_(0)\n",
    "\n",
    "    def forward(self, x, edge_index, edge_length, edge_attr):\n",
    "        if self.smooth:\n",
    "            C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n",
    "            C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0)  # Modification: cutoff\n",
    "        else:\n",
    "            C = (edge_length <= self.cutoff).float()\n",
    "        W = self.nn(edge_attr) * C.view(-1, 1)\n",
    "\n",
    "        x = self.lin1(x)\n",
    "        x = self.propagate(edge_index, x=x, W=W)\n",
    "        x = self.lin2(x)\n",
    "        return x\n",
    "\n",
    "    def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n",
    "        return x_j * W\n",
    "\n",
    "\n",
    "class InteractionBlock(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n",
    "        super(InteractionBlock, self).__init__()\n",
    "        mlp = Sequential(\n",
    "            Linear(num_gaussians, num_filters),\n",
    "            ShiftedSoftplus(),\n",
    "            Linear(num_filters, num_filters),\n",
    "        )\n",
    "        self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n",
    "        self.act = ShiftedSoftplus()\n",
    "        self.lin = Linear(hidden_channels, hidden_channels)\n",
    "\n",
    "    def forward(self, x, edge_index, edge_length, edge_attr):\n",
    "        x = self.conv(x, edge_index, edge_length, edge_attr)\n",
    "        x = self.act(x)\n",
    "        x = self.lin(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "class SchNetEncoder(Module):\n",
    "    def __init__(\n",
    "        self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hidden_channels = hidden_channels\n",
    "        self.num_filters = num_filters\n",
    "        self.num_interactions = num_interactions\n",
    "        self.cutoff = cutoff\n",
    "\n",
    "        self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n",
    "\n",
    "        self.interactions = ModuleList()\n",
    "        for _ in range(num_interactions):\n",
    "            block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n",
    "            self.interactions.append(block)\n",
    "\n",
    "    def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n",
    "        if embed_node:\n",
    "            assert z.dim() == 1 and z.dtype == torch.long\n",
    "            h = self.embedding(z)\n",
    "        else:\n",
    "            h = z\n",
    "        for interaction in self.interactions:\n",
    "            h = h + interaction(h, edge_index, edge_length, edge_attr)\n",
    "\n",
    "        return h\n",
    "\n",
    "\n",
    "class GINEConv(MessagePassing):\n",
    "    \"\"\"\n",
    "    Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n",
    "    https://huggingface.co/papers/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n",
    "        super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n",
    "        self.nn = mlp\n",
    "        self.initial_eps = eps\n",
    "\n",
    "        if isinstance(activation, str):\n",
    "            self.activation = getattr(F, activation)\n",
    "        else:\n",
    "            self.activation = None\n",
    "\n",
    "        if train_eps:\n",
    "            self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n",
    "        else:\n",
    "            self.register_buffer(\"eps\", torch.Tensor([eps]))\n",
    "\n",
    "    def forward(\n",
    "        self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n",
    "    ) -> torch.Tensor:\n",
    "        \"\"\"\"\"\"\n",
    "        if isinstance(x, torch.Tensor):\n",
    "            x: OptPairTensor = (x, x)\n",
    "\n",
    "        # Node and edge feature dimensionalites need to match.\n",
    "        if isinstance(edge_index, torch.Tensor):\n",
    "            assert edge_attr is not None\n",
    "            assert x[0].size(-1) == edge_attr.size(-1)\n",
    "        elif isinstance(edge_index, SparseTensor):\n",
    "            assert x[0].size(-1) == edge_index.size(-1)\n",
    "\n",
    "        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n",
    "        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n",
    "\n",
    "        x_r = x[1]\n",
    "        if x_r is not None:\n",
    "            out += (1 + self.eps) * x_r\n",
    "\n",
    "        return self.nn(out)\n",
    "\n",
    "    def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n",
    "        if self.activation:\n",
    "            return self.activation(x_j + edge_attr)\n",
    "        else:\n",
    "            return x_j + edge_attr\n",
    "\n",
    "    def __repr__(self):\n",
    "        return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n",
    "\n",
    "\n",
    "class GINEncoder(torch.nn.Module):\n",
    "    def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n",
    "        super().__init__()\n",
    "\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.num_convs = num_convs\n",
    "        self.short_cut = short_cut\n",
    "        self.concat_hidden = concat_hidden\n",
    "        self.node_emb = nn.Embedding(100, hidden_dim)\n",
    "\n",
    "        if isinstance(activation, str):\n",
    "            self.activation = getattr(F, activation)\n",
    "        else:\n",
    "            self.activation = None\n",
    "\n",
    "        self.convs = nn.ModuleList()\n",
    "        for i in range(self.num_convs):\n",
    "            self.convs.append(\n",
    "                GINEConv(\n",
    "                    MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n",
    "                    activation=activation,\n",
    "                )\n",
    "            )\n",
    "\n",
    "    def forward(self, z, edge_index, edge_attr):\n",
    "        \"\"\"\n",
    "        Input:\n",
    "            data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n",
    "            hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n",
    "        Output:\n",
    "            node_feature: graph feature\n",
    "        \"\"\"\n",
    "\n",
    "        node_attr = self.node_emb(z)  # (num_node, hidden)\n",
    "\n",
    "        hiddens = []\n",
    "        conv_input = node_attr  # (num_node, hidden)\n",
    "\n",
    "        for conv_idx, conv in enumerate(self.convs):\n",
    "            hidden = conv(conv_input, edge_index, edge_attr)\n",
    "            if conv_idx < len(self.convs) - 1 and self.activation is not None:\n",
    "                hidden = self.activation(hidden)\n",
    "            assert hidden.shape == conv_input.shape\n",
    "            if self.short_cut and hidden.shape == conv_input.shape:\n",
    "                hidden += conv_input\n",
    "\n",
    "            hiddens.append(hidden)\n",
    "            conv_input = hidden\n",
    "\n",
    "        if self.concat_hidden:\n",
    "            node_feature = torch.cat(hiddens, dim=-1)\n",
    "        else:\n",
    "            node_feature = hiddens[-1]\n",
    "\n",
    "        return node_feature\n",
    "\n",
    "\n",
    "class MLPEdgeEncoder(Module):\n",
    "    def __init__(self, hidden_dim=100, activation=\"relu\"):\n",
    "        super().__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n",
    "        self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n",
    "\n",
    "    @property\n",
    "    def out_channels(self):\n",
    "        return self.hidden_dim\n",
    "\n",
    "    def forward(self, edge_length, edge_type):\n",
    "        \"\"\"\n",
    "        Input:\n",
    "            edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n",
    "        Returns:\n",
    "            edge_attr: The representation of edges. (E, 2 * num_gaussians)\n",
    "        \"\"\"\n",
    "        d_emb = self.mlp(edge_length)  # (num_edge, hidden_dim)\n",
    "        edge_attr = self.bond_emb(edge_type)  # (num_edge, hidden_dim)\n",
    "        return d_emb * edge_attr  # (num_edge, hidden)\n",
    "\n",
    "\n",
    "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n",
    "    h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n",
    "    h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1)  # (E, 2H)\n",
    "    return h_pair\n",
    "\n",
    "\n",
    "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        num_nodes:  Number of atoms.\n",
    "        edge_index: Bond indices of the original graph.\n",
    "        edge_type:  Bond types of the original graph.\n",
    "        order:  Extension order.\n",
    "    Returns:\n",
    "        new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n",
    "    \"\"\"\n",
    "\n",
    "    def binarize(x):\n",
    "        return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n",
    "\n",
    "    def get_higher_order_adj_matrix(adj, order):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            adj:        (N, N)\n",
    "            type_mat:   (N, N)\n",
    "        Returns:\n",
    "            Following attributes will be updated:\n",
    "              - edge_index\n",
    "              - edge_type\n",
    "            Following attributes will be added to the data object:\n",
    "              - bond_edge_index: Original edge_index.\n",
    "        \"\"\"\n",
    "        adj_mats = [\n",
    "            torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n",
    "            binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n",
    "        ]\n",
    "\n",
    "        for i in range(2, order + 1):\n",
    "            adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n",
    "        order_mat = torch.zeros_like(adj)\n",
    "\n",
    "        for i in range(1, order + 1):\n",
    "            order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n",
    "\n",
    "        return order_mat\n",
    "\n",
    "    num_types = 22\n",
    "    # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n",
    "    # from rdkit.Chem.rdchem import BondType as BT\n",
    "    N = num_nodes\n",
    "    adj = to_dense_adj(edge_index).squeeze(0)\n",
    "    adj_order = get_higher_order_adj_matrix(adj, order)  # (N, N)\n",
    "\n",
    "    type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0)  # (N, N)\n",
    "    type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n",
    "    assert (type_mat * type_highorder == 0).all()\n",
    "    type_new = type_mat + type_highorder\n",
    "\n",
    "    new_edge_index, new_edge_type = dense_to_sparse(type_new)\n",
    "    _, edge_order = dense_to_sparse(adj_order)\n",
    "\n",
    "    # data.bond_edge_index = data.edge_index  # Save original edges\n",
    "    new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N)  # modify data\n",
    "\n",
    "    return new_edge_index, new_edge_type\n",
    "\n",
    "\n",
    "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n",
    "    assert edge_type.dim() == 1\n",
    "    N = pos.size(0)\n",
    "\n",
    "    bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n",
    "\n",
    "    if is_sidechain is None:\n",
    "        rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch)  # (2, E_r)\n",
    "    else:\n",
    "        # fetch sidechain and its batch index\n",
    "        is_sidechain = is_sidechain.bool()\n",
    "        dummy_index = torch.arange(pos.size(0), device=pos.device)\n",
    "        sidechain_pos = pos[is_sidechain]\n",
    "        sidechain_index = dummy_index[is_sidechain]\n",
    "        sidechain_batch = batch[is_sidechain]\n",
    "\n",
    "        assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n",
    "        r_edge_index_x = assign_index[1]\n",
    "        r_edge_index_y = assign_index[0]\n",
    "        r_edge_index_y = sidechain_index[r_edge_index_y]\n",
    "\n",
    "        rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y))  # (2, E)\n",
    "        rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x))  # (2, E)\n",
    "        rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1)  # (2, 2E)\n",
    "        # delete self loop\n",
    "        rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n",
    "\n",
    "    rgraph_adj = torch.sparse.LongTensor(\n",
    "        rgraph_edge_index,\n",
    "        torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n",
    "        torch.Size([N, N]),\n",
    "    )\n",
    "\n",
    "    composed_adj = (bgraph_adj + rgraph_adj).coalesce()  # Sparse (N, N, T)\n",
    "\n",
    "    new_edge_index = composed_adj.indices()\n",
    "    new_edge_type = composed_adj.values().long()\n",
    "\n",
    "    return new_edge_index, new_edge_type\n",
    "\n",
    "\n",
    "def extend_graph_order_radius(\n",
    "    num_nodes,\n",
    "    pos,\n",
    "    edge_index,\n",
    "    edge_type,\n",
    "    batch,\n",
    "    order=3,\n",
    "    cutoff=10.0,\n",
    "    extend_order=True,\n",
    "    extend_radius=True,\n",
    "    is_sidechain=None,\n",
    "):\n",
    "    if extend_order:\n",
    "        edge_index, edge_type = _extend_graph_order(\n",
    "            num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n",
    "        )\n",
    "\n",
    "    if extend_radius:\n",
    "        edge_index, edge_type = _extend_to_radius_graph(\n",
    "            pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n",
    "        )\n",
    "\n",
    "    return edge_index, edge_type\n",
    "\n",
    "\n",
    "def get_distance(pos, edge_index):\n",
    "    return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n",
    "\n",
    "\n",
    "def graph_field_network(score_d, pos, edge_index, edge_length):\n",
    "    \"\"\"\n",
    "    Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n",
    "    5-7 of the GeoDiff Paper https://huggingface.co/papers/2203.02923\n",
    "    \"\"\"\n",
    "    N = pos.size(0)\n",
    "    dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]])  # (E, 3)\n",
    "    score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n",
    "        -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n",
    "    )  # (N, 3)\n",
    "    return score_pos\n",
    "\n",
    "\n",
    "def clip_norm(vec, limit, p=2):\n",
    "    norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n",
    "    denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n",
    "    return vec * denom\n",
    "\n",
    "\n",
    "def is_local_edge(edge_type):\n",
    "    return edge_type > 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QWrHJFcYXyUB"
   },
   "source": [
    "Main model class!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MCeZA1qQXzoK"
   },
   "outputs": [],
   "source": [
    "class MoleculeGNN(ModelMixin, ConfigMixin):\n",
    "    @register_to_config\n",
    "    def __init__(\n",
    "        self,\n",
    "        hidden_dim=128,\n",
    "        num_convs=6,\n",
    "        num_convs_local=4,\n",
    "        cutoff=10.0,\n",
    "        mlp_act=\"relu\",\n",
    "        edge_order=3,\n",
    "        edge_encoder=\"mlp\",\n",
    "        smooth_conv=True,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.cutoff = cutoff\n",
    "        self.edge_encoder = edge_encoder\n",
    "        self.edge_order = edge_order\n",
    "\n",
    "        \"\"\"\n",
    "        edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n",
    "        in SchNetEncoder\n",
    "        \"\"\"\n",
    "        self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act)  # get_edge_encoder(config)\n",
    "        self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act)  # get_edge_encoder(config)\n",
    "\n",
    "        \"\"\"\n",
    "        The graph neural network that extracts node-wise features.\n",
    "        \"\"\"\n",
    "        self.encoder_global = SchNetEncoder(\n",
    "            hidden_channels=hidden_dim,\n",
    "            num_filters=hidden_dim,\n",
    "            num_interactions=num_convs,\n",
    "            edge_channels=self.edge_encoder_global.out_channels,\n",
    "            cutoff=cutoff,\n",
    "            smooth=smooth_conv,\n",
    "        )\n",
    "        self.encoder_local = GINEncoder(\n",
    "            hidden_dim=hidden_dim,\n",
    "            num_convs=num_convs_local,\n",
    "        )\n",
    "\n",
    "        \"\"\"\n",
    "        `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n",
    "            gradients w.r.t. edge_length (out_dim = 1).\n",
    "        \"\"\"\n",
    "        self.grad_global_dist_mlp = MultiLayerPerceptron(\n",
    "            2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
    "        )\n",
    "\n",
    "        self.grad_local_dist_mlp = MultiLayerPerceptron(\n",
    "            2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
    "        )\n",
    "\n",
    "        \"\"\"\n",
    "        Incorporate parameters together\n",
    "        \"\"\"\n",
    "        self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n",
    "        self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n",
    "\n",
    "    def _forward(\n",
    "        self,\n",
    "        atom_type,\n",
    "        pos,\n",
    "        bond_index,\n",
    "        bond_type,\n",
    "        batch,\n",
    "        time_step,  # NOTE, model trained without timestep performed best\n",
    "        edge_index=None,\n",
    "        edge_type=None,\n",
    "        edge_length=None,\n",
    "        return_edges=False,\n",
    "        extend_order=True,\n",
    "        extend_radius=True,\n",
    "        is_sidechain=None,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            atom_type:  Types of atoms, (N, ).\n",
    "            bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n",
    "            bond_type:  Bond types, (E, ).\n",
    "            batch:      Node index to graph index, (N, ).\n",
    "        \"\"\"\n",
    "        N = atom_type.size(0)\n",
    "        if edge_index is None or edge_type is None or edge_length is None:\n",
    "            edge_index, edge_type = extend_graph_order_radius(\n",
    "                num_nodes=N,\n",
    "                pos=pos,\n",
    "                edge_index=bond_index,\n",
    "                edge_type=bond_type,\n",
    "                batch=batch,\n",
    "                order=self.edge_order,\n",
    "                cutoff=self.cutoff,\n",
    "                extend_order=extend_order,\n",
    "                extend_radius=extend_radius,\n",
    "                is_sidechain=is_sidechain,\n",
    "            )\n",
    "            edge_length = get_distance(pos, edge_index).unsqueeze(-1)  # (E, 1)\n",
    "        local_edge_mask = is_local_edge(edge_type)  # (E, )\n",
    "\n",
    "        # with the parameterization of NCSNv2\n",
    "        # DDPM loss implicit handle the noise variance scale conditioning\n",
    "        sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device)  # (E, 1)\n",
    "\n",
    "        # Encoding global\n",
    "        edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type)  # Embed edges\n",
    "\n",
    "        # Global\n",
    "        node_attr_global = self.encoder_global(\n",
    "            z=atom_type,\n",
    "            edge_index=edge_index,\n",
    "            edge_length=edge_length,\n",
    "            edge_attr=edge_attr_global,\n",
    "        )\n",
    "        # Assemble pairwise features\n",
    "        h_pair_global = assemble_atom_pair_feature(\n",
    "            node_attr=node_attr_global,\n",
    "            edge_index=edge_index,\n",
    "            edge_attr=edge_attr_global,\n",
    "        )  # (E_global, 2H)\n",
    "        # Invariant features of edges (radius graph, global)\n",
    "        edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge)  # (E_global, 1)\n",
    "\n",
    "        # Encoding local\n",
    "        edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type)  # Embed edges\n",
    "        # edge_attr += temb_edge\n",
    "\n",
    "        # Local\n",
    "        node_attr_local = self.encoder_local(\n",
    "            z=atom_type,\n",
    "            edge_index=edge_index[:, local_edge_mask],\n",
    "            edge_attr=edge_attr_local[local_edge_mask],\n",
    "        )\n",
    "        # Assemble pairwise features\n",
    "        h_pair_local = assemble_atom_pair_feature(\n",
    "            node_attr=node_attr_local,\n",
    "            edge_index=edge_index[:, local_edge_mask],\n",
    "            edge_attr=edge_attr_local[local_edge_mask],\n",
    "        )  # (E_local, 2H)\n",
    "\n",
    "        # Invariant features of edges (bond graph, local)\n",
    "        if isinstance(sigma_edge, torch.Tensor):\n",
    "            edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n",
    "                1.0 / sigma_edge[local_edge_mask]\n",
    "            )  # (E_local, 1)\n",
    "        else:\n",
    "            edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge)  # (E_local, 1)\n",
    "\n",
    "        if return_edges:\n",
    "            return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n",
    "        else:\n",
    "            return edge_inv_global, edge_inv_local\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        sample,\n",
    "        timestep: Union[torch.Tensor, float, int],\n",
    "        return_dict: bool = True,\n",
    "        sigma=1.0,\n",
    "        global_start_sigma=0.5,\n",
    "        w_global=1.0,\n",
    "        extend_order=False,\n",
    "        extend_radius=True,\n",
    "        clip_local=None,\n",
    "        clip_global=1000.0,\n",
    "    ) -> Union[MoleculeGNNOutput, Tuple]:\n",
    "        r\"\"\"\n",
    "        Args:\n",
    "            sample: packed torch geometric object\n",
    "            timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
    "            return_dict (`bool`, *optional*, defaults to `True`):\n",
    "                Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
    "        Returns:\n",
    "            [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n",
    "            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n",
    "        \"\"\"\n",
    "\n",
    "        # unpack sample\n",
    "        atom_type = sample.atom_type\n",
    "        bond_index = sample.edge_index\n",
    "        bond_type = sample.edge_type\n",
    "        num_graphs = sample.num_graphs\n",
    "        pos = sample.pos\n",
    "\n",
    "        timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n",
    "\n",
    "        edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n",
    "            atom_type=atom_type,\n",
    "            pos=sample.pos,\n",
    "            bond_index=bond_index,\n",
    "            bond_type=bond_type,\n",
    "            batch=sample.batch,\n",
    "            time_step=timesteps,\n",
    "            return_edges=True,\n",
    "            extend_order=extend_order,\n",
    "            extend_radius=extend_radius,\n",
    "        )  # (E_global, 1), (E_local, 1)\n",
    "\n",
    "        # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n",
    "        node_eq_local = graph_field_network(\n",
    "            edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n",
    "        )\n",
    "        if clip_local is not None:\n",
    "            node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n",
    "\n",
    "        # Global\n",
    "        if sigma < global_start_sigma:\n",
    "            edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n",
    "            node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n",
    "            node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n",
    "        else:\n",
    "            node_eq_global = 0\n",
    "\n",
    "        # Sum\n",
    "        eps_pos = node_eq_local + node_eq_global * w_global\n",
    "\n",
    "        if not return_dict:\n",
    "            return (-eps_pos,)\n",
    "\n",
    "        return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CCIrPYSJj9wd"
   },
   "source": [
    "### Load pretrained model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YdrAr6Ch--Ab"
   },
   "source": [
    "#### Load a model\n",
    "The model used is a design an\n",
    "equivariant convolutional layer, named graph field network (GFN).\n",
    "\n",
    "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 172,
     "referenced_widgets": [
      "d90f304e9560472eacfbdd11e46765eb",
      "1c6246f15b654f4daa11c9bcf997b78c",
      "c2321b3bff6f490ca12040a20308f555",
      "b7feb522161f4cf4b7cc7c1a078ff12d",
      "e2d368556e494ae7ae4e2e992af2cd4f",
      "bbef741e76ec41b7ab7187b487a383df",
      "561f742d418d4721b0670cc8dd62e22c",
      "872915dd1bb84f538c44e26badabafdd",
      "d022575f1fa2446d891650897f187b4d",
      "fdc393f3468c432aa0ada05e238a5436",
      "2c9362906e4b40189f16d14aa9a348da",
      "6010fc8daa7a44d5aec4b830ec2ebaa1",
      "7e0bb1b8d65249d3974200686b193be2",
      "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
      "6526646be5ed415c84d1245b040e629b",
      "24d31fc3576e43dd9f8301d2ef3a37ab",
      "2918bfaadc8d4b1a9832522c40dfefb8",
      "a4bfdca35cc54dae8812720f1b276a08",
      "e4901541199b45c6a18824627692fc39",
      "f915cf874246446595206221e900b2fe",
      "a9e388f22a9742aaaf538e22575c9433",
      "42f6c3db29d7484ba6b4f73590abd2f4"
     ]
    },
    "id": "DyCo0nsqjbml",
    "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d90f304e9560472eacfbdd11e46765eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/3.27M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6010fc8daa7a44d5aec4b830ec2ebaa1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/401 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
      "Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\n",
      "- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "DEVICE = \"cuda\"\n",
    "model = MoleculeGNN.from_pretrained(\"fusing/gfn-molecule-gen-drugs\").to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HdclRaqoUWUD"
   },
   "source": [
    "The warnings above are because the pre-trained model was uploaded before cleaning the code!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PlOkPySoJ1m9"
   },
   "source": [
    "#### Create scheduler\n",
    "Note, other schedulers are used in the paper for slightly improved performance over DDPM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nNHnIk9CkAb2"
   },
   "outputs": [],
   "source": [
    "from diffusers import DDPMScheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RnDJdDBztjFF"
   },
   "outputs": [],
   "source": [
    "num_timesteps = 1000\n",
    "scheduler = DDPMScheduler(\n",
    "    num_train_timesteps=num_timesteps, beta_schedule=\"sigmoid\", beta_start=1e-7, beta_end=2e-3, clip_sample=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1vh3fpSAflkL"
   },
   "source": [
    "### Get a dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B6qzaGjVKFVk"
   },
   "source": [
    "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n",
    "\n",
    "(direct downloading from the hub does not yet work for this datatype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jbLl3EJdgj3x"
   },
   "outputs": [],
   "source": [
    "# from google.colab import files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E591lVuTgxPE"
   },
   "outputs": [],
   "source": [
    "# uploaded = files.upload()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KUNxfK3ln98Q"
   },
   "source": [
    "Load the dataset with torch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "7L4iOShTpcQX",
    "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2022-10-12 18:32:19--  https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
      "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n",
      "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 127774 (125K) [application/octet-stream]\n",
      "Saving to: ‘molecules.pkl’\n",
      "\n",
      "molecules.pkl       100%[===================>] 124.78K   180KB/s    in 0.7s    \n",
      "\n",
      "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
    "dataset = torch.load(\"/content/molecules.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QZcmy1EvKQRk"
   },
   "source": [
    "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JVjz6iH_H6Eh",
    "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=<rdkit.Chem.rdchem.Mol object at 0x7f707d2cb130>, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vHNiZAUxNgoy"
   },
   "source": [
    "## Run the diffusion process"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jZ1KZrxKqENg"
   },
   "source": [
    "#### Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s240tYueqKKf"
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "import os\n",
    "\n",
    "from torch_geometric.data import Batch, Data\n",
    "from torch_scatter import scatter_mean\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def repeat_data(data: Data, num_repeat) -> Batch:\n",
    "    datas = [copy.deepcopy(data) for i in range(num_repeat)]\n",
    "    return Batch.from_data_list(datas)\n",
    "\n",
    "\n",
    "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n",
    "    datas = batch.to_data_list()\n",
    "    new_data = []\n",
    "    for i in range(num_repeat):\n",
    "        new_data += copy.deepcopy(datas)\n",
    "    return Batch.from_data_list(new_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AMnQTk0eqT7Z"
   },
   "source": [
    "#### Constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WYGkzqgzrHmF"
   },
   "outputs": [],
   "source": [
    "num_samples = 1  # solutions per molecule\n",
    "num_molecules = 3\n",
    "\n",
    "DEVICE = \"cuda\"\n",
    "sampling_type = \"ddpm_noisy\"  #'' # paper also uses \"generalize\" and \"ld\"\n",
    "# constants for inference\n",
    "w_global = 0.5  # 0,.3 for qm9\n",
    "global_start_sigma = 0.5\n",
    "eta = 1.0\n",
    "clip_local = None\n",
    "clip_pos = None\n",
    "\n",
    "# constants for data handling\n",
    "save_traj = False\n",
    "save_data = False\n",
    "output_dir = \"/content/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-xD5bJ3SqM7t"
   },
   "source": [
    "#### Generate samples!\n",
    "Note that the 3d representation of a molecule is referred to as the **conformation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "x9xuLUNg26z1",
    "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  after removing the cwd from sys.path.\n",
      "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "\n",
    "results = []\n",
    "\n",
    "# define sigmas\n",
    "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n",
    "sigmas = sigmas.to(DEVICE)\n",
    "\n",
    "for count, data in enumerate(tqdm(dataset)):\n",
    "    num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n",
    "\n",
    "    data_input = data.clone()\n",
    "    data_input[\"pos_ref\"] = None\n",
    "    batch = repeat_data(data_input, num_samples).to(DEVICE)\n",
    "\n",
    "    # initial configuration\n",
    "    pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n",
    "\n",
    "    # for logging animation of denoising\n",
    "    pos_traj = []\n",
    "    with torch.no_grad():\n",
    "        # scale initial sample\n",
    "        pos = pos_init * sigmas[-1]\n",
    "        for t in scheduler.timesteps:\n",
    "            batch.pos = pos\n",
    "\n",
    "            # generate geometry with model, then filter it\n",
    "            epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n",
    "\n",
    "            # Update\n",
    "            reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n",
    "\n",
    "            pos = reconstructed_pos\n",
    "\n",
    "            if torch.isnan(pos).any():\n",
    "                print(\"NaN detected. Please restart.\")\n",
    "                raise FloatingPointError()\n",
    "\n",
    "            # recenter graph of positions for next iteration\n",
    "            pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n",
    "\n",
    "            # optional clipping\n",
    "            if clip_pos is not None:\n",
    "                pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n",
    "            pos_traj.append(pos.clone().cpu())\n",
    "\n",
    "    pos_gen = pos.cpu()\n",
    "    if save_traj:\n",
    "        pos_gen_traj = pos_traj.cpu()\n",
    "        data.pos_gen = torch.stack(pos_gen_traj)\n",
    "    else:\n",
    "        data.pos_gen = pos_gen\n",
    "    results.append(data)\n",
    "\n",
    "\n",
    "if save_data:\n",
    "    save_path = os.path.join(output_dir, \"samples_all.pkl\")\n",
    "\n",
    "    with open(save_path, \"wb\") as f:\n",
    "        pickle.dump(results, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fSApwSaZNndW"
   },
   "source": [
    "## Render the results!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "d47Zxo2OKdgZ"
   },
   "source": [
    "This function allows us to render 3d in colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "e9Cd0kCAv9b8"
   },
   "outputs": [],
   "source": [
    "from google.colab import output\n",
    "\n",
    "\n",
    "output.enable_custom_widget_manager()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RjaVuR15NqzF"
   },
   "source": [
    "### Helper functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "28rBYa9NKhlz"
   },
   "source": [
    "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LKdKdwxcyTQ6"
   },
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "def set_rdmol_positions(rdkit_mol, pos):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.\n",
    "        pos: (N_atoms, 3)\n",
    "    \"\"\"\n",
    "    mol = deepcopy(rdkit_mol)\n",
    "    set_rdmol_positions_(mol, pos)\n",
    "    return mol\n",
    "\n",
    "\n",
    "def set_rdmol_positions_(mol, pos):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.\n",
    "        pos: (N_atoms, 3)\n",
    "    \"\"\"\n",
    "    for i in range(pos.shape[0]):\n",
    "        mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n",
    "    return mol"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NuE10hcpKmzK"
   },
   "source": [
    "Process the generated data to make it easy to view."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KieVE1vc0_Vs",
    "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "collect 5 generated molecules in `mols`\n"
     ]
    }
   ],
   "source": [
    "# the model can generate multiple conformations per 2d geometry\n",
    "num_gen = results[0][\"pos_gen\"].shape[0]\n",
    "\n",
    "# init storage objects\n",
    "mols_gen = []\n",
    "mols_orig = []\n",
    "for to_process in results:\n",
    "    # store the reference 3d position\n",
    "    to_process[\"pos_ref\"] = to_process[\"pos_ref\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n",
    "\n",
    "    # store the generated 3d position\n",
    "    to_process[\"pos_gen\"] = to_process[\"pos_gen\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n",
    "\n",
    "    # copy data to new object\n",
    "    new_mol = set_rdmol_positions(to_process.rdmol, to_process[\"pos_gen\"][0])\n",
    "\n",
    "    # append results\n",
    "    mols_gen.append(new_mol)\n",
    "    mols_orig.append(to_process.rdmol)\n",
    "\n",
    "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tin89JwMKp4v"
   },
   "source": [
    "Import tools to visualize the 2d chemical diagram of the molecule."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yqV6gllSZn38"
   },
   "outputs": [],
   "source": [
    "from IPython.display import SVG, display\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem.Draw import rdMolDraw2D as MD2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "TFNKmGddVoOk"
   },
   "source": [
    "Select molecule to visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KzuwLlrrVaGc"
   },
   "outputs": [],
   "source": [
    "idx = 0\n",
    "assert idx < len(results), \"selected molecule that was not generated\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hkb8w0_SNtU8"
   },
   "source": [
    "### Viewing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I3R4QBQeKttN"
   },
   "source": [
    "This 2D rendering is the equivalent of the **input to the model**!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 321
    },
    "id": "gkQRWjraaKex",
    "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47"
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<svg baseProfile=\"full\" height=\"300px\" version=\"1.1\" viewBox=\"0 0 450 300\" width=\"450px\" xml:space=\"preserve\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:rdkit=\"http://www.rdkit.org/xml\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<!-- END OF HEADER -->\n",
       "<rect height=\"300.0\" style=\"opacity:1.0;fill:#FFFFFF;stroke:none\" width=\"450.0\" x=\"0.0\" y=\"0.0\"> </rect>\n",
       "<path class=\"bond-0 atom-0 atom-1\" d=\"M 20.5,147.6 L 57.8,136.7\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-1 atom-1 atom-2\" d=\"M 57.8,136.7 L 67.1,98.9\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-2 atom-2 atom-3\" d=\"M 67.1,98.9 L 104.4,88.1\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-3 atom-3 atom-4\" d=\"M 104.4,88.1 L 132.5,115.0\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-4 atom-4 atom-5\" d=\"M 132.5,115.0 L 128.7,130.5\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-4 atom-4 atom-5\" d=\"M 128.7,130.5 L 124.9,146.0\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-5 atom-5 atom-6\" d=\"M 128.7,158.0 L 140.0,168.8\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-5 atom-5 atom-6\" d=\"M 140.0,168.8 L 151.3,179.7\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-6 atom-6 atom-7\" d=\"M 155.1,180.6 L 151.3,196.1\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-6 atom-6 atom-7\" d=\"M 151.3,196.1 L 147.5,211.5\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-6 atom-6 atom-7\" d=\"M 147.5,178.8 L 143.7,194.2\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-6 atom-6 atom-7\" d=\"M 143.7,194.2 L 139.9,209.7\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-7 atom-6 atom-8\" d=\"M 151.3,179.7 L 188.7,168.8\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-8 atom-8 atom-9\" d=\"M 188.7,168.8 L 216.7,195.8\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-9 atom-9 atom-10\" d=\"M 216.7,195.8 L 254.1,184.9\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-10 atom-10 atom-11\" d=\"M 254.1,184.9 L 257.9,169.4\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-10 atom-10 atom-11\" d=\"M 257.9,169.4 L 261.7,153.9\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-11 atom-11 atom-12\" d=\"M 268.8,145.5 L 282.4,141.6\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-11 atom-11 atom-12\" d=\"M 282.4,141.6 L 295.9,137.7\" style=\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-12 atom-12 atom-13\" d=\"M 295.0,130.6 L 291.6,118.8\" style=\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-12 atom-12 atom-13\" d=\"M 291.6,118.8 L 288.2,107.0\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-12 atom-12 atom-13\" d=\"M 302.5,128.4 L 299.1,116.6\" style=\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-12 atom-12 atom-13\" d=\"M 299.1,116.6 L 295.6,104.9\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-13 atom-12 atom-14\" d=\"M 306.5,142.3 L 309.9,154.0\" style=\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-13 atom-12 atom-14\" d=\"M 309.9,154.0 L 313.3,165.7\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-13 atom-12 atom-14\" d=\"M 299.0,144.4 L 302.4,156.1\" style=\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-13 atom-12 atom-14\" d=\"M 302.4,156.1 L 305.8,167.9\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-14 atom-12 atom-15\" d=\"M 305.5,134.9 L 321.8,130.1\" style=\"fill:none;fill-rule:evenodd;stroke:#CCCC00;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-14 atom-12 atom-15\" d=\"M 321.8,130.1 L 338.1,125.4\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-15 atom-15 atom-16\" d=\"M 338.1,125.4 L 347.4,87.6\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-15 atom-15 atom-16\" d=\"M 347.0,121.6 L 353.5,95.2\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-16 atom-16 atom-17\" d=\"M 347.4,87.6 L 384.7,76.8\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-17 atom-17 atom-18\" d=\"M 384.7,76.8 L 412.8,103.7\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-17 atom-17 atom-18\" d=\"M 383.5,86.4 L 403.2,105.3\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-18 atom-18 atom-19\" d=\"M 412.8,103.7 L 403.5,141.5\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-19 atom-19 atom-20\" d=\"M 403.5,141.5 L 412.1,154.2\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-19 atom-19 atom-20\" d=\"M 412.1,154.2 L 420.8,166.9\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-19 atom-19 atom-20\" d=\"M 399.7,149.7 L 405.7,158.6\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-19 atom-19 atom-20\" d=\"M 405.7,158.6 L 411.7,167.4\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-20 atom-20 atom-21\" d=\"M 420.1,180.5 L 413.5,189.0\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-20 atom-20 atom-21\" d=\"M 413.5,189.0 L 406.8,197.5\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-21 atom-21 atom-22\" d=\"M 395.2,202.1 L 382.8,197.7\" style=\"fill:none;fill-rule:evenodd;stroke:#FF0000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-21 atom-21 atom-22\" d=\"M 382.8,197.7 L 370.4,193.2\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-22 atom-22 atom-23\" d=\"M 365.1,184.4 L 365.6,168.4\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-22 atom-22 atom-23\" d=\"M 365.6,168.4 L 366.2,152.3\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-22 atom-22 atom-23\" d=\"M 373.1,179.9 L 373.4,168.6\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-22 atom-22 atom-23\" d=\"M 373.4,168.6 L 373.8,157.4\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-23 atom-11 atom-24\" d=\"M 257.9,141.9 L 246.6,131.1\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-23 atom-11 atom-24\" d=\"M 246.6,131.1 L 235.3,120.2\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-24 atom-24 atom-25\" d=\"M 235.3,120.2 L 197.9,131.1\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-25 atom-5 atom-26\" d=\"M 117.8,154.4 L 101.8,159.0\" style=\"fill:none;fill-rule:evenodd;stroke:#0000FF;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-25 atom-5 atom-26\" d=\"M 101.8,159.0 L 85.9,163.6\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-26 atom-26 atom-1\" d=\"M 85.9,163.6 L 57.8,136.7\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-27 atom-25 atom-8\" d=\"M 197.9,131.1 L 188.7,168.8\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-28 atom-23 atom-15\" d=\"M 366.2,152.3 L 338.1,125.4\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"bond-29 atom-23 atom-19\" d=\"M 366.2,152.3 L 403.5,141.5\" style=\"fill:none;fill-rule:evenodd;stroke:#000000;stroke-width:2.0px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1\"/>\n",
       "<path class=\"atom-5\" d=\"M 120.8 147.3 L 124.4 153.1 Q 124.8 153.7, 125.3 154.7 Q 125.9 155.8, 126.0 155.8 L 126.0 147.3 L 127.4 147.3 L 127.4 158.3 L 125.9 158.3 L 122.0 151.9 Q 121.6 151.2, 121.1 150.3 Q 120.6 149.4, 120.5 149.2 L 120.5 158.3 L 119.1 158.3 L 119.1 147.3 L 120.8 147.3 \" fill=\"#0000FF\"/>\n",
       "<path class=\"atom-7\" d=\"M 137.0 217.5 Q 137.0 214.9, 138.3 213.4 Q 139.6 211.9, 142.0 211.9 Q 144.5 211.9, 145.8 213.4 Q 147.1 214.9, 147.1 217.5 Q 147.1 220.2, 145.8 221.7 Q 144.4 223.2, 142.0 223.2 Q 139.6 223.2, 138.3 221.7 Q 137.0 220.2, 137.0 217.5 M 142.0 222.0 Q 143.7 222.0, 144.6 220.8 Q 145.5 219.7, 145.5 217.5 Q 145.5 215.3, 144.6 214.2 Q 143.7 213.1, 142.0 213.1 Q 140.4 213.1, 139.4 214.2 Q 138.5 215.3, 138.5 217.5 Q 138.5 219.7, 139.4 220.8 Q 140.4 222.0, 142.0 222.0 \" fill=\"#FF0000\"/>\n",
       "<path class=\"atom-11\" d=\"M 260.9 141.6 L 264.5 147.5 Q 264.9 148.0, 265.5 149.1 Q 266.1 150.1, 266.1 150.2 L 266.1 141.6 L 267.5 141.6 L 267.5 152.6 L 266.0 152.6 L 262.2 146.3 Q 261.7 145.5, 261.2 144.7 Q 260.8 143.8, 260.6 143.5 L 260.6 152.6 L 259.2 152.6 L 259.2 141.6 L 260.9 141.6 \" fill=\"#0000FF\"/>\n",
       "<path class=\"atom-12\" d=\"M 297.6 140.1 Q 297.7 140.1, 298.2 140.3 Q 298.8 140.5, 299.3 140.7 Q 299.9 140.8, 300.5 140.8 Q 301.5 140.8, 302.1 140.3 Q 302.7 139.8, 302.7 138.9 Q 302.7 138.3, 302.4 137.9 Q 302.1 137.6, 301.6 137.3 Q 301.2 137.1, 300.4 136.9 Q 299.4 136.6, 298.8 136.3 Q 298.2 136.1, 297.8 135.5 Q 297.4 134.9, 297.4 133.9 Q 297.4 132.5, 298.4 131.6 Q 299.3 130.8, 301.2 130.8 Q 302.4 130.8, 303.9 131.4 L 303.5 132.6 Q 302.2 132.0, 301.2 132.0 Q 300.1 132.0, 299.6 132.5 Q 299.0 132.9, 299.0 133.7 Q 299.0 134.3, 299.3 134.6 Q 299.6 135.0, 300.0 135.2 Q 300.5 135.4, 301.2 135.6 Q 302.2 135.9, 302.8 136.3 Q 303.4 136.6, 303.8 137.2 Q 304.3 137.8, 304.3 138.9 Q 304.3 140.4, 303.2 141.3 Q 302.2 142.1, 300.5 142.1 Q 299.5 142.1, 298.8 141.8 Q 298.1 141.6, 297.2 141.3 L 297.6 140.1 \" fill=\"#CCCC00\"/>\n",
       "<path class=\"atom-13\" d=\"M 284.8 99.0 Q 284.8 96.3, 286.1 94.8 Q 287.4 93.4, 289.9 93.4 Q 292.3 93.4, 293.6 94.8 Q 294.9 96.3, 294.9 99.0 Q 294.9 101.6, 293.6 103.2 Q 292.3 104.7, 289.9 104.7 Q 287.4 104.7, 286.1 103.2 Q 284.8 101.6, 284.8 99.0 M 289.9 103.4 Q 291.5 103.4, 292.5 102.3 Q 293.4 101.2, 293.4 99.0 Q 293.4 96.8, 292.5 95.7 Q 291.5 94.6, 289.9 94.6 Q 288.2 94.6, 287.3 95.7 Q 286.4 96.8, 286.4 99.0 Q 286.4 101.2, 287.3 102.3 Q 288.2 103.4, 289.9 103.4 \" fill=\"#FF0000\"/>\n",
       "<path class=\"atom-14\" d=\"M 306.5 173.7 Q 306.5 171.0, 307.8 169.5 Q 309.1 168.1, 311.6 168.1 Q 314.0 168.1, 315.3 169.5 Q 316.6 171.0, 316.6 173.7 Q 316.6 176.3, 315.3 177.9 Q 314.0 179.4, 311.6 179.4 Q 309.1 179.4, 307.8 177.9 Q 306.5 176.4, 306.5 173.7 M 311.6 178.1 Q 313.3 178.1, 314.2 177.0 Q 315.1 175.9, 315.1 173.7 Q 315.1 171.5, 314.2 170.4 Q 313.3 169.3, 311.6 169.3 Q 309.9 169.3, 309.0 170.4 Q 308.1 171.5, 308.1 173.7 Q 308.1 175.9, 309.0 177.0 Q 309.9 178.1, 311.6 178.1 \" fill=\"#FF0000\"/>\n",
       "<path class=\"atom-20\" d=\"M 422.9 168.2 L 426.5 174.0 Q 426.9 174.6, 427.5 175.6 Q 428.1 176.6, 428.1 176.7 L 428.1 168.2 L 429.5 168.2 L 429.5 179.2 L 428.0 179.2 L 424.2 172.8 Q 423.7 172.0, 423.2 171.2 Q 422.8 170.3, 422.6 170.1 L 422.6 179.2 L 421.2 179.2 L 421.2 168.2 L 422.9 168.2 \" fill=\"#0000FF\"/>\n",
       "<path class=\"atom-21\" d=\"M 396.5 204.4 Q 396.5 201.8, 397.8 200.3 Q 399.1 198.8, 401.5 198.8 Q 404.0 198.8, 405.3 200.3 Q 406.6 201.8, 406.6 204.4 Q 406.6 207.1, 405.3 208.6 Q 403.9 210.1, 401.5 210.1 Q 399.1 210.1, 397.8 208.6 Q 396.5 207.1, 396.5 204.4 M 401.5 208.9 Q 403.2 208.9, 404.1 207.8 Q 405.0 206.6, 405.0 204.4 Q 405.0 202.3, 404.1 201.2 Q 403.2 200.1, 401.5 200.1 Q 399.8 200.1, 398.9 201.2 Q 398.0 202.2, 398.0 204.4 Q 398.0 206.7, 398.9 207.8 Q 399.8 208.9, 401.5 208.9 \" fill=\"#FF0000\"/>\n",
       "<path class=\"atom-22\" d=\"M 362.5 185.7 L 366.1 191.5 Q 366.5 192.1, 367.0 193.2 Q 367.6 194.2, 367.6 194.3 L 367.6 185.7 L 369.1 185.7 L 369.1 196.7 L 367.6 196.7 L 363.7 190.4 Q 363.3 189.6, 362.8 188.7 Q 362.3 187.9, 362.2 187.6 L 362.2 196.7 L 360.8 196.7 L 360.8 185.7 L 362.5 185.7 \" fill=\"#0000FF\"/>\n",
       "</svg>"
      ],
      "text/plain": [
       "<IPython.core.display.SVG object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "mc = Chem.MolFromSmiles(dataset[0][\"smiles\"])\n",
    "molSize = (450, 300)\n",
    "drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1])\n",
    "drawer.DrawMolecule(mc)\n",
    "drawer.FinishDrawing()\n",
    "svg = drawer.GetDrawingText()\n",
    "display(SVG(svg.replace(\"svg:\", \"\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z4FDMYMxKw2I"
   },
   "source": [
    "Generate the 3d molecule!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 17,
     "referenced_widgets": [
      "695ab5bbf30a4ab19df1f9f33469f314",
      "eac6a8dcdc9d4335a2e51031793ead29"
     ]
    },
    "id": "aT1Bkb8YxJfV",
    "outputId": "b98870ae-049d-4386-b676-166e9526bda2"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "695ab5bbf30a4ab19df1f9f33469f314",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": []
     },
     "metadata": {
      "application/vnd.jupyter.widget-view+json": {
       "colab": {
        "custom_widget_manager": {
         "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
        }
       }
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from nglview import show_rdkit as show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 337,
     "referenced_widgets": [
      "be446195da2b4ff2aec21ec5ff963a54",
      "c6596896148b4a8a9c57963b67c7782f",
      "2489b5e5648541fbbdceadb05632a050",
      "01e0ba4e5da04914b4652b8d58565d7b",
      "c30e6c2f3e2a44dbbb3d63bd519acaa4",
      "f31c6e40e9b2466a9064a2669933ecd5",
      "19308ccac642498ab8b58462e3f1b0bb",
      "4a081cdc2ec3421ca79dd933b7e2b0c4",
      "e5c0d75eb5e1447abd560c8f2c6017e1",
      "5146907ef6764654ad7d598baebc8b58",
      "144ec959b7604a2cabb5ca46ae5e5379",
      "abce2a80e6304df3899109c6d6cac199",
      "65195cb7a4134f4887e9dd19f3676462"
     ]
    },
    "id": "pxtq8I-I18C-",
    "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be446195da2b4ff2aec21ec5ff963a54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "NGLWidget()"
      ]
     },
     "metadata": {
      "application/vnd.jupyter.widget-view+json": {
       "colab": {
        "custom_widget_manager": {
         "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
        }
       }
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# new molecule\n",
    "show(mols_gen[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KJr4h2mwXeTo"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "01e0ba4e5da04914b4652b8d58565d7b": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1",
       "IPY_MODEL_5146907ef6764654ad7d598baebc8b58"
      ],
      "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379"
     }
    },
    "144ec959b7604a2cabb5ca46ae5e5379": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "19308ccac642498ab8b58462e3f1b0bb": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "1c6246f15b654f4daa11c9bcf997b78c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df",
      "placeholder": "​",
      "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c",
      "value": "Downloading: 100%"
     }
    },
    "2489b5e5648541fbbdceadb05632a050": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ButtonModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ButtonModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ButtonView",
      "button_style": "",
      "description": "",
      "disabled": false,
      "icon": "compress",
      "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199",
      "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462",
      "tooltip": ""
     }
    },
    "24d31fc3576e43dd9f8301d2ef3a37ab": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2918bfaadc8d4b1a9832522c40dfefb8": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2c9362906e4b40189f16d14aa9a348da": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "42f6c3db29d7484ba6b4f73590abd2f4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "4a081cdc2ec3421ca79dd933b7e2b0c4": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "SliderStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "SliderStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": "",
      "handle_color": null
     }
    },
    "5146907ef6764654ad7d598baebc8b58": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "IntSliderModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "IntSliderModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "IntSliderView",
      "continuous_update": true,
      "description": "",
      "description_tooltip": null,
      "disabled": false,
      "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb",
      "max": 0,
      "min": 0,
      "orientation": "horizontal",
      "readout": true,
      "readout_format": "d",
      "step": 1,
      "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4",
      "value": 0
     }
    },
    "561f742d418d4721b0670cc8dd62e22c": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "6010fc8daa7a44d5aec4b830ec2ebaa1": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2",
       "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
       "IPY_MODEL_6526646be5ed415c84d1245b040e629b"
      ],
      "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab"
     }
    },
    "65195cb7a4134f4887e9dd19f3676462": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ButtonStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ButtonStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "button_color": null,
      "font_weight": ""
     }
    },
    "6526646be5ed415c84d1245b040e629b": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433",
      "placeholder": "​",
      "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4",
      "value": " 401/401 [00:00&lt;00:00, 13.5kB/s]"
     }
    },
    "695ab5bbf30a4ab19df1f9f33469f314": {
     "model_module": "nglview-js-widgets",
     "model_module_version": "3.0.1",
     "model_name": "ColormakerRegistryModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "nglview-js-widgets",
      "_model_module_version": "3.0.1",
      "_model_name": "ColormakerRegistryModel",
      "_msg_ar": [],
      "_msg_q": [],
      "_ready": false,
      "_view_count": null,
      "_view_module": "nglview-js-widgets",
      "_view_module_version": "3.0.1",
      "_view_name": "ColormakerRegistryView",
      "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29"
     }
    },
    "7e0bb1b8d65249d3974200686b193be2": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8",
      "placeholder": "​",
      "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08",
      "value": "Downloading: 100%"
     }
    },
    "872915dd1bb84f538c44e26badabafdd": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a4bfdca35cc54dae8812720f1b276a08": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "a9e388f22a9742aaaf538e22575c9433": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "abce2a80e6304df3899109c6d6cac199": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": "34px"
     }
    },
    "b7feb522161f4cf4b7cc7c1a078ff12d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436",
      "placeholder": "​",
      "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da",
      "value": " 3.27M/3.27M [00:01&lt;00:00, 3.25MB/s]"
     }
    },
    "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39",
      "max": 401,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f915cf874246446595206221e900b2fe",
      "value": 401
     }
    },
    "bbef741e76ec41b7ab7187b487a383df": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "be446195da2b4ff2aec21ec5ff963a54": {
     "model_module": "nglview-js-widgets",
     "model_module_version": "3.0.1",
     "model_name": "NGLModel",
     "state": {
      "_camera_orientation": [
       -15.519693580202304,
       -14.065056548036177,
       -23.53197484807691,
       0,
       -23.357853515109753,
       20.94055073042662,
       2.888695042134944,
       0,
       14.352363398292775,
       18.870825741878015,
       -20.744689572909344,
       0,
       0.2724999189376831,
       0.6940000057220459,
       -0.3734999895095825,
       1
      ],
      "_camera_str": "orthographic",
      "_dom_classes": [],
      "_gui_theme": null,
      "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050",
      "_igui": null,
      "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b",
      "_model_module": "nglview-js-widgets",
      "_model_module_version": "3.0.1",
      "_model_name": "NGLModel",
      "_ngl_color_dict": {},
      "_ngl_coordinate_resource": {},
      "_ngl_full_stage_parameters": {
       "ambientColor": 14540253,
       "ambientIntensity": 0.2,
       "backgroundColor": "white",
       "cameraEyeSep": 0.3,
       "cameraFov": 40,
       "cameraType": "perspective",
       "clipDist": 10,
       "clipFar": 100,
       "clipNear": 0,
       "fogFar": 100,
       "fogNear": 50,
       "hoverTimeout": 0,
       "impostor": true,
       "lightColor": 14540253,
       "lightIntensity": 1,
       "mousePreset": "default",
       "panSpeed": 1,
       "quality": "medium",
       "rotateSpeed": 2,
       "sampleLevel": 0,
       "tooltip": true,
       "workerDefault": true,
       "zoomSpeed": 1.2
      },
      "_ngl_msg_archive": [
       {
        "args": [
         {
          "binary": false,
          "data": "HETATM    1  C1  UNL     1      -0.025   3.128   2.316  1.00  0.00           C  \nHETATM    2  H1  UNL     1       0.183   3.657   2.823  1.00  0.00           H  \nHETATM    3  C2  UNL     1       0.590   3.559   0.963  1.00  0.00           C  \nHETATM    4  C3  UNL     1       0.056   4.479   0.406  1.00  0.00           C  \nHETATM    5  C4  UNL     1      -0.219   4.802  -1.065  1.00  0.00           C  \nHETATM    6  H2  UNL     1       0.686   4.431  -1.575  1.00  0.00           H  \nHETATM    7  H3  UNL     1      -0.524   5.217  -1.274  1.00  0.00           H  \nHETATM    8  C5  UNL     1      -1.284   3.766  -1.342  1.00  0.00           C  \nHETATM    9  N1  UNL     1      -1.073   2.494  -0.580  1.00  0.00           N  \nHETATM   10  C6  UNL     1      -1.909   1.494  -0.964  1.00  0.00           C  \nHETATM   11  O1  UNL     1      -2.487   1.531  -2.092  1.00  0.00           O  \nHETATM   12  C7  UNL     1      -2.232   0.242  -0.130  1.00  0.00           C  \nHETATM   13  C8  UNL     1      -2.161  -1.057  -1.037  1.00  0.00           C  \nHETATM   14  C9  UNL     1      -0.744  -1.111  -1.610  1.00  0.00           C  \nHETATM   15  N2  UNL     1       0.290  -0.917  -0.628  1.00  0.00           N  \nHETATM   16  S1  UNL     1       1.717  -1.597  -0.914  1.00  0.00           S  \nHETATM   17  O2  UNL     1       1.960  -1.671  -2.338  1.00  0.00           O  \nHETATM   18  O3  UNL     1       2.713  -0.968  -0.082  1.00  0.00           O  \nHETATM   19  C10 UNL     1       1.425  -3.170  -0.345  1.00  0.00           C  \nHETATM   20  C11 UNL     1       1.225  -4.400  -1.271  1.00  0.00           C  \nHETATM   21  C12 UNL     1       1.314  -5.913  -0.895  1.00  0.00           C  \nHETATM   22  C13 UNL     1       1.823  -6.229   0.386  1.00  0.00           C  \nHETATM   23  C14 UNL     1       2.031  -5.110   1.365  1.00  0.00           C  \nHETATM   24  N3  UNL     1       1.850  -5.267   2.712  1.00  0.00           N  \nHETATM   25  O4  UNL     1       1.382  -4.029   3.126  1.00  0.00           O  \nHETATM   26  N4  UNL     1       1.300  -3.023   2.154  1.00  0.00           N  \nHETATM   27  C15 UNL     1       1.731  -3.672   1.032  1.00  0.00           C  \nHETATM   28  H4  UNL     1       2.380  -6.874   0.436  1.00  0.00           H  \nHETATM   29  H5  UNL     1       0.704  -6.526  -1.420  1.00  0.00           H  \nHETATM   30  H6  UNL     1       1.144  -4.035  -2.291  1.00  0.00           H  \nHETATM   31  C16 UNL     1       0.044  -0.371   0.685  1.00  0.00           C  \nHETATM   32  C17 UNL     1      -1.352  -0.045   1.077  1.00  0.00           C  \nHETATM   33  H7  UNL     1      -1.395   0.770   1.768  1.00  0.00           H  \nHETATM   34  H8  UNL     1      -1.792  -0.941   1.582  1.00  0.00           H  \nHETATM   35  H9  UNL     1       0.583  -1.035   1.393  1.00  0.00           H  \nHETATM   36  H10 UNL     1       0.664   0.613   0.663  1.00  0.00           H  \nHETATM   37  H11 UNL     1      -0.631  -0.267  -2.335  1.00  0.00           H  \nHETATM   38  H12 UNL     1      -0.571  -2.046  -2.098  1.00  0.00           H  \nHETATM   39  H13 UNL     1      -2.872  -0.992  -1.826  1.00  0.00           H  \nHETATM   40  H14 UNL     1      -2.370  -1.924  -0.444  1.00  0.00           H  \nHETATM   41  H15 UNL     1      -3.258   0.364   0.197  1.00  0.00           H  \nHETATM   42  C18 UNL     1       0.276   2.337  -0.078  1.00  0.00           C  \nHETATM   43  H16 UNL     1       0.514   1.371   0.252  1.00  0.00           H  \nHETATM   44  H17 UNL     1       0.988   2.413  -0.949  1.00  0.00           H  \nHETATM   45  H18 UNL     1      -1.349   3.451  -2.379  1.00  0.00           H  \nHETATM   46  H19 UNL     1      -2.224   4.055  -0.958  1.00  0.00           H  \nHETATM   47  H20 UNL     1       0.793   5.486   0.669  1.00  0.00           H  \nHETATM   48  H21 UNL     1      -0.849   4.974   0.937  1.00  0.00           H  \nHETATM   49  H22 UNL     1       1.667   3.431   1.070  1.00  0.00           H  \nHETATM   50  H23 UNL     1       0.379   2.143   2.689  1.00  0.00           H  \nHETATM   51  H24 UNL     1      -1.094   2.983   2.223  1.00  0.00           H  \nCONECT    1    2    3   50   51\nCONECT    3    4   42   49\nCONECT    4    5   47   48\nCONECT    5    6    7    8\nCONECT    8    9   45   46\nCONECT    9   10   42\nCONECT   10   11   11   12\nCONECT   12   13   32   41\nCONECT   13   14   39   40\nCONECT   14   15   37   38\nCONECT   15   16   31\nCONECT   16   17   17   18   18\nCONECT   16   19\nCONECT   19   20   20   27\nCONECT   20   21   30\nCONECT   21   22   22   29\nCONECT   22   23   28\nCONECT   23   24   24   27\nCONECT   24   25\nCONECT   25   26\nCONECT   26   27   27\nCONECT   31   32   35   36\nCONECT   32   33   34\nCONECT   42   43   44\nEND\n",
          "type": "blob"
         }
        ],
        "kwargs": {
         "defaultRepresentation": true,
         "ext": "pdb"
        },
        "methodName": "loadFile",
        "reconstruc_color_scheme": false,
        "target": "Stage",
        "type": "call_method"
       }
      ],
      "_ngl_original_stage_parameters": {
       "ambientColor": 14540253,
       "ambientIntensity": 0.2,
       "backgroundColor": "white",
       "cameraEyeSep": 0.3,
       "cameraFov": 40,
       "cameraType": "perspective",
       "clipDist": 10,
       "clipFar": 100,
       "clipNear": 0,
       "fogFar": 100,
       "fogNear": 50,
       "hoverTimeout": 0,
       "impostor": true,
       "lightColor": 14540253,
       "lightIntensity": 1,
       "mousePreset": "default",
       "panSpeed": 1,
       "quality": "medium",
       "rotateSpeed": 2,
       "sampleLevel": 0,
       "tooltip": true,
       "workerDefault": true,
       "zoomSpeed": 1.2
      },
      "_ngl_repr_dict": {
       "0": {
        "0": {
         "params": {
          "aspectRatio": 1.5,
          "assembly": "default",
          "bondScale": 0.3,
          "bondSpacing": 0.75,
          "clipCenter": {
           "x": 0,
           "y": 0,
           "z": 0
          },
          "clipNear": 0,
          "clipRadius": 0,
          "colorMode": "hcl",
          "colorReverse": false,
          "colorScale": "",
          "colorScheme": "element",
          "colorValue": 9474192,
          "cylinderOnly": false,
          "defaultAssembly": "",
          "depthWrite": true,
          "diffuse": 16777215,
          "diffuseInterior": false,
          "disableImpostor": false,
          "disablePicking": false,
          "flatShaded": false,
          "interiorColor": 2236962,
          "interiorDarkening": 0,
          "lazy": false,
          "lineOnly": false,
          "linewidth": 2,
          "matrix": {
           "elements": [
            1,
            0,
            0,
            0,
            0,
            1,
            0,
            0,
            0,
            0,
            1,
            0,
            0,
            0,
            0,
            1
           ]
          },
          "metalness": 0,
          "multipleBond": "off",
          "opacity": 1,
          "openEnded": true,
          "quality": "high",
          "radialSegments": 20,
          "radiusData": {},
          "radiusScale": 2,
          "radiusSize": 0.15,
          "radiusType": "size",
          "roughness": 0.4,
          "sele": "",
          "side": "double",
          "sphereDetail": 2,
          "useInteriorColor": true,
          "visible": true,
          "wireframe": false
         },
         "type": "ball+stick"
        }
       },
       "1": {
        "0": {
         "params": {
          "aspectRatio": 1.5,
          "assembly": "default",
          "bondScale": 0.3,
          "bondSpacing": 0.75,
          "clipCenter": {
           "x": 0,
           "y": 0,
           "z": 0
          },
          "clipNear": 0,
          "clipRadius": 0,
          "colorMode": "hcl",
          "colorReverse": false,
          "colorScale": "",
          "colorScheme": "element",
          "colorValue": 9474192,
          "cylinderOnly": false,
          "defaultAssembly": "",
          "depthWrite": true,
          "diffuse": 16777215,
          "diffuseInterior": false,
          "disableImpostor": false,
          "disablePicking": false,
          "flatShaded": false,
          "interiorColor": 2236962,
          "interiorDarkening": 0,
          "lazy": false,
          "lineOnly": false,
          "linewidth": 2,
          "matrix": {
           "elements": [
            1,
            0,
            0,
            0,
            0,
            1,
            0,
            0,
            0,
            0,
            1,
            0,
            0,
            0,
            0,
            1
           ]
          },
          "metalness": 0,
          "multipleBond": "off",
          "opacity": 1,
          "openEnded": true,
          "quality": "high",
          "radialSegments": 20,
          "radiusData": {},
          "radiusScale": 2,
          "radiusSize": 0.15,
          "radiusType": "size",
          "roughness": 0.4,
          "sele": "",
          "side": "double",
          "sphereDetail": 2,
          "useInteriorColor": true,
          "visible": true,
          "wireframe": false
         },
         "type": "ball+stick"
        }
       }
      },
      "_ngl_serialize": false,
      "_ngl_version": "",
      "_ngl_view_id": [
       "FB989FD1-5B9C-446B-8914-6B58AF85446D"
      ],
      "_player_dict": {},
      "_scene_position": {},
      "_scene_rotation": {},
      "_synced_model_ids": [],
      "_synced_repr_model_ids": [],
      "_view_count": null,
      "_view_height": "",
      "_view_module": "nglview-js-widgets",
      "_view_module_version": "3.0.1",
      "_view_name": "NGLView",
      "_view_width": "",
      "background": "white",
      "frame": 0,
      "gui_style": null,
      "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f",
      "max_frame": 0,
      "n_components": 2,
      "picked": {}
     }
    },
    "c2321b3bff6f490ca12040a20308f555": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd",
      "max": 3271865,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d",
      "value": 3271865
     }
    },
    "c30e6c2f3e2a44dbbb3d63bd519acaa4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c6596896148b4a8a9c57963b67c7782f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d022575f1fa2446d891650897f187b4d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "d90f304e9560472eacfbdd11e46765eb": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c",
       "IPY_MODEL_c2321b3bff6f490ca12040a20308f555",
       "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d"
      ],
      "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f"
     }
    },
    "e2d368556e494ae7ae4e2e992af2cd4f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "e4901541199b45c6a18824627692fc39": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "e5c0d75eb5e1447abd560c8f2c6017e1": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "PlayModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "PlayModel",
      "_playing": false,
      "_repeat": false,
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "PlayView",
      "description": "",
      "description_tooltip": null,
      "disabled": false,
      "interval": 100,
      "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4",
      "max": 0,
      "min": 0,
      "show_repeat": true,
      "step": 1,
      "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5",
      "value": 0
     }
    },
    "eac6a8dcdc9d4335a2e51031793ead29": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f31c6e40e9b2466a9064a2669933ecd5": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "f915cf874246446595206221e900b2fe": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "fdc393f3468c432aa0ada05e238a5436": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
