{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "_l0tU8gum9eT"
      },
      "source": [
        "# A Gentle Introduction to Geometric Graph Neural Networks\n",
        "\n",
        "Graph Neural Networks (GNNs) are a part of an emerging research paradigm called **Geometric Deep Learning** -- devising neural network architectures that respect the invariances and symmetries in data. This practical aims to be a gentle introduction into the world of Geometric Deep Learning.\n",
        "\n",
        "We recommend opening a copy of this notebook via **Google Collab** (working locally is also easy):\n",
        "\n",
        "<a target=\"_blank\" href=\"https://colab.research.google.com/github/chaitjo/geometric-gnn-dojo/blob/main/geometric_gnn_101.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "The **aims** of this practical are as follows:\n",
        "\n",
        "* Understanding the concepts of **invariance** and **equivariance** to symmetries, which are fundamental properties of Graph Neural Networks. We will cover theory and proofs, as well as programming and unit testing.\n",
        "* Becoming hands-on with [**PyTorch Geometric**](https://pytorch-geometric.readthedocs.io/en/latest/) (PyG), a popular libary for developing state-of-the-art GNNs and Geometric Deep Learning models. In particular, gaining familiarity with the `MessagePassing` base class for designing novel GNN layers and the `Data` object for representing graph datasets.\n",
        "* Gaining an appreciation of the fundamental principles behind constructing GNN layers that take advantage of **geometric information** for graphs embedded in **3D space**, such as biomolecules, materials, and other physical systems.\n",
        "\n",
        "## Authors and Acknowledgements\n",
        "\n",
        "**Here are the authors**: do not hesitate to reach out to us for any queries and feedback on your solutions!\n",
        "\n",
        "- Chaitanya K. Joshi (ckj24@cl.cam.ac.uk)\n",
        "- Charlie Harris (cch57@cam.ac.uk)\n",
        "- Ramon Viñas Torné (rv340@cam.ac.uk)\n",
        "\n",
        "This notebook was initially developed for students taking the following courses:\n",
        "- [Representation Learning on Graphs and Networks](https://www.cl.cam.ac.uk/teaching/2122/L45), at University of Cambridge's Department of Computer Science and Technology (instructors: Prof. Pietro Liò, Dr. Petar Veličković).\n",
        "- [Geometric Deep Learning](https://aimsammi.org/), at the African Master’s in Machine Intelligence (instructors: Prof. Michael Bronstein, Prof. Joan Bruna, Dr. Taco Cohen, Dr. Petar Veličković).\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N0oFTe7ZiO8i"
      },
      "source": [
        "# ⚙️ Part 0: Installation and Setup\n",
        "\n",
        "**❗️Note:** You will need a GPU to complete this practical. If using Collab, remember to click `Runtime -> Change runtime type`, and set the `hardware accelerator` to **GPU**."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "HwOueIhIQPBd"
      },
      "outputs": [],
      "source": [
        "#@title [RUN] Install required python libraries\n",
        "import os\n",
        "\n",
        "# Install PyTorch Geometric and other libraries\n",
        "if 'IS_GRADESCOPE_ENV' not in os.environ:\n",
        "    !pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html\n",
        "    !pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html\n",
        "    !pip install -q torch-geometric==2.0.3\n",
        "    !pip install -q rdkit-pypi==2021.9.4\n",
        "    !pip install -q py3Dmol==1.8.0\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "mvIHO8B_RjeG"
      },
      "outputs": [],
      "source": [
        "#@title [RUN] Import python modules\n",
        "\n",
        "import os\n",
        "import time\n",
        "import random\n",
        "import numpy as np\n",
        "\n",
        "from scipy.stats import ortho_group\n",
        "\n",
        "import torch\n",
        "import torch.nn.functional as F\n",
        "from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential\n",
        "\n",
        "import torch_geometric\n",
        "from torch_geometric.data import Data\n",
        "from torch_geometric.data import Batch\n",
        "from torch_geometric.datasets import QM9\n",
        "import torch_geometric.transforms as T\n",
        "from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.nn import MessagePassing, global_mean_pool\n",
        "from torch_geometric.datasets import QM9\n",
        "from torch_scatter import scatter\n",
        "\n",
        "import rdkit.Chem as Chem\n",
        "from rdkit.Geometry.rdGeometry import Point3D\n",
        "from rdkit.Chem import QED, Crippen, rdMolDescriptors, rdmolops\n",
        "from rdkit.Chem.Draw import IPythonConsole\n",
        "\n",
        "import py3Dmol\n",
        "from rdkit.Chem import AllChem\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import pandas as pd\n",
        "\n",
        "from google.colab import files\n",
        "from IPython.display import HTML\n",
        "\n",
        "print(\"PyTorch version {}\".format(torch.__version__))\n",
        "print(\"PyG version {}\".format(torch_geometric.__version__))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "voXIXbkVOeZm"
      },
      "outputs": [],
      "source": [
        "#@title [RUN] Set random seed for deterministic results\n",
        "\n",
        "def seed(seed=0):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "seed(0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "rFNre1NLdMvT"
      },
      "outputs": [],
      "source": [
        "#@title [RUN] Helper functions for data preparation\n",
        "\n",
        "class SetTarget(object):\n",
        "    \"\"\"\n",
        "    This transform mofifies the labels vector per data sample to only keep \n",
        "    the label for a specific target (there are 19 targets in QM9).\n",
        "\n",
        "    Note: for this practical, we have hardcoded the target to be target #0,\n",
        "    i.e. the electric dipole moment of a drug-like molecule.\n",
        "    (https://en.wikipedia.org/wiki/Electric_dipole_moment)\n",
        "    \"\"\"\n",
        "    def __call__(self, data):\n",
        "        target = 0 # we hardcoded choice of target  \n",
        "        data.y = data.y[:, target]\n",
        "        return data\n",
        "\n",
        "\n",
        "class CompleteGraph(object):\n",
        "    \"\"\"\n",
        "    This transform adds all pairwise edges into the edge index per data sample, \n",
        "    then removes self loops, i.e. it builds a fully connected or complete graph\n",
        "    \"\"\"\n",
        "    def __call__(self, data):\n",
        "        device = data.edge_index.device\n",
        "\n",
        "        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)\n",
        "        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)\n",
        "\n",
        "        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)\n",
        "        col = col.repeat(data.num_nodes)\n",
        "        edge_index = torch.stack([row, col], dim=0)\n",
        "\n",
        "        edge_attr = None\n",
        "        if data.edge_attr is not None:\n",
        "            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]\n",
        "            size = list(data.edge_attr.size())\n",
        "            size[0] = data.num_nodes * data.num_nodes\n",
        "            edge_attr = data.edge_attr.new_zeros(size)\n",
        "            edge_attr[idx] = data.edge_attr\n",
        "\n",
        "        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)\n",
        "        data.edge_attr = edge_attr\n",
        "        data.edge_index = edge_index\n",
        "\n",
        "        return data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ExJ0b3xcQl5n"
      },
      "outputs": [],
      "source": [
        "#@title [RUN] Helper functions for visualization\n",
        "\n",
        "allowable_atoms = [\n",
        "    \"H\",\n",
        "    \"C\",\n",
        "    \"N\",\n",
        "    \"O\",\n",
        "    \"F\",\n",
        "    \"C\",\n",
        "    \"Cl\",\n",
        "    \"Br\",\n",
        "    \"I\",\n",
        "    \"H\", \n",
        "    \"Unknown\",\n",
        "]\n",
        "\n",
        "def to_atom(t):\n",
        "    try:\n",
        "        return allowable_atoms[int(t.argmax())]\n",
        "    except:\n",
        "        return \"C\"\n",
        "\n",
        "\n",
        "def to_bond_index(t):\n",
        "    t_s = t.squeeze()\n",
        "    return [1, 2, 3, 4][\n",
        "        int(\n",
        "            torch.dot(\n",
        "                t_s,\n",
        "                torch.tensor(\n",
        "                    range(t_s.size()[0]), dtype=torch.float, device=t.device\n",
        "                ),\n",
        "            ).item()\n",
        "        )\n",
        "    ]\n",
        "\n",
        "def to_rdkit(data, device=None):\n",
        "    has_pos = False\n",
        "    node_list = []\n",
        "    for i in range(data.x.size()[0]):\n",
        "        node_list.append(to_atom(data.x[i][:5]))\n",
        "\n",
        "    # create empty editable mol object\n",
        "    mol = Chem.RWMol()\n",
        "    # add atoms to mol and keep track of index\n",
        "    node_to_idx = {}\n",
        "    invalid_idx = set([])\n",
        "    for i in range(len(node_list)):\n",
        "        if node_list[i] == \"Stop\" or node_list[i] == \"H\":\n",
        "            invalid_idx.add(i)\n",
        "            continue\n",
        "        a = Chem.Atom(node_list[i])\n",
        "        molIdx = mol.AddAtom(a)\n",
        "        node_to_idx[i] = molIdx\n",
        "\n",
        "    added_bonds = set([])\n",
        "    for i in range(0, data.edge_index.size()[1]):\n",
        "        ix = data.edge_index[0][i].item()\n",
        "        iy = data.edge_index[1][i].item()\n",
        "        bond = to_bond_index(data.edge_attr[i])  # <font color='red'>TODO</font> fix this\n",
        "        # bond = 1\n",
        "        # add bonds between adjacent atoms\n",
        "\n",
        "        if data.edge_attr[i].sum() == 0:\n",
        "          continue\n",
        "\n",
        "        if (\n",
        "            (str((ix, iy)) in added_bonds)\n",
        "            or (str((iy, ix)) in added_bonds)\n",
        "            or (iy in invalid_idx or ix in invalid_idx)\n",
        "        ):\n",
        "            continue\n",
        "        # add relevant bond type (there are many more of these)\n",
        "\n",
        "        if bond == 0:\n",
        "            continue\n",
        "        elif bond == 1:\n",
        "            bond_type = Chem.rdchem.BondType.SINGLE\n",
        "            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)\n",
        "        elif bond == 2:\n",
        "            bond_type = Chem.rdchem.BondType.DOUBLE\n",
        "            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)\n",
        "        elif bond == 3:\n",
        "            bond_type = Chem.rdchem.BondType.TRIPLE\n",
        "            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)\n",
        "        elif bond == 4:\n",
        "            bond_type = Chem.rdchem.BondType.SINGLE\n",
        "            mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)\n",
        "\n",
        "        added_bonds.add(str((ix, iy)))\n",
        "\n",
        "    if has_pos:\n",
        "        conf = Chem.Conformer(mol.GetNumAtoms())\n",
        "        for i in range(data.pos.size(0)):\n",
        "            if i in invalid_idx:\n",
        "                continue\n",
        "            p = Point3D(\n",
        "                data.pos[i][0].item(),\n",
        "                data.pos[i][1].item(),\n",
        "                data.pos[i][2].item(),\n",
        "            )\n",
        "            conf.SetAtomPosition(node_to_idx[i], p)\n",
        "        conf.SetId(0)\n",
        "        mol.AddConformer(conf)\n",
        "\n",
        "    # Convert RWMol to Mol object\n",
        "    mol = mol.GetMol()\n",
        "    mol_frags = rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)\n",
        "    largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())\n",
        "    return largest_mol\n",
        "\n",
        "\n",
        "def MolTo3DView(mol, size=(300, 300), style=\"stick\", surface=False, opacity=0.5):\n",
        "    \"\"\"Draw molecule in 3D\n",
        "    \n",
        "    Args:\n",
        "    ----\n",
        "        mol: rdMol, molecule to show\n",
        "        size: tuple(int, int), canvas size\n",
        "        style: str, type of drawing molecule\n",
        "               style can be 'line', 'stick', 'sphere', 'carton'\n",
        "        surface, bool, display SAS\n",
        "        opacity, float, opacity of surface, range 0.0-1.0\n",
        "    Return:\n",
        "    ----\n",
        "        viewer: py3Dmol.view, a class for constructing embedded 3Dmol.js views in ipython notebooks.\n",
        "    \"\"\"\n",
        "    assert style in ('line', 'stick', 'sphere', 'carton')\n",
        "\n",
        "    mol = Chem.AddHs(mol)\n",
        "    AllChem.EmbedMolecule(mol)\n",
        "    AllChem.MMFFOptimizeMolecule(mol, maxIters=200)\n",
        "    mblock = Chem.MolToMolBlock(mol)\n",
        "    viewer = py3Dmol.view(width=size[0], height=size[1])\n",
        "    viewer.addModel(mblock, 'mol')\n",
        "    viewer.setStyle({style:{}})\n",
        "    if surface:\n",
        "        viewer.addSurface(py3Dmol.SAS, {'opacity': opacity})\n",
        "    viewer.zoomTo()\n",
        "    return viewer\n",
        "\n",
        "def smi2conf(smiles):\n",
        "    '''Convert SMILES to rdkit.Mol with 3D coordinates'''\n",
        "    mol = Chem.MolFromSmiles(smiles)\n",
        "    if mol is not None:\n",
        "        mol = Chem.AddHs(mol)\n",
        "        AllChem.EmbedMolecule(mol)\n",
        "        AllChem.MMFFOptimizeMolecule(mol, maxIters=200)\n",
        "        return mol\n",
        "    else:\n",
        "        return None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gY7foToFoo8Q"
      },
      "outputs": [],
      "source": [
        "# For storing experimental results over the course of the practical\n",
        "RESULTS = {}\n",
        "DF_RESULTS = pd.DataFrame(columns=[\"Test MAE\", \"Val MAE\", \"Epoch\", \"Model\"])"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "O_e6GlqvRQn9"
      },
      "source": [
        "Great! We are ready to dive into the practical!\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "67UjWJ4RibHt"
      },
      "source": [
        "# 🧪 Part 0: Introduction to Molecular Property Prediction with PyTorch Geometric\n",
        "\n",
        "This section covers the fundamentals. We will study how Graph Neural Networks (GNNs) can be employed for predicting chemical properties of molecules, an impactful real-world application of Geometric Deep Learning. To achieve this, we will first introduce PyTorch Geometric, a widely-used Python library that facilitates the implementation of GNNs.\n",
        "\n",
        "## PyTorch Geometric\n",
        "\n",
        "[PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) (PyG) is an excellent library for graph representation learning research and development:\n",
        "\n",
        "> PyTorch Geometric (PyG) consists of various methods for deep learning on graphs and other irregular structures, also known as Geometric Deep Learning, from a variety of published papers. In addition, it provides easy-to-use mini-batch loaders for operating on many small and single giant graphs, multi GPU-support, distributed graph learning, a large number of common benchmark datasets, and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.\n",
        "\n",
        "In this practical, we will make extensive use of PyG. If you have never worked with PyG before, do not worry, we will provide you with some examples and guide you through all the fundamentals in a detailed manner. We also highly recommend [this self-contained official tutorial](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html), which will help you get started. Among other things, you will learn how to implement state-of-the-art GNN layers via the generic PyG [Message Passing](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html) class (more on this later).\n",
        "\n",
        "Now, let's turn our attention to the problem of predicting molecular properties.\n",
        "\n",
        "## Molecular Property Prediction\n",
        "\n",
        "Molecules are a great example of an object from nature that can easily be represented as a graph of atoms (nodes) connected by bonds (edges). \n",
        "A popular application of GNNs in chemistry is the task of **Molecular Property Prediction**. The goal is to train a GNN model from historical experimental data that can predict useful properties of drug-like molecules. The model's predictions can then be used to guide the drug design process.\n",
        "\n",
        "<!-- ![](https://drive.google.com/uc?id=1Hs6fMSZ6a0WdjKqzbmBME0RYoSwxMaYQ) -->\n",
        "<img src=\"https://github.com/chaitjo/dump/raw/main/molproppred.png\">\n",
        "\n",
        "One famous example of GNNs being used in molecular property prediction is in the world of **antibiotic discovery**, an area with a potentially massive impact on humanity and infamously little innovation. A GNN trained to predict how much a molecule would inhibit a bacteria was able identify the previously overlooked compound [**Halicin**](https://www.wikiwand.com/en/Halicin) (below) during virtual screening. Not only did halicin show powerful results during *in vitro* (in cell) testing but it also had a completely novel mechanism of action that no bacteria has developed resistance to (yet).\n",
        "\n",
        "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Halicin.svg/440px-Halicin.svg.png\" width=\"30%\">\n",
        "\n",
        "## The QM9 Dataset\n",
        "\n",
        "QM9 (Quantum Mechanics dataset 9) is a dataset consisting of about **130,000 small molecules** with 19 regression targets. Since being used by [MoleculeNet](https://arxiv.org/abs/1703.00564), it has become a popular dataset to benchmark new architectures for molecular property prediction.\n",
        "\n",
        "Specifically, we will be predicting the [electric dipole moment](https://en.wikipedia.org/wiki/Electric_dipole_moment) of drug-like molecules. According to Wikipedia:\n",
        "> \"The electric dipole moment is a measure of the separation of positive and negative electrical charges within a system, that is, a measure of the system's overall polarity.\"\n",
        "\n",
        "We can visualize this concept via the water molecule H<sub>2</sub>0, which forms a dipole due to its slightly different distribution of negative (blue) and postive (red) charge.\n",
        "\n",
        "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/1/15/Water-elpot-transparent-3D-balls.png/500px-Water-elpot-transparent-3D-balls.png\" width=\"25%\">\n",
        "\n",
        "You do not need to worry about the exact physical and chemical principles that underpin dipole moments. As you might imagine, writing the equations from first priciples to predict a property like this, espeically for complex molecules (e.g. proteins), is very difficult. All you need know (for the sake of this practical anyway) is that these molecules can be representated as graphs with node and edge features as well as **spatial information** that we can use to train a GNN model using the ground truth labels.\n",
        "\n",
        "Now let us load the QM9 dataset and explore how molecular graphs are represented. PyG makes this extremely convinient.\n",
        "\n",
        "(The dataset may take a few minutes to download.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ptep7GJJR8jJ"
      },
      "outputs": [],
      "source": [
        "if 'IS_GRADESCOPE_ENV' not in os.environ:\n",
        "    path = './qm9'\n",
        "    target = 0\n",
        "\n",
        "    # Transforms which are applied during data loading:\n",
        "    # (1) Fully connect the graphs, (2) Select the target/label\n",
        "    transform = T.Compose([CompleteGraph(), SetTarget()])\n",
        "    \n",
        "    # Load the QM9 dataset with the transforms defined\n",
        "    dataset = QM9(path, transform=transform)\n",
        "\n",
        "    # Normalize targets per data sample to mean = 0 and std = 1.\n",
        "    mean = dataset.data.y.mean(dim=0, keepdim=True)\n",
        "    std = dataset.data.y.std(dim=0, keepdim=True)\n",
        "    dataset.data.y = (dataset.data.y - mean) / std\n",
        "    mean, std = mean[:, target].item(), std[:, target].item()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p98OyJsOZUKn"
      },
      "source": [
        "## Data Preparation and Splitting\n",
        "\n",
        "The QM9 dataset has over **130,000** molecular graphs! \n",
        "\n",
        "Let us create a more tractable sub-set of **3,000** molecular graphs for the purposes of this practical and separate it into training, validation, and test sets. We shall use 1,000 graphs each for training, validation, and testing.\n",
        "\n",
        "Towards the end of this practical, you will get to experiment with the full/larger sub-sets of the QM9 dataset, too."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RFaBnrDoS2K-"
      },
      "outputs": [],
      "source": [
        "print(f\"Total number of samples: {len(dataset)}.\")\n",
        "\n",
        "# Split datasets (in case of using the full dataset)\n",
        "# test_dataset = dataset[:10000]\n",
        "# val_dataset = dataset[10000:20000]\n",
        "# train_dataset = dataset[20000:]\n",
        "\n",
        "# Split datasets (our 3K subset)\n",
        "train_dataset = dataset[:1000]\n",
        "val_dataset = dataset[1000:2000]\n",
        "test_dataset = dataset[2000:3000]\n",
        "print(f\"Created dataset splits with {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} test samples.\")\n",
        "\n",
        "# Create dataloaders with batch size = 32\n",
        "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
        "val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)\n",
        "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qa-fb3-JNDzk"
      },
      "source": [
        "## Visualising Molecular Graphs\n",
        "\n",
        "To get a better understanding of how the QM9 molecular graphs look like, let's visualise a few samples from the training set along with their corresponding target (their dipole moment).\n",
        "\n",
        "In the following plot we visualise **sparse graphs** where edges represent physical connections (i.e. bonds). In this practical, however, we will use **fully-connected graphs** and encode the graph structure in the attributes of each. Later in this practical, we will study the advantages and downsides of both approaches.\n",
        "\n",
        "**❗️Note:** we have implemented some code for you to convert the PyG graph into a Molecule object that can be used by RDKit, a python package for chemistry and visualing molecules. It is not important for you to understand RDKit beyond visualisation purposes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hrg5oQ48MmnB"
      },
      "outputs": [],
      "source": [
        "num_viz = 50\n",
        "mols = [to_rdkit(train_dataset[i]) for i in range(num_viz)]\n",
        "values = [str(round(float(train_dataset[i].y), 3)) for i in range(num_viz)]\n",
        "\n",
        "Chem.Draw.MolsToGridImage(mols, legends=[f\"y = {value}\" for value in values], molsPerRow=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "byQrx71Udlv5"
      },
      "source": [
        "## Understanding PyG Data Objects\n",
        "\n",
        "Each graph in our dataset is encapsulated in a PyG `Data` object, a convient way of representing all structured data for use in Geometric Deep Learning (including graphs, point clouds, and meshes)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wyEVg0zhMOTn"
      },
      "outputs": [],
      "source": [
        "data = train_dataset[0] # one data sample, i.e. molecular graph\n",
        "print(\"Let us print all the attributes (along with their shapes) that our PyG molecular graph contains:\")\n",
        "print(data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tmnIaRv2MP3j"
      },
      "source": [
        "Within an instance of a `Data` object, individual `Torch.Tensor` attributes (or any other variable type) can be easily dot accessed within a neural network layer. The graphs from PyG come with a number of pre-computed features which we describe below (do not worry if you are unfamiliar with the chemistry terms here):\n",
        "\n",
        "**Atom features (`data.x`)** - $\\mathbb{R}^{|V| \\times 11}$\n",
        "- 1st-5th features: Atom type (one-hot: H, C, N, O, F)\n",
        "- 6th feature (also `data.z`): Atomic number (number of protons).\n",
        "- 7th feature: Aromatic (binary)\n",
        "- 8th-10th features: Electron orbital hybridization (one-hot: sp, sp2, sp3)\n",
        "- 11th feature: Number of hydrogens\n",
        "\n",
        "**Edge Index (`data.edge_index`)** - $\\mathbb{R}^{2×|E|}$\n",
        "- A tensor of dimensions 2 x `num_edges` that describe the edge connectivity of the graph\n",
        "\n",
        "**Edge features (`data.edge_attr`)** - $\\mathbb{R}^{|E|\\times 4}$\n",
        "- 1st-4th features: bond type (one-hot: single, double, triple, aromatic)\n",
        "\n",
        "**Atom positions (`data.pos`)** - $\\mathbb{R}^{|V|\\times 3}$\n",
        "- 3D coordinates of each atom . (We will talk about their importance later in the practical.)\n",
        "\n",
        "**Target (`data.y`)** - $\\mathbb{R}^{1}$\n",
        "- A scalar value corresponding to the molecules electric dipole moment\n",
        "\n",
        "**❗️Note:** We will be using **fully-connected graphs** (i.e. all atoms in a molecule are connected to each other, except self-loops). The information about the molecule structures will be available to the models through the edge features (`data.edge_attr`) as follows:\n",
        "- When two atoms are physically connected, the edge attributes indicate the **bond type** (single, double, triple, or aromatic) through a one-hot vector.\n",
        "- When two atoms are not physically connected, **all edge attributes** are **zero**.\n",
        "We will later study the advantages/downsides of fully-connected adjacency matrices versus sparse adjacency matrices (where an edge between two atoms is present only when there exists a physical connection between them)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ng5PwJW9debD"
      },
      "outputs": [],
      "source": [
        "print(f\"\\nThis molecule has {data.x.shape[0]} atoms, and {data.edge_attr.shape[0]} edges.\")\n",
        "\n",
        "print(f\"\\nFor each atom, we are given a feature vector with {data.x.shape[1]} entries (described above).\")\n",
        "\n",
        "print(f\"\\nFor each edge, we are given a feature vector with {data.edge_attr.shape[1]} entries (also described above).\")\n",
        "\n",
        "print(f\"\\nIn the next section, we will learn how to build a GNN in the Message Passing flavor to process the node and edge features of molecular graphs and predict their properties.\")\n",
        "\n",
        "print(f\"\\nEach atom also has a {data.pos.shape[1]}-dimensional coordinate associated with it. We will talk about their importance later in the practical.\")\n",
        "\n",
        "print(f\"\\nFinally, we have {data.y.shape[0]} regression target for the entire molecule.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nDnWN7YvnaYs"
      },
      "source": [
        "## Using PyG for batching\n",
        "\n",
        "As you might remember from the previous practical, **batching graphs** can be quite a tedious and fiddly process. Thankfully, using PyG makes this super simple! Given a list of `Data` objects, we can easily batch this into a PyG `Batch` object as well as unbatch back into a list of graphs. Furthermore, in simple cases like ours, the PyG `DataLoader` object (different from the vanilla PyTorch one) handles all of the batching under the hood for us!\n",
        "\n",
        "Lets quicky batch and unbatch some graphs anyway as a demonstration:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xd3Yn7jOgNbY"
      },
      "outputs": [],
      "source": [
        "# Toy graph 1\n",
        "edge_index_1 = torch.tensor(\n",
        "    [[0, 1, 1, 2], [1, 0, 2, 1]], \n",
        "    dtype=torch.long\n",
        ")\n",
        "x_1 = torch.tensor([[-1], [0], [1]], dtype=torch.float)\n",
        "\n",
        "data_1 = Data(x=x_1, edge_index=edge_index_1)\n",
        "\n",
        "# Toy graph 2\n",
        "edge_index_2 = torch.tensor(\n",
        "    [[0, 2, 1, 0], [2, 0, 0, 1]], \n",
        "    dtype=torch.long\n",
        ")\n",
        "x_2 = torch.tensor([[1], [0], [-1]], dtype=torch.float)\n",
        "\n",
        "data_2 = Data(x=x_2, edge_index=edge_index_2)\n",
        "\n",
        "# Create batch from toy graphs\n",
        "data_list = [data_1, data_2]\n",
        "batch = Batch.from_data_list(data_list)\n",
        "\n",
        "assert (batch[0].x == data_1.x).all() and (batch[1].x == data_2.x).all()\n",
        "\n",
        "# Create DataLoader\n",
        "loader = DataLoader(data_list, batch_size=1, shuffle=False)\n",
        "it = iter(loader)\n",
        "batch_1 = next(it)\n",
        "batch_2 = next(it)\n",
        "\n",
        "assert (batch_1.x == data_1.x).all() and (batch_2.x == data_2.x).all()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zuv-ofkecJgU"
      },
      "source": [
        "Awesome! We have downloaded and prepared the QM9 dataset, visualised some samples, understood the attributes associated with each molecular graph, and reviewed how batching works in PyG. Now, we are ready to understand how we can develop GNNs in PyG for molecular property prediction.\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "563IbrHKihGQ"
      },
      "source": [
        "# 📩 Part 0: Introduction to Message Passing Neural Networks in PyTorch Geometric\n",
        "\n",
        "As a gentle introduction to PyTorch Geometric, we will walk you through the first steps of developing a GNN in the **Message Passing** flavor.\n",
        "\n",
        "<!-- ![](https://drive.google.com/uc?id=1Wdgdq606XW1MelvcU1nW5CxWe1rHsWt1) -->\n",
        "<img src=\"https://github.com/chaitjo/dump/raw/main/gnn-layers.png\">"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TAwoHyWs452X"
      },
      "source": [
        "## Formalism\n",
        "\n",
        "Firstly, let us formalise our molecular property prediction pipeline. (Our notation will mostly follow what has been introduced in the lectures, but we do make some different choices for variable names.)\n",
        "\n",
        "### Graph \n",
        "Consider a molecular graph $\\mathcal{G} = \\left( \\mathcal{V}, \\mathcal{E} \\right)$, where $\\mathcal{V}$ is a set of $n$ nodes, and $\\mathcal{E}$ is a set of edges associated with the nodes. For each node $i \\in \\mathcal{V}$, we are given a $d_n$-dimensional initial feature vector $h_i \\in \\mathbb{R}^{d_n}$.\n",
        "For each edge $(i, j) \\in \\mathcal{E}$, we are given a $d_e$-dimensional initial feature vector $e_{ij} \\in \\mathbb{R}^{d_e}$. For QM9 graphs, $d_n = 11, d_e = 4$.\n",
        "\n",
        "### Label/target \n",
        "Associated with each graph $\\mathcal{G}$ is a scalar target or label $y \\in \\mathbb{R}^{1}$, which we would like to predict. \n",
        "\n",
        "We will design a Message Passing Neural Network for graph property prediction to do this. Our MPNN will consist of several layers of message passing, followed by a global pooling and prediction head.\n",
        "\n",
        "### MPNN Layer \n",
        "The Message Passing operation iteratively updates node features $h_i^{\\ell} \\in \\mathbb{R}^d$ from layer $\\ell$ to layer $\\ell+1$ via the following equation:\n",
        "$$\n",
        "h_i^{\\ell+1} = \\phi \\Bigg( h_i^{\\ell}, \\oplus_{j \\in \\mathcal{N}_i} \\Big( \\psi \\left( h_i^{\\ell}, h_j^{\\ell}, e_{ij} \\right) \\Big) \\Bigg),\n",
        "$$\n",
        "where $\\psi, \\phi$ are Multi-Layer Perceptrons (MLPs), and $\\oplus$ is a permutation-invariant local neighborhood aggregation function such as summation, maximization, or averaging.\n",
        "\n",
        "Let us break down the MPNN layer into three pedagogical steps:\n",
        "- **Step (1): Message.** For each pair of linked nodes $i, j$, the network first computes a message $m_{ij} =  \\psi \\left( h_i^{\\ell}, h_j^{\\ell}, e_{ij} \\right)$. The MLP $\\psi: \\mathbb{R}^{2d + d_e} → \\mathbb{R}^d$ takes as input the concatenation of the feature vectors from the source node, destination node, and edge.\n",
        "    - Note that for the first layer $\\ell=0$, $h_i^{\\ell=0} = W_{in} \\left( h_i \\right)$, where $W_{in} \\in \\mathbb{R}^{d_n}  \\rightarrow \\mathbb{R}^{d}$ is a simple linear projection (`torch.nn.Linear`) for the initial node features to hidden dimension $d$.\n",
        "- **Step (2): Aggregate.** At each node $i$, the incoming messages from all its neighbors are then aggregated as $m_{i} = \\oplus_{{j \\in \\mathcal{N}_i}} \\left( m_{ij} \\right)$, where $\\oplus$ is a permutation-invariant function. We will use summation, i.e. $\\oplus_{{j \\in \\mathcal{N}_i}} = \\sum_{{j \\in \\mathcal{N}_i}}$.\n",
        "- **Step (3): Update.** Finally, the network updates the node feature vector $h_i^{\\ell+1} = \\phi \\left( h_i^{\\ell}, m_i \\right)$, by concatenating the aggregated message $m_i$ and the previous node feature vector $h_i^{\\ell}$, and passing them through an MLP $\\phi: \\mathbb{R}^{2d} → \\mathbb{R}^{d}$.\n",
        "\n",
        "### Global Pooling and Prediction Head\n",
        "After $L$ layers of message passing, we obtain the final node features $h_i^{\\ell=L}$. As we have a single target $y$ per graph, we must pool all node features into a single graph feature or graph embedding $h_G \\in \\mathbb{R}^d$ via another permutation-invariant function $R$, sometimes called the 'readout' function, as follows:\n",
        "$$\n",
        "h_G = R_{i \\in \\mathcal{V}} \\left( h_i^{\\ell=L} \\right). \n",
        "$$ \n",
        "We will use global average pooling over all node features, i.e.\n",
        "$$\n",
        "h_G = \\frac{1}{|\\mathcal{V}|} \\sum_{i \\in \\mathcal{V}} h_i^{\\ell=L}. \n",
        "$$\n",
        "\n",
        "The graph embedding $h_G$ is passed through a linear prediction head $W_{pred} \\in \\mathbb{R}^{d} \\rightarrow \\mathbb{R}^1$ to obtain the overall prediction $\\hat y \\in \\mathbb{R}^1$:\n",
        "$$\n",
        "\\hat y = W_{pred} \\left( h_G \\right).\n",
        "$$ \n",
        "\n",
        "### Loss Function\n",
        "Our MPNN graph property prediction model can be trained end-to-end via minimizing the standard mean-squared error loss for regression:\n",
        "$$\n",
        "\\mathcal{L}_{MSE} = \\lVert y - \\hat y \\rVert^2_2.\n",
        "$$"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2xcV8Yb148Kq"
      },
      "source": [
        "## Coding the basic Message Passing Neural Network Layer\n",
        "\n",
        "We are now ready to define a basic MPNN layer which implements what we have described above. In particular, we will code up the **MPNN Layer** first. (We will code up the other parts subsequently.)\n",
        "\n",
        "To do so, we will inherit from the `MessagePassing` base class, which automatically takes care of message propagation and is extremely useful to develop advanced GNN models. To implement a custom MPNN, the user only needs to define the behaviour of the `message` (i.e. $\\psi$), the `aggregate`(i.e. $\\oplus$), and `update` (i.e. $\\phi$) functions. You may also refer to the [PyG documentation](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html) for implementing custom message passing layers.\n",
        "\n",
        "Below, we provide the implementation of a standard MPNN layer as an example, with extensive inline comments to help you figure out what is going on."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wO5GskvZnZ1j"
      },
      "outputs": [],
      "source": [
        "class MPNNLayer(MessagePassing):\n",
        "    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):\n",
        "        \"\"\"Message Passing Neural Network Layer\n",
        "\n",
        "        Args:\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            aggr: (str) - aggregation function `\\oplus` (sum/mean/max)\n",
        "        \"\"\"\n",
        "        # Set the aggregation function\n",
        "        super().__init__(aggr=aggr)\n",
        "\n",
        "        self.emb_dim = emb_dim\n",
        "        self.edge_dim = edge_dim\n",
        "\n",
        "        # MLP `\\psi` for computing messages `m_ij`\n",
        "        # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU\n",
        "        # dims: (2d + d_e) -> d\n",
        "        self.mlp_msg = Sequential(\n",
        "            Linear(2*emb_dim + edge_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),\n",
        "            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()\n",
        "          )\n",
        "        \n",
        "        # MLP `\\phi` for computing updated node features `h_i^{l+1}`\n",
        "        # Implemented as a stack of Linear->BN->ReLU->Linear->BN->ReLU\n",
        "        # dims: 2d -> d\n",
        "        self.mlp_upd = Sequential(\n",
        "            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), \n",
        "            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()\n",
        "          )\n",
        "\n",
        "    def forward(self, h, edge_index, edge_attr):\n",
        "        \"\"\"\n",
        "        The forward pass updates node features `h` via one round of message passing.\n",
        "\n",
        "        As our MPNNLayer class inherits from the PyG MessagePassing parent class,\n",
        "        we simply need to call the `propagate()` function which starts the \n",
        "        message passing procedure: `message()` -> `aggregate()` -> `update()`.\n",
        "        \n",
        "        The MessagePassing class handles most of the logic for the implementation.\n",
        "        To build custom GNNs, we only need to define our own `message()`, \n",
        "        `aggregate()`, and `update()` functions (defined subsequently).\n",
        "\n",
        "        Args:\n",
        "            h: (n, d) - initial node features\n",
        "            edge_index: (e, 2) - pairs of edges (i, j)\n",
        "            edge_attr: (e, d_e) - edge features\n",
        "\n",
        "        Returns:\n",
        "            out: (n, d) - updated node features\n",
        "        \"\"\"\n",
        "        out = self.propagate(edge_index, h=h, edge_attr=edge_attr)\n",
        "        return out\n",
        "\n",
        "    def message(self, h_i, h_j, edge_attr):\n",
        "        \"\"\"Step (1) Message\n",
        "\n",
        "        The `message()` function constructs messages from source nodes j \n",
        "        to destination nodes i for each edge (i, j) in `edge_index`.\n",
        "\n",
        "        The arguments can be a bit tricky to understand: `message()` can take \n",
        "        any arguments that were initially passed to `propagate`. Additionally, \n",
        "        we can differentiate destination nodes and source nodes by appending \n",
        "        `_i` or `_j` to the variable name, e.g. for the node features `h`, we\n",
        "        can use `h_i` and `h_j`. \n",
        "        \n",
        "        This part is critical to understand as the `message()` function\n",
        "        constructs messages for each edge in the graph. The indexing of the\n",
        "        original node features `h` (or other node variables) is handled under\n",
        "        the hood by PyG.\n",
        "\n",
        "        Args:\n",
        "            h_i: (e, d) - destination node features\n",
        "            h_j: (e, d) - source node features\n",
        "            edge_attr: (e, d_e) - edge features\n",
        "        \n",
        "        Returns:\n",
        "            msg: (e, d) - messages `m_ij` passed through MLP `\\psi`\n",
        "        \"\"\"\n",
        "        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)\n",
        "        return self.mlp_msg(msg)\n",
        "    \n",
        "    def aggregate(self, inputs, index):\n",
        "        \"\"\"Step (2) Aggregate\n",
        "\n",
        "        The `aggregate` function aggregates the messages from neighboring nodes,\n",
        "        according to the chosen aggregation function ('sum' by default).\n",
        "\n",
        "        Args:\n",
        "            inputs: (e, d) - messages `m_ij` from destination to source nodes\n",
        "            index: (e, 1) - list of source nodes for each edge/message in `input`\n",
        "\n",
        "        Returns:\n",
        "            aggr_out: (n, d) - aggregated messages `m_i`\n",
        "        \"\"\"\n",
        "        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)\n",
        "    \n",
        "    def update(self, aggr_out, h):\n",
        "        \"\"\"\n",
        "        Step (3) Update\n",
        "\n",
        "        The `update()` function computes the final node features by combining the \n",
        "        aggregated messages with the initial node features.\n",
        "\n",
        "        `update()` takes the first argument `aggr_out`, the result of `aggregate()`, \n",
        "        as well as any optional arguments that were initially passed to \n",
        "        `propagate()`. E.g. in this case, we additionally pass `h`.\n",
        "\n",
        "        Args:\n",
        "            aggr_out: (n, d) - aggregated messages `m_i`\n",
        "            h: (n, d) - initial node features\n",
        "\n",
        "        Returns:\n",
        "            upd_out: (n, d) - updated node features passed through MLP `\\phi`\n",
        "        \"\"\"\n",
        "        upd_out = torch.cat([h, aggr_out], dim=-1)\n",
        "        return self.mlp_upd(upd_out)\n",
        "\n",
        "    def __repr__(self) -> str:\n",
        "        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OPB24QlmU_46"
      },
      "source": [
        "Great! We have defined a **Message Passing layer** following the equation we had introduced previously. Let us use this layer to code up the full **MPNN graph property prediction model**. This model will take as input molecular graphs, process them via multiple MPNN layers, and predict a single property for each of them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q0vqS1NAU_ZN"
      },
      "outputs": [],
      "source": [
        "class MPNNModel(Module):\n",
        "    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):\n",
        "        \"\"\"Message Passing Neural Network model for graph property prediction\n",
        "\n",
        "        Args:\n",
        "            num_layers: (int) - number of message passing layers `L`\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            in_dim: (int) - initial node feature dimension `d_n`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            out_dim: (int) - output dimension (fixed to 1)\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        \n",
        "        # Linear projection for initial node features\n",
        "        # dim: d_n -> d\n",
        "        self.lin_in = Linear(in_dim, emb_dim)\n",
        "        \n",
        "        # Stack of MPNN layers\n",
        "        self.convs = torch.nn.ModuleList()\n",
        "        for layer in range(num_layers):\n",
        "            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))\n",
        "        \n",
        "        # Global pooling/readout function `R` (mean pooling)\n",
        "        # PyG handles the underlying logic via `global_mean_pool()`\n",
        "        self.pool = global_mean_pool\n",
        "\n",
        "        # Linear prediction head\n",
        "        # dim: d -> out_dim\n",
        "        self.lin_pred = Linear(emb_dim, out_dim)\n",
        "        \n",
        "    def forward(self, data):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            data: (PyG.Data) - batch of PyG graphs\n",
        "\n",
        "        Returns: \n",
        "            out: (batch_size, out_dim) - prediction for each graph\n",
        "        \"\"\"\n",
        "        h = self.lin_in(data.x) # (n, d_n) -> (n, d)\n",
        "        \n",
        "        for conv in self.convs:\n",
        "            h = h + conv(h, data.edge_index, data.edge_attr) # (n, d) -> (n, d)\n",
        "            # Note that we add a residual connection after each MPNN layer\n",
        "\n",
        "        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)\n",
        "\n",
        "        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)\n",
        "\n",
        "        return out.view(-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W4xlvC8bZEv_"
      },
      "source": [
        "Awesome! We are done defining our first MPNN model for graph property prediction.\n",
        "\n",
        "But wait! Before we dive into training and evaluation this model, let us write some sanity checks for a **fundamental property** of the model and the layer.\n",
        "\n",
        "## Unit tests for Permutation Invariance and Equivariance\n",
        "\n",
        "The lectures have repeatedly indicated on certain fundamental properties for machine learning on graphs:\n",
        "- A **GNN <ins>layer**</ins> is **equivariant** to permutations of the set of nodes in the graph; i.e. as we permute the nodes, the node features produced by the GNN must permute accordingly.\n",
        "- A **GNN <ins>model**</ins> for graph-level property prediction is **invariant** to the permutations of the set of nodes in the graph; i.e. as we permute the nodes, the graph-level propery remains unchanged.\n",
        "\n",
        "(But wait...**What is a permutation?** Essentially, it is an **ordering of the nodes** in a graph. In general, there is **no canonical way** of assigning an ordering of the nodes, unlike textual or image data. However, graphs need to be stored and processed on computers in order to perform machine learning on them (which is what this course is about!). Thus, we need to ensure that our models are able to principaly handle this **lack of canonical ordering** or permutation of graph nodes. This is what the above statements are trying to say.)\n",
        "\n",
        "### Formalism\n",
        "\n",
        "Let us try to formalise these notions of permutation invariance and equivariance via matrix notation (it is easier that way).\n",
        "\n",
        "- Let $\\mathbf{H} \\in \\mathbb{R}^{n \\times d}$ be a matrix of node features for a given molecular graph, where $n$ is the number of nodes/atoms and each row $h_i$ is the $d$-dimensional feature for node $i$.\n",
        "- Let $\\mathbf{A} \\in \\mathbb{R}^{n \\times n}$ be the adjacency matrix where each entry denotes $a_{ij}$ the presence or absence of an edge between nodes $i$ and $j$.\n",
        "- Let $\\mathbf{F}(\\mathbf{H}, \\mathbf{A}): \\mathbb{R}^{n \\times d} \\times \\mathbb{R}^{n \\times n} \\rightarrow \\mathbb{R}^{n \\times d}$ be a **GNN <ins>layer**</ins> that takes as input the node features and adjacency matrix, and returns the **updated node features**.\n",
        "- Let $f(\\mathbf{H}, \\mathbf{A}): \\mathbb{R}^{n \\times d} \\times \\mathbb{R}^{n \\times n} \\rightarrow \\mathbb{R}$ be a **GNN <ins>model**</ins> that takes as input the node features and adjacency matrix, and returns the **predicted graph-level property**.\n",
        "- Let $\\mathbf{P} \\in \\mathbb{R}^{n \\times n}$ be a **[permutation matrix](https://en.wikipedia.org/wiki/Permutation_matrix)** which has exactly one 1 in every row and column, and 0s elsewhere. Left-multipying $\\mathbf{P}$ with a matrix changes the ordering of the rows of the matrix.\n",
        "\n",
        "### Permuation Equivariance\n",
        "\n",
        "The GNN <ins>layer</ins> $\\mathbf{F}$ is **permuation equivariant** as follows:\n",
        "$$ \n",
        "\\mathbf{F}(\\mathbf{PH}, \\mathbf{PAP^T}) = \\mathbf{P} \\ \\mathbf{F}(\\mathbf{H}, \\mathbf{A}).\n",
        "$$\n",
        "\n",
        "Another way to formulate the above could be: (1) Consider the updated node features $\\mathbf{H'} = \\mathbf{F}(\\mathbf{H}, \\mathbf{A})$. (2) Applying any permutation matrix $\\mathbf{P}$ to the input of the GNN layer $\\mathbf{F}$ should produce the same result as applying the same permutation on $\\mathbf{H'}$:\n",
        "$$\n",
        "\\mathbf{F}(\\mathbf{PH}, \\mathbf{PAP^T}) = \\mathbf{P} \\ \\mathbf{H'}\n",
        "$$\n",
        "\n",
        "### Permuation Invariance\n",
        "\n",
        "The GNN <ins>model</ins> $f$ for graph-level prediction is **permutation invariant** as follows:\n",
        "$$ \n",
        "f(\\mathbf{PH}, \\mathbf{PAP^T}) = f(\\mathbf{H}, \\mathbf{A}).\n",
        "$$\n",
        "\n",
        "Another way to formulate the above could be: (1) Consider the predicted molecular property $\\mathbf{\\hat y} = f(\\mathbf{H}, \\mathbf{A})$. (2) Applying any permutation matrix $\\mathbf{P}$ to the input of the GNN model $f$ should produce the same result as not applying it:\n",
        "$$ \n",
        "f(\\mathbf{PH}, \\mathbf{PAP^T}) = \\mathbf{\\hat y}.\n",
        "$$\n",
        "\n",
        "With that formalism out of the way, let us write some unit tests to confirm that our `MPNNModel` and `MPNNLayer` are indeed permutation invariant and equivariant, respectively."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vXuzZIqpZVqS"
      },
      "outputs": [],
      "source": [
        "def permute_graph(data, perm):\n",
        "    \"\"\"Helper function for permuting PyG Data object attributes consistently.\n",
        "    \"\"\"\n",
        "    # Permute the node attribute ordering\n",
        "    data.x = data.x[perm]\n",
        "    data.pos = data.pos[perm]\n",
        "    data.z = data.z[perm]\n",
        "    data.batch = data.batch[perm]\n",
        "\n",
        "    # Permute the edge index\n",
        "    adj = to_dense_adj(data.edge_index)\n",
        "    adj = adj[:, perm, :]\n",
        "    adj = adj[:, :, perm]\n",
        "    data.edge_index = dense_to_sparse(adj)[0]\n",
        "\n",
        "    # Note: \n",
        "    # (1) While we originally defined the permutation matrix P as only having \n",
        "    #     entries 0 and 1, its implementation via `perm` uses indexing into \n",
        "    #     torch tensors, instead. \n",
        "    # (2) It is cumbersome to permute the edge_attr, so we set it to constant \n",
        "    #     dummy values. For any experiments beyond unit testing, all GNN models \n",
        "    #     use the original edge_attr.\n",
        "\n",
        "    return data\n",
        "\n",
        "def permutation_invariance_unit_test(module, dataloader):\n",
        "    \"\"\"Unit test for checking whether a module (GNN model) is \n",
        "    permutation invariant.\n",
        "    \"\"\"\n",
        "    it = iter(dataloader)\n",
        "    data = next(it)\n",
        "\n",
        "    # Set edge_attr to dummy values (for simplicity)\n",
        "    data.edge_attr = torch.zeros(data.edge_attr.shape)\n",
        "\n",
        "    # Forward pass on original example\n",
        "    out_1 = module(data)\n",
        "\n",
        "    # Create random permutation\n",
        "    perm = torch.randperm(data.x.shape[0])\n",
        "    data = permute_graph(data, perm)\n",
        "\n",
        "    # Forward pass on permuted example\n",
        "    out_2 = module(data)\n",
        "\n",
        "    # Check whether output varies after applying transformations\n",
        "    return torch.allclose(out_1, out_2, atol=1e-04)\n",
        "\n",
        "\n",
        "def permutation_equivariance_unit_test(module, dataloader):\n",
        "    \"\"\"Unit test for checking whether a module (GNN layer) is \n",
        "    permutation equivariant.\n",
        "    \"\"\"\n",
        "    it = iter(dataloader)\n",
        "    data = next(it)\n",
        "\n",
        "    # Set edge_attr to dummy values (for simplicity)\n",
        "    data.edge_attr = torch.zeros(data.edge_attr.shape)\n",
        "\n",
        "    # Forward pass on original example\n",
        "    out_1 = module(data.x, data.edge_index, data.edge_attr)\n",
        "\n",
        "    # Create random permutation\n",
        "    perm = torch.randperm(data.x.shape[0])\n",
        "    data = permute_graph(data, perm)\n",
        "\n",
        "    # Forward pass on permuted example\n",
        "    out_2 = module(data.x, data.edge_index, data.edge_attr)\n",
        "\n",
        "    # Check whether output varies after applying transformations\n",
        "    return torch.allclose(out_1[perm], out_2, atol=1e-04)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kj_EER0Mg5YQ"
      },
      "source": [
        "Now that we have defined the unit tests for permutation invariance (for the full MPNN model) and permutation equivariance (for the MPNN layer), let us perform the sanity check:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sNAxOxMQkDwf"
      },
      "outputs": [],
      "source": [
        "# Instantiate temporary model, layer, and dataloader for unit testing\n",
        "layer = MPNNLayer(emb_dim=11, edge_dim=4)\n",
        "model = MPNNModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)\n",
        "dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
        "\n",
        "# Permutation invariance unit test for MPNN model\n",
        "print(f\"Is {type(model).__name__} permutation invariant? --> {permutation_invariance_unit_test(model, dataloader)}!\")\n",
        "\n",
        "# Permutation equivariance unit for MPNN layer\n",
        "print(f\"Is {type(layer).__name__} permutation equivariant? --> {permutation_equivariance_unit_test(layer, dataloader)}!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g-cmASemh0wA"
      },
      "source": [
        "## Training and Evaluating Models\n",
        "\n",
        "Great! We are finally ready to train and evaluate our model on QM9. We have provided a **basic experiment loop** which takes as input the model and dataloaders, performs training, and returns the final performance on the **validation** and **test set**.\n",
        "\n",
        "We will be training a `MPNNModel` consisting of 4 layers of message passing with a hidden dimension of 64."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FrYb8xr5iZQM"
      },
      "outputs": [],
      "source": [
        "#@title [RUN] Helper functions for managing experiments, training, and evaluating models.\n",
        "\n",
        "def train(model, train_loader, optimizer, device):\n",
        "    model.train()\n",
        "    loss_all = 0\n",
        "\n",
        "    for data in train_loader:\n",
        "        data = data.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        y_pred = model(data)\n",
        "        loss = F.mse_loss(y_pred, data.y)\n",
        "        loss.backward()\n",
        "        loss_all += loss.item() * data.num_graphs\n",
        "        optimizer.step()\n",
        "    return loss_all / len(train_loader.dataset)\n",
        "\n",
        "\n",
        "def eval(model, loader, device):\n",
        "    model.eval()\n",
        "    error = 0\n",
        "\n",
        "    for data in loader:\n",
        "        data = data.to(device)\n",
        "        with torch.no_grad():\n",
        "            y_pred = model(data)\n",
        "            # Mean Absolute Error using std (computed when preparing data)\n",
        "            error += (y_pred * std - data.y * std).abs().sum().item()\n",
        "    return error / len(loader.dataset)\n",
        "\n",
        "\n",
        "def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100):\n",
        "    \n",
        "    print(f\"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.\")\n",
        "    \n",
        "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "    print(\"\\nModel architecture:\")\n",
        "    print(model)\n",
        "    total_param = 0\n",
        "    for param in model.parameters():\n",
        "        total_param += np.prod(list(param.data.size()))\n",
        "    print(f'Total parameters: {total_param}')\n",
        "    model = model.to(device)\n",
        "\n",
        "    # Adam optimizer with LR 1e-3\n",
        "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
        "\n",
        "    # LR scheduler which decays LR when validation metric doesn't improve\n",
        "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
        "        optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)\n",
        "    \n",
        "    print(\"\\nStart training:\")\n",
        "    best_val_error = None\n",
        "    perf_per_epoch = [] # Track Test/Val MAE vs. epoch (for plotting)\n",
        "    t = time.time()\n",
        "    for epoch in range(1, n_epochs+1):\n",
        "        # Call LR scheduler at start of each epoch\n",
        "        lr = scheduler.optimizer.param_groups[0]['lr']\n",
        "\n",
        "        # Train model for one epoch, return avg. training loss\n",
        "        loss = train(model, train_loader, optimizer, device)\n",
        "        \n",
        "        # Evaluate model on validation set\n",
        "        val_error = eval(model, val_loader, device)\n",
        "        \n",
        "        if best_val_error is None or val_error <= best_val_error:\n",
        "            # Evaluate model on test set if validation metric improves\n",
        "            test_error = eval(model, test_loader, device)\n",
        "            best_val_error = val_error\n",
        "\n",
        "        if epoch % 10 == 0:\n",
        "            # Print and track stats every 10 epochs\n",
        "            print(f'Epoch: {epoch:03d}, LR: {lr:5f}, Loss: {loss:.7f}, '\n",
        "                  f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}')\n",
        "        \n",
        "        scheduler.step(val_error)\n",
        "        perf_per_epoch.append((test_error, val_error, epoch, model_name))\n",
        "    \n",
        "    t = time.time() - t\n",
        "    train_time = t/60\n",
        "    print(f\"\\nDone! Training took {train_time:.2f} mins. Best validation MAE: {best_val_error:.7f}, corresponding test MAE: {test_error:.7f}.\")\n",
        "    \n",
        "    return best_val_error, test_error, train_time, perf_per_epoch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JsGM5LWmiZQM"
      },
      "outputs": [],
      "source": [
        "model = MPNNModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)\n",
        "model_name = type(model).__name__\n",
        "best_val_error, test_error, train_time, perf_per_epoch = run_experiment(\n",
        "    model, \n",
        "    model_name, \n",
        "    train_loader,\n",
        "    val_loader, \n",
        "    test_loader,\n",
        "    n_epochs=100\n",
        ")\n",
        "RESULTS[model_name] = (best_val_error, test_error, train_time)\n",
        "df_temp = pd.DataFrame(perf_per_epoch, columns=[\"Test MAE\", \"Val MAE\", \"Epoch\", \"Model\"])\n",
        "DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e4NWM5CbptmE"
      },
      "outputs": [],
      "source": [
        "RESULTS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v7rtvD0zvmpF"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Val MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 2));"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r1pjr7brqKCz"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Test MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 1));"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SQjp21t7Z2UF"
      },
      "source": [
        "Super! Everything up to this point has already been covered in the lectures, and we hope that the practical so far has been a useful recap along with the acompanying code.\n",
        "\n",
        "Now for the fun part, where you will be required to think what you have studied so far!\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bReLCZx9teYT"
      },
      "source": [
        "# 🧊 Part 1: Geometric Graphs and Message Passing with 3D Coordinates\n",
        "\n",
        "Remember that we were given **3D coordinates** with each atom in our molecular graph?\n",
        "\n",
        "Molecular graphs, and other structured data occurring in nature, do not simply exist on flat planes. Instead, molecules have an **inherent 3D structure** that influences their properties and functions.\n",
        "\n",
        "Let us visualize a molecule from QM9 in all of its 3D glory! \n",
        "\n",
        "Go ahead and try move this molecule with your mouse cursor!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9yFToebQcxo4"
      },
      "outputs": [],
      "source": [
        "MolTo3DView(smi2conf(Chem.MolToSmiles(to_rdkit(train_dataset[48]))))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzDu7IH2H3qI"
      },
      "source": [
        "## 💻**Task 1.1:** Develop a Message Passing Neural Network that incorporates the atom coordinates as node features **(0.5 Marks)**.\n",
        "\n",
        "\n",
        "Our initial and somewhat **'vanilla' MPNN** `MPNNModel` ignores the atom coordiantes and only uses the node features to perform message passing. This means that the model is **not** leveraging useful **3D structural information** to predict the target property.\n",
        "\n",
        "Your first task is to modify the original `MPNNModel` to incorporate **atom coordinates** into the **node features**.\n",
        "\n",
        "We have defined most of the new `CoordMPNNModel` class for you, and you have to fill in the `YOUR CODE HERE` sections.\n",
        "\n",
        "🤔 *Hint: As reminder, the 3D atom positions are stored in `data.pos`. You don't have to do something very smart right now (that will come later). A **simple** solution is okay to get started, e.g. concatenation or summation.*\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nZu6pELwvph9"
      },
      "outputs": [],
      "source": [
        "class CoordMPNNModel(MPNNModel):\n",
        "    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):\n",
        "        \"\"\"Message Passing Neural Network model for graph property prediction\n",
        "\n",
        "        This model uses both node features and coordinates as inputs.\n",
        "\n",
        "        Args:\n",
        "            num_layers: (int) - number of message passing layers `L`\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            in_dim: (int) - initial node feature dimension `d_n`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            out_dim: (int) - output dimension (fixed to 1)\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        \n",
        "        # ============ YOUR CODE HERE ==============\n",
        "        # Adapt the input linear layer or add new input layers \n",
        "        # to account for the atom positions.\n",
        "        #\n",
        "        # Linear projection for initial node features and coordinates\n",
        "        # dim: ??? -> d\n",
        "        # self.lin_in = ...\n",
        "        # ==========================================\n",
        "        \n",
        "        # Stack of MPNN layers\n",
        "        self.convs = torch.nn.ModuleList()\n",
        "        for layer in range(num_layers):\n",
        "            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))\n",
        "        \n",
        "        # Global pooling/readout function `R` (mean pooling)\n",
        "        # PyG handles the underlying logic via `global_mean_pool()`\n",
        "        self.pool = global_mean_pool\n",
        "\n",
        "        # Linear prediction head\n",
        "        # dim: d -> out_dim\n",
        "        self.lin_pred = Linear(emb_dim, out_dim)\n",
        "        \n",
        "    def forward(self, data):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            data: (PyG.Data) - batch of PyG graphs\n",
        "\n",
        "        Returns: \n",
        "            out: (batch_size, out_dim) - prediction for each graph\n",
        "        \"\"\"\n",
        "        # ============ YOUR CODE HERE ==============\n",
        "        # Incorporate the atom positions along with the features.\n",
        "        #\n",
        "        # h = ...\n",
        "        # ==========================================\n",
        "        \n",
        "        for conv in self.convs:\n",
        "            h = h + conv(h, data.edge_index, data.edge_attr) # (n, d) -> (n, d)\n",
        "            # Note that we add a residual connection after each MPNN layer\n",
        "\n",
        "        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)\n",
        "\n",
        "        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)\n",
        "\n",
        "        return out.view(-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2GjBznG2x96s"
      },
      "source": [
        "## 💻**Task 1.2:** Test the permutation invariance and equivariance properties of your new `CoordMPNNModel` with node features and coordinates, as well as the constituent `MPNNLayer`. **(0.5 Marks)**\n",
        "\n",
        "Super! You have successfully implemented an MPNN which utilises both the **atom features** as well as **coordinates** to predict molecular properties. \n",
        "\n",
        "Before we evaluate it, let us once again run the permutation sanity checks again to make sure the model and layer have the desired properties that constitute every basic GNN:\n",
        "- The `MPNNLayer` should be permutation equivariant (we have already shown this previously, but we want you to repeat the exercise in order to **thoroughly** understand it).\n",
        "- The `CoordMPNNModel` should be permutation invariant.\n",
        "\n",
        "Your task is to fill in the `YOUR CODE HERE` sections to run the required unit tests. You do not need to write new unit tests yet, the ones we defined previously can be re-used."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "34Besg_qKo7A"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate temporary model, layer, and dataloader for unit testing.\n",
        "# Remember that we are now unit testing the CoordMPNNModel, which is different\n",
        "# than the previous model but still composed of the MPNNLayer.\n",
        "#\n",
        "# layer = ...\n",
        "# model = ...\n",
        "# ==========================================\n",
        "dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
        "\n",
        "# Permutation invariance unit test for MPNN model\n",
        "print(f\"Is {type(model).__name__} permutation invariant? --> {permutation_invariance_unit_test(model, dataloader)}!\")\n",
        "\n",
        "# Permutation equivariance unit for MPNN layer\n",
        "print(f\"Is {type(layer).__name__} permutation equivariant? --> {permutation_equivariance_unit_test(layer, dataloader)}!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V03t1f51tW2w"
      },
      "source": [
        "## 💻**Task 1.3.** Prove that your new `CoordMPNNModel` is invariant to permutations of both the node features as well as node coordinates. **(0.5 Marks)**\n",
        "\n",
        "🤔 *Hint: We are looking for simple statements that follow how we formalised permuation invariance for the vanilla MPNN model. We expect you to be copy-pasting most of the formalism and accounting for how your MPNN incorporates both the node features and coordinates. You can additionally introduce $\\mathbf{X} \\in \\mathbb{R}^{n \\times 3}$ as the matrix of node coordinates for a given molecular graph.*\n",
        "\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TtVdT-uoP5cn"
      },
      "source": [
        "## 💻**Task 1.4.** Train and evaluate your `CoordMPNNModel` with node features and coordinates on QM9. **(0.5 Marks)**\n",
        "\n",
        "Awesome! You are now ready to train and evaluate our new MPNN with node features and coordinates on QM9.\n",
        "\n",
        "Re-use the experiment loop we have provided and fill in the `YOUR CODE HERE` sections to run the experiment.\n",
        "\n",
        "You will be training a `CoordMPNNModel` consisting of 4 layers of message passing with a hidden dimension of 64, in order to compare your result fairly to the previous vanilla `MPNNModel`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4NCZinUVPozM"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate your CoordMPNNModel with the appropriate arguments.\n",
        "#\n",
        "# model = CoordMPNNModel(...)\n",
        "# ==========================================\n",
        "\n",
        "model_name = type(model).__name__\n",
        "best_val_error, test_error, train_time, perf_per_epoch = run_experiment(\n",
        "    model, \n",
        "    model_name, # \"MPNN w/ Features and Coordinates\", \n",
        "    train_loader,\n",
        "    val_loader, \n",
        "    test_loader,\n",
        "    n_epochs=100\n",
        ")\n",
        "\n",
        "RESULTS[model_name] = (best_val_error, test_error, train_time)\n",
        "df_temp = pd.DataFrame(perf_per_epoch, columns=[\"Test MAE\", \"Val MAE\", \"Epoch\", \"Model\"])\n",
        "DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jNU5ISpKsHOR"
      },
      "outputs": [],
      "source": [
        "RESULTS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rbTC_wWPr_qW"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Val MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 2));"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PC5Cd0FKwxoD"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Test MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 1));"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f_2h-Ki6RKSX"
      },
      "source": [
        "Hmm... If you've implemented the `CoordMPNNModel` correctly up till now, you may see a very curious result -- the performance of `CoordMPNNModel` is about equal or marginally worse than the vanilla `MPNNModel`!\n",
        "\n",
        "<font color='red'>This is because the `CoordMPNNModel` is not using 3D structural information in a principled manner.</font>\n",
        "\n",
        "The next sections will help us formalise and understand why this is happening.\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "89EPWJ8fGwHe"
      },
      "source": [
        "# 🔄 Part 2: Invariance to 3D Symmetries: Rotation and Translation\n",
        "\n",
        "We saw that the performance of `CoordMPNNModel` is unexpectedly mediocre compared to `MPNNModel` despite using both node features and coordinates. (But please do not panic if your results say otherwise.) In order to determine why, we must understand the concept of **3D symmetries**.\n",
        "\n",
        "### Geometric Invariance\n",
        "\n",
        "Recall that molecular graphs have 3D coordinates for each atom. A key detail which we have purposely withheld from you up till this point (😈) is that these 3D coordinates are **not  inherently fixed** or **permanent**. Instead, they were **experimentally determined** relative to a **frame of reference**.\n",
        "\n",
        "To fully grasp these statements, here is GIF of a drug-like molecules moving around in 3D space...\n",
        "\n",
        "<!-- ![](https://drive.google.com/uc?id=1QcQcF91TD-CTKFaR4NN8YyXbtSTcGny8) -->\n",
        "<img src=\"https://github.com/chaitjo/dump/raw/main/3d-molecule-moving.gif\">\n",
        "\n",
        "The atoms' 3D coordinates are constantly **rotating** and **translating**. However, the **properties** of this molecule will always remain the same no matter how we rotate or translate it. In other words, the molecule's properties are **invariant** to 3D rotations and translations.\n",
        "\n",
        "In this block we will study how to design GNN layers and models that respect these regularities.\n",
        "\n",
        "### Formalism\n",
        "\n",
        "Let us try to formalise the notion of invariance to 3D rotations and translations in GNNs via matrix notation.\n",
        "\n",
        "- Let $\\mathbf{H} \\in \\mathbb{R}^{n \\times d}$ be a matrix of node features for a given molecular graph, where $n$ is the number of nodes/atoms and each row $h_i$ is the $d$-dimensional feature for node $i$.\n",
        "- Let $\\mathbf{X} \\in \\mathbb{R}^{n \\times 3}$ be a matrix of node coordinates for a given molecular graph, where $n$ is the number of nodes/atoms and each row $x_i$ is the 3D coordinate for node $i$.\n",
        "- Let $\\mathbf{A} \\in \\mathbb{R}^{n \\times n}$ be the adjacency matrix where each entry denotes $a_{ij}$ the presence or absence of an edge between nodes $i$ and $j$.\n",
        "- Let $\\mathbf{F}(\\mathbf{H}, \\mathbf{X}, \\mathbf{A}): \\mathbb{R}^{n \\times d} \\times \\mathbb{R}^{n \\times 3} \\times \\mathbb{R}^{n \\times n} \\rightarrow \\mathbb{R}^{n \\times d}$ be a **GNN <ins>layer**</ins> that takes as input the node features, node coordinates, and adjacency matrix, and returns the **updated node features**.\n",
        "- Let $f(\\mathbf{H}, \\mathbf{X}, \\mathbf{A}): \\mathbb{R}^{n \\times d} \\times \\mathbb{R}^{n \\times 3} \\times \\mathbb{R}^{n \\times n} \\rightarrow \\mathbb{R}$ be a **GNN <ins>model**</ins> that takes as input the node features, node coordinates, and adjacency matrix, and returns the **predicted graph-level property**.\n",
        "\n",
        "(Notice that we have updated the notation for the GNN layer $\\mathbf{F}$ and GNN model $\\mathbf{f}$ to include the matrix of node coordinates $\\mathbf{X}$ as an additional input.) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0fZTuGyVNC6n"
      },
      "source": [
        "## 💻**Task 2.1:** What does it mean for the GNN <ins>model</ins> $f$ and the GNN <ins>layer</ins> $\\mathbf{F}$ to be invariant to 3D rotations and translations? Express this _mathematically_ using the definitions above. **(0.5 Mark)**\n",
        "\n",
        "🤔 *Hint: Revisit the formalisms for permutation invariance and equivariance to get an idea of how to go about this. You should use the matrix notation we have provided above. Similar to the permuatation matrix $\\mathbf{P}$, you may now define an orthogonal [**rotation matrix**](https://en.wikipedia.org/wiki/Rotation_matrix) $\\mathbf{Q} \\in \\mathbb{R}^{3 \\times 3}$ and a [**translation vector**](https://en.wikipedia.org/wiki/Translation_(geometry)) $\\mathbf{t} \\in \\mathbb{R}^3$ in your answer. These would operate on the matrix of node coordinates $\\mathbf{X} \\in \\mathbb{R}^{n \\times 3}$*.\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kY6J_vL3hgCN"
      },
      "source": [
        "Before you start coding up a more principled MPNN model, we would like you to take a moment to think about why invariance to 3D rotations and translations is something desirable for GNNs predicting molecular properties...\n",
        "\n",
        "## 💻**Task 2.2:** Is invariance to 3D rotations and translations a desirable property for GNNs? Explain why. **(0.5 Marks)**\n",
        "\n",
        "🤔 *Hint: We are not looking for an essay, a few sentences will suffice here.*\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yfOU10W1O729"
      },
      "source": [
        "## 💻**Task 2.3:** Write the unit test to check your `CoordMPNNModel` for 3D rotation and translation invariance. **(0.5 Mark)**\n",
        "\n",
        "\n",
        "🤔 *Hint: Show that the output of the model varies when:*\n",
        "1. All the atom coordinates in `data.pos` are multiplied by any random _orthogonal_ rotation matrix $Q \\in \\mathbb{R}^{3 \\times 3}$. (We have provided a helper function for creating rotation matrices.)\n",
        "2. All the atom coordinates in `data.pos` are displaced by any random translation vector $\\mathbf{t} \\in \\mathbb{R}^3$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C017NYCReaRy"
      },
      "outputs": [],
      "source": [
        "def random_orthogonal_matrix(dim=3):\n",
        "  \"\"\"Helper function to build a random orthogonal matrix of shape (dim, dim)\n",
        "  \"\"\"\n",
        "  Q = torch.tensor(ortho_group.rvs(dim=dim)).float()\n",
        "  return Q\n",
        "\n",
        "\n",
        "def rot_trans_invariance_unit_test(module, dataloader):\n",
        "    \"\"\"Unit test for checking whether a module (GNN model/layer) is \n",
        "    rotation and translation invariant.\n",
        "    \"\"\"\n",
        "    it = iter(dataloader)\n",
        "    data = next(it)\n",
        "\n",
        "    # Forward pass on original example\n",
        "    # Note: We have written a conditional forward pass so that the same unit\n",
        "    #       test can be used for both the GNN model as well as the layer.\n",
        "    #       The functionality for layers will be useful subsequently. \n",
        "    if isinstance(module, MPNNModel):\n",
        "        out_1 = module(data)\n",
        "    else: # if ininstance(module, MessagePassing):\n",
        "        out_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)\n",
        "\n",
        "    Q = random_orthogonal_matrix(dim=3)\n",
        "    t = torch.rand(3)\n",
        "    # ============ YOUR CODE HERE ==============\n",
        "    # Perform random rotation + translation on data.\n",
        "    #\n",
        "    # data.pos = ...\n",
        "    # ==========================================\n",
        "\n",
        "    # Forward pass on rotated + translated example\n",
        "    if isinstance(module, MPNNModel):\n",
        "        out_2 = module(data)\n",
        "    else: # if ininstance(module, MessagePassing):\n",
        "        out_2 = module(data.x, data.pos, data.edge_index, data.edge_attr)\n",
        "    \n",
        "    # ============ YOUR CODE HERE ==============\n",
        "    # Check whether output varies after applying transformations.\n",
        "    #\n",
        "    # return torch.allclose(..., atol=1e-04)\n",
        "    # =========================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hxVHYY_ahxrX"
      },
      "source": [
        "Now that you have defined the unit tests for rotation and translation invariance, perform the sanity check on your `CoordMPNNModel`:\n",
        "\n",
        "(Spoiler alert: if you have implemented things as expected, the unit test should return `False` for the `CoordMPNNModel`.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U2Gv4m3ahcya"
      },
      "outputs": [],
      "source": [
        "# Instantiate temporary model, layer, and dataloader for unit testing\n",
        "model = CoordMPNNModel(num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1)\n",
        "dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
        "\n",
        "# Rotation and translation invariance unit test for MPNN model\n",
        "print(f\"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lYqLfZpFj0Pf"
      },
      "source": [
        "In this part, you have formalised how a GNN can be 3D rotation and translation invariant, thought about why this is desirable for molecular property prediction, and shown that the `CoordMPNNModel` was not rotation and translation invariant.\n",
        "\n",
        "At this point, you should have a concrete understanding of why the performance of `CoordMPNNModel` is equal or worse than the vanilla `MPNNModel`, and what we meant by our initial statement before we began this part: \n",
        ">\"The `CoordMPNNModel` is not using 3D structural information in a principled manner\" \n",
        "\n",
        "Let us try fixing this in the next part!\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "byXbNG9lGizB"
      },
      "source": [
        "# ✈️ Part 3: Message Passing with Invariance to 3D Rotations and Translations\n",
        "\n",
        "This section will dive into how we may design GNN models which operate on graphs with 3D coordinates in a more theoretically sound way."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BWceJ74iH7lV"
      },
      "source": [
        "## 💻**Task 3.1:** Design a new Message Passing Layer as well as the accompanying MPNN Model that are both <ins>invariant</ins> to 3D rotations and translations. **(2 Marks)**\n",
        "\n",
        "**❗️ Note:** There is no single correct answer to this question.\n",
        "\n",
        "Our initial **'vanilla' MPNN** `MPNNModel` and `MPNNLayer` ignored the atom coordiantes and only uses the node features to perform message passing. This means that the model was **not** leveraging **3D structural information** to predict the target property.\n",
        "\n",
        "Our second **'naive' coordinate MPNN** `CoordMPNNModel` used the node features along with the atom coordinates in an unprincipled manner, resulting in the model not being invariant to 3D rotations and translations of the coordinates (which was a desirable property, as we saw in the previous part).\n",
        "\n",
        "Your task is to define a new `InvariantMPNNLayer` which utilise both **atom coordinates** and **node features**.\n",
        "\n",
        "We have defined most of the new `InvariantMPNNLayer`, and you have to fill in the `YOUR CODE HERE` sections. We have also already defined the `InvariantMPNNModel` that instantiates your new layer to compose the model. You only need to define the new layer.\n",
        "\n",
        "🤔 *Hint 1: Unlike the previous `CoordMPNNModel`, we would suggest using the coodinate information to constuct the messages as opposed to incorporating it into the node features. In particular, we would like you to think about **how** to use the coordinates in a principled manner to constuct the messages: What is a measurement that we can computer using a pair of coordinates that will be invariant to rotating and translating them?*\n",
        "\n",
        "🤔 *Hint 2:  tensors passed to `propagate()` can be mapped to the respective nodes  and  by appending `_i` or `_j` to the variable name, e.g. `h_i` and `h_j` for the node features `h`. Note that we generally refer to `_i` as the central nodes that aggregates information, and refer to `_j` as the neighboring nodes.*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cZ2f0wOEK-Qx"
      },
      "outputs": [],
      "source": [
        "class InvariantMPNNLayer(MessagePassing):\n",
        "    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):\n",
        "        \"\"\"Message Passing Neural Network Layer\n",
        "\n",
        "        This layer is invariant to 3D rotations and translations.\n",
        "\n",
        "        Args:\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            aggr: (str) - aggregation function `\\oplus` (sum/mean/max)\n",
        "        \"\"\"\n",
        "        # Set the aggregation function\n",
        "        super().__init__(aggr=aggr)\n",
        "\n",
        "        self.emb_dim = emb_dim\n",
        "        self.edge_dim = edge_dim\n",
        "\n",
        "        # ============ YOUR CODE HERE ==============\n",
        "        # MLP `\\psi` for computing messages `m_ij`\n",
        "        # dims: (???) -> d\n",
        "        #\n",
        "        # self.mlp_msg = Sequential(...)\n",
        "        # ==========================================\n",
        "        \n",
        "        # MLP `\\phi` for computing updated node features `h_i^{l+1}`\n",
        "        # dims: 2d -> d\n",
        "        self.mlp_upd = Sequential(\n",
        "            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), \n",
        "            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()\n",
        "          )\n",
        "\n",
        "    def forward(self, h, pos, edge_index, edge_attr):\n",
        "        \"\"\"\n",
        "        The forward pass updates node features `h` via one round of message passing.\n",
        "\n",
        "        Args:\n",
        "            h: (n, d) - initial node features\n",
        "            pos: (n, 3) - initial node coordinates\n",
        "            edge_index: (e, 2) - pairs of edges (i, j)\n",
        "            edge_attr: (e, d_e) - edge features\n",
        "\n",
        "        Returns:\n",
        "            out: (n, d) - updated node features\n",
        "        \"\"\"\n",
        "        # ============ YOUR CODE HERE ==============\n",
        "        # Notice that the `forward()` function has a new argument \n",
        "        # `pos` denoting the initial node coordinates. Your task is\n",
        "        # to update the `propagate()` function in order to pass `pos`\n",
        "        # to the `message()` function along with the other arguments.\n",
        "        #\n",
        "        # out = self.propagate(...)\n",
        "        # return out\n",
        "        # ==========================================\n",
        "\n",
        "    # ============ YOUR CODE HERE ==============\n",
        "    # Write a custom `message()` function that takes as arguments the\n",
        "    # source and destination node features, node coordiantes, and `edge_attr`.\n",
        "    # Incorporate the coordinates `pos` into the message computation such\n",
        "    # that the messages are invariant to rotations and translations.\n",
        "    # This will ensure that the overall layer is also invariant.\n",
        "    #\n",
        "    # def message(self, ...):\n",
        "    # \"\"\"The `message()` function constructs messages from source nodes j \n",
        "    #    to destination nodes i for each edge (i, j) in `edge_index`.\n",
        "    #\n",
        "    #    Args:\n",
        "    #        ...\n",
        "    #    \n",
        "    #    Returns:\n",
        "    #        ...\n",
        "    # \"\"\"\n",
        "    #   ...  \n",
        "    #   msg = ...\n",
        "    #   return self.mlp_msg(msg)\n",
        "    # ==========================================\n",
        "    \n",
        "    def aggregate(self, inputs, index):\n",
        "        \"\"\"The `aggregate` function aggregates the messages from neighboring nodes,\n",
        "        according to the chosen aggregation function ('sum' by default).\n",
        "\n",
        "        Args:\n",
        "            inputs: (e, d) - messages `m_ij` from destination to source nodes\n",
        "            index: (e, 1) - list of source nodes for each edge/message in `input`\n",
        "\n",
        "        Returns:\n",
        "            aggr_out: (n, d) - aggregated messages `m_i`\n",
        "        \"\"\"\n",
        "        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)\n",
        "    \n",
        "    def update(self, aggr_out, h):\n",
        "        \"\"\"The `update()` function computes the final node features by combining the \n",
        "        aggregated messages with the initial node features.\n",
        "\n",
        "        Args:\n",
        "            aggr_out: (n, d) - aggregated messages `m_i`\n",
        "            h: (n, d) - initial node features\n",
        "\n",
        "        Returns:\n",
        "            upd_out: (n, d) - updated node features passed through MLP `\\phi`\n",
        "        \"\"\"\n",
        "        upd_out = torch.cat([h, aggr_out], dim=-1)\n",
        "        return self.mlp_upd(upd_out)\n",
        "\n",
        "    def __repr__(self) -> str:\n",
        "        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')\n",
        "\n",
        "\n",
        "class InvariantMPNNModel(MPNNModel):\n",
        "    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):\n",
        "        \"\"\"Message Passing Neural Network model for graph property prediction\n",
        "\n",
        "        This model uses both node features and coordinates as inputs, and\n",
        "        is invariant to 3D rotations and translations.\n",
        "\n",
        "        Args:\n",
        "            num_layers: (int) - number of message passing layers `L`\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            in_dim: (int) - initial node feature dimension `d_n`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            out_dim: (int) - output dimension (fixed to 1)\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        \n",
        "        # Linear projection for initial node features\n",
        "        # dim: d_n -> d\n",
        "        self.lin_in = Linear(in_dim, emb_dim)\n",
        "        \n",
        "        # Stack of invariant MPNN layers\n",
        "        self.convs = torch.nn.ModuleList()\n",
        "        for layer in range(num_layers):\n",
        "            self.convs.append(InvariantMPNNLayer(emb_dim, edge_dim, aggr='add'))\n",
        "        \n",
        "        # Global pooling/readout function `R` (mean pooling)\n",
        "        # PyG handles the underlying logic via `global_mean_pool()`\n",
        "        self.pool = global_mean_pool\n",
        "\n",
        "        # Linear prediction head\n",
        "        # dim: d -> out_dim\n",
        "        self.lin_pred = Linear(emb_dim, out_dim)\n",
        "        \n",
        "    def forward(self, data):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            data: (PyG.Data) - batch of PyG graphs\n",
        "\n",
        "        Returns: \n",
        "            out: (batch_size, out_dim) - prediction for each graph\n",
        "        \"\"\"\n",
        "        h = self.lin_in(data.x) # (n, d_n) -> (n, d)\n",
        "        \n",
        "        for conv in self.convs:\n",
        "            h = h + conv(h, data.pos, data.edge_index, data.edge_attr) # (n, d) -> (n, d)\n",
        "            # Note that we add a residual connection after each MPNN layer\n",
        "\n",
        "        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)\n",
        "\n",
        "        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)\n",
        "\n",
        "        return out.view(-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wqcJy1HtPgj6"
      },
      "source": [
        "Super! You have now defined a more geometrically principled message passing layer and used it to construct an MPNN model with is invariant to 3D rotations and translations.\n",
        "\n",
        "## 💻**Task 3.2:** Write down the update equation of your new `InvariantMPNNLayer` and use that to prove that the layer and model are invariant to 3D rotations and translations. **(1 Mark)**\n",
        "\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CslS3yJBR9jl"
      },
      "source": [
        "Great! You have successfully written the update equation for your new `InvariantMPNNLayer` and shown how it is indeed invariant to 3D rotations and translations.\n",
        "\n",
        "Let us just perform some sanity checks to verify this.\n",
        "\n",
        "## 💻**Task 3.3:** Perform unit tests for your `InvariantMPNNLayer` and `InvariantMPNNModel`. Show that the layer and model are both invariant to 3D rotations and translations. **(0.5 Mark)**\n",
        "\n",
        "🤔 *Hint: Run the unit tests defined previously.*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "epMyUUnaSolt"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate temporary model, layer, and dataloader for unit testing.\n",
        "# Remember that we are now unit testing the InvariantMPNNModel, \n",
        "# which is  composed of the InvariantMPNNLayer.\n",
        "#\n",
        "# layer = ...\n",
        "# model = ...\n",
        "# ==========================================\n",
        "dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
        "\n",
        "# Rotation and translation invariance unit test for MPNN model\n",
        "print(f\"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!\")\n",
        "\n",
        "# Rotation and translation invariance unit test for MPNN layer\n",
        "print(f\"Is {type(layer).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(layer, dataloader)}!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qx5XBblEryId"
      },
      "source": [
        "Good job! You have defined the `InvariantMPNNLayer` and `InvariantMPNNModel`, after which you have proved and experimentally verified their invariance to 3D rotations and translations.\n",
        "\n",
        "It is finally time to run an experiment with our geometrically principled model!\n",
        "\n",
        "## 💻**Task 3.4:** Train and evaluate your `InvariantMPNNModel`. Additionally, provide a few sentences explaining the model's results compared to the basic `MPNNModel` and the naive `CoordMPNNModel` defined previously. Is the new model better? By a significant margin or only minorly better? **(0.5 Mark)**\n",
        "\n",
        "Re-use the experiment loop we have provided and fill in the `YOUR CODE HERE` sections to run the experiment.\n",
        "\n",
        "You will be training an `InvariantMPNNModel` consisting of 4 layers of message passing with a hidden dimension of 64, in order to compare your result fairly to the previous vanilla `MPNNModel` and naive `CoordMPNNModel`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TDOR0aYRshZW"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate your InvariantMPNNModel with the appropriate arguments.\n",
        "#\n",
        "# model = InvariantMPNNModel(...)\n",
        "# ==========================================\n",
        "\n",
        "model_name = type(model).__name__\n",
        "best_val_error, test_error, train_time, perf_per_epoch = run_experiment(\n",
        "    model, \n",
        "    model_name, # \"MPNN w/ Features and Coordinates (Invariant Layers)\", \n",
        "    train_loader,\n",
        "    val_loader, \n",
        "    test_loader,\n",
        "    n_epochs=100\n",
        ")\n",
        "\n",
        "RESULTS[model_name] = (best_val_error, test_error, train_time)\n",
        "df_temp = pd.DataFrame(perf_per_epoch, columns=[\"Test MAE\", \"Val MAE\", \"Epoch\", \"Model\"])\n",
        "DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2A0UoiyawMqw"
      },
      "outputs": [],
      "source": [
        "RESULTS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-eJEGsolyHBN"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Val MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 2));"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fnkXCeabtk4a"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Test MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 1));"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z7ZIvMCDrBR7"
      },
      "source": [
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9C1XPXAf_Tsx"
      },
      "source": [
        "Awesome! You have now gone from a vanilla `MPNNModel`, to a naive use of coodinate information in `CoordMPNNModel`, to a more geometrically principled approach in `InvariantMPNN` model.\n",
        "\n",
        "In the next part, we will try to further push the limits of how much information we can derive from the geometry of molecules!\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g3abj3akCFcp"
      },
      "source": [
        "# 🚀 Part 4: Message Passing with Equivariance to 3D Rotations and Translations\n",
        "\n",
        "In the previous part of the practical, we studied the concepts of **3D rotation** and **translation** invariance. Now, we will go one step further. We will consider a GNN for molecular property prediction that is composed of message passing layers that are **equivariant** to 3D rotations and translations.\n",
        "\n",
        "But why...you may ask. Let us take a step back.\n",
        "\n",
        "### Why Geometric Equivariance over Invariance?\n",
        "\n",
        "In order to motivate the need for geometric equivariance and symmetries, we would like to take you back to the notion of permutation symmetries in GNNs for graphs, as well as translation symmetries in ConvNets for 2D images.\n",
        "\n",
        "#### Permutation Symmetry in GNNs vs. DeepSets\n",
        "\n",
        "Earlier in the practical, we reviewed the concept of **permutation invariance** and **equivariance**. Fundamentally, a GNN layer must be a permutation <ins>equivariant</ins> operation on the graph nodes, i.e. changing the node ordering of the graph results in the same permutation applied to the node outputs of the layer. However, the overall GNN model for graph-level property prediction is still a **permutation <ins>invariant</ins>** function on the graph nodes, i.e. changing the node ordering does not impact the predicted graph property.\n",
        "\n",
        "Recall from the lectures that the **[DeepSets model](https://arxiv.org/abs/1703.06114)** is yet another permutation <ins>invariant</ins> architecture over sets of nodes, and is a perfectly reasonable option for predicting graph-level properties (which are also permutation invariant, as we just stated). This raises a critical question: **why did we build permutation <ins>invariant</ins> GNN models composed of permutation <ins>equivariant</ins> GNN layers?**\n",
        "\n",
        "The answer is that permutation <ins>equivariant</ins> GNN layers enable the model to better leverage the **relational structure** of the underlying nodes, as well as construct more powerful node representations by **stacking several layers** of these permutation <ins>equivariant</ins> operations. (You can try running a DeepSets model for QM9 yourself and see the performance reduce.)\n",
        "\n",
        "Now, consider the same analogy for 3D rotation and translation symmetries for your molecular property prediction models. Consider your `InvariantMPNNModel` so far -- it is composed of `InvariantMPNNLayer` which are merely <ins>invariant</ins> to 3D rotations and translations.\n",
        "\n",
        "Analogous to how permutation <ins>equivariant</ins> layer enabled GNNs to leverage relational structure in a more principled manner, a **3D rotation** and **translation <ins>equivariant</ins> layer** may enable your model to **leverage geometric structure** in a more principled manner, too.\n",
        "\n",
        "#### Translation Symmetry in ConvNets for 2D Images\n",
        "\n",
        "Yet another example where <ins>invariant</ins> models are composed of <ins>equivariant</ins> layers is the ubiquitous **Convolutional Neural Network** for 2D images.\n",
        "\n",
        "The ConvNet model is <ins>invariant</ins> to **translations**, in the sense that it will detect a cat in an image, regardless of where the cat is positioned in the image.\n",
        "\n",
        "Importantly, the ConvNet is composed of **convolution filters** which are akin to sliding a rectangular window over the input image. Convolution filters are matching low level patterns within the image. Intuitively, one of these filters may be a cat detection filter, in that it will fire whenever it comes across cat-like pixels. Thus, convolution filters are translation <ins>equivariant</ins> functions since their output translates along with their input.\n",
        "\n",
        "<!-- <img src=\"https://drive.google.com/uc?id=1vgTAG_n5r3H2nqo5vaZPyC60hbZMTEkN\" width=\"100%\"> -->\n",
        "<img src=\"https://github.com/chaitjo/dump/raw/main/symmetry.png\">\n",
        "\n",
        "([Source](https://bernhard-kainz.com/))\n",
        "\n",
        "Translation <ins>invariant</ins> ConvNets are composed of translation <ins>equivariant</ins> convolution filters in order to build **heirarchical features** across multiple layers. Stacking deep ConvNets enables the features across layers to interact in a **compositional** manner and enables the overall network to learn increasingly **complex visual concepts**.\n",
        "\n",
        "The following video shows yet another demonstration of the **translational equivariance** of convolution filters: a shift to the input image directly corresponds to a shift of the output features.\n",
        "\n",
        "([Source](https://fabianfuchsml.github.io/equivariance1of2/))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3JURWoA2Cfsc"
      },
      "outputs": [],
      "source": [
        "HTML('<iframe width=\"560\" height=\"315\" src=\"https://edwag.github.io/video/translation_equivariance.mp4\" allowfullscreen></iframe>')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yO7iwRjvDD16"
      },
      "source": [
        "### Formalism\n",
        "\n",
        "Hopefully, we have sufficiently motivated the need for 3D rotation and translation equivariant GNN layers. Let us now try to formalise the notion of equivariance to 3D rotations and translations via matrix notation.\n",
        "\n",
        "- Let $\\mathbf{H} \\in \\mathbb{R}^{n \\times d}$ be a matrix of node features for a given molecular graph, where $n$ is the number of nodes/atoms and each row $h_i$ is the $d$-dimensional feature for node $i$.\n",
        "- Let $\\mathbf{X} \\in \\mathbb{R}^{n \\times 3}$ be a matrix of node coordinates for a given molecular graph, where $n$ is the number of nodes/atoms and each row $x_i$ is the 3D coordinate for node $i$.\n",
        "- Let $\\mathbf{A} \\in \\mathbb{R}^{n \\times n}$ be the adjacency matrix where each entry denotes $a_{ij}$ the presence or absence of an edge between nodes $i$ and $j$.\n",
        "- Let $\\mathbf{F}(\\mathbf{H}, \\mathbf{X}, \\mathbf{A}): \\mathbb{R}^{n \\times d} \\times \\mathbb{R}^{n \\times 3} \\times \\mathbb{R}^{n \\times n} \\rightarrow \\mathbb{R}^{n \\times d}\\times \\mathbb{R}^{n \\times 3}$ be a **GNN <ins>layer**</ins> that takes as input the node features, node coordinates, and adjacency matrix, and returns the **updated node features** as well as **updated node coordinates**.\n",
        "- Let $f(\\mathbf{H}, \\mathbf{X}, \\mathbf{A}): \\mathbb{R}^{n \\times d} \\times \\mathbb{R}^{n \\times 3} \\times \\mathbb{R}^{n \\times n} \\rightarrow \\mathbb{R}$ be a **GNN <ins>model**</ins> that takes as input the node features, node coordinates, and adjacency matrix, and returns the **predicted graph-level property**.\n",
        "\n",
        "Our GNN <ins>model</ins> $f(\\mathbf{H}, \\mathbf{X}, \\mathbf{A})$ is composed of multiple rotation and translation equivariant GNN <ins>layers</ins> $\\mathbf{F}^{\\ell}(\\mathbf{H}^{\\ell}, \\mathbf{X}^{\\ell}, \\mathbf{A}), \\ell = 1, 2, \\dots, L$. \n",
        "\n",
        "### How is this different from Geometrically Invariant Message Passing?\n",
        "\n",
        "Importantly, and in contrast to rotation and translation invariant message passing layers, each round of equivariant message passing updates both the **node features** as well as the **node coordinates**:\n",
        "$$\n",
        "\\mathbf{H}^{\\ell+1}, \\mathbf{X}^{\\ell+1} = \\mathbf{F}^{\\ell} (\\mathbf{H}^{\\ell}, \\mathbf{X}^{\\ell}, \\mathbf{A}).\n",
        "$$\n",
        "\n",
        "Such a formulation is highly beneficial for GNNs to learn useful node features in settings where we are modelling a **dynamical system** and have reason to believe that the node coordinates are continuously being updated, e.g. by the action of **intermolecular forces**.\n",
        "\n",
        "Do note the following nuances about geometrically equivariant message passing layers $\\mathbf{F}$:\n",
        "- The updated **node coodinates** $\\mathbf{X'}$ are **equivariant** to 3D rotations and translations of the input coordinates $\\mathbf{X}$.\n",
        "- The updated **node features** $\\mathbf{H'}$ are still **invariant** to 3D rotations and translations of the input coordinates $\\mathbf{X}$ (similar to the geometrically invariant message passing layer).\n",
        "- The overall **MPNN model** $f$ will still be **invariant** to 3D rotations and translations. This is because we are predicting a **single scalar quantity** (the electric dipole moment) per molecule, which remains unchanged under any rotations and translations of the atoms' coordinates. Thus, the final node feature vectors after $L$ layers of message passing are aggregated into a graph embedding (and the final node coordinates are ignored). The graph embedding is then used to predict the target.\n",
        "\n",
        "The following figure aims to succinctly capture these nuances about geometrically equivariant message passing layers $\\mathbf{F}$ which are used to compose a geometrically invariant GNN $\\mathbf{f}$:\n",
        "\n",
        "<img src=\"https://drive.google.com/uc?id=1rRsjM8AdxiU-uJ7C5t1JDMkC19QKdGPg\" width=\"100%\">\n",
        "<!-- <img src=\"https://github.com/chaitjo/dump/raw/main/gnn-symmetry.png\"> -->\n",
        "\n",
        "What we want you to investigate in this part is how we may improve a **GNN model** that is **invariant** to 3D rotations and translations by using **message passing layers** that are **equivariant** to these **3D symmetries**.\n",
        "\n",
        "Let us get started!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c-az-clhTLLv"
      },
      "source": [
        "## 💻**Task 4.1:** What does it mean for a GNN <ins>layer</ins> $\\mathbf{F}$ to be equivariant to 3D rotations and translations? Express this _mathematically_ using the definitions above. **(0.5 Marks)**\n",
        "\n",
        "🤔 *Hint: Revisit the formalisms introduced previously for permutation invariance and equivariance, as well as 3D rotation and traslation invariance.*\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NPMwl3y1C-LD"
      },
      "source": [
        "## 💻**Task 4.2:** Design a new Message Passing Layer that is <ins>equivariant</ins> to 3D rotations and translations. **(2.5 Marks)**\n",
        "\n",
        "🤔 *Hint 1: To ensure equivariance to 3D rotations and translations, your message passing layer should now update both the node features as well as the node coordinates. This means that each of the `message()`, `aggregate()`, and `update()` functions will be passing around a tuple of outputs, consisting of the node features and node coordinates.*\n",
        "\n",
        "🤔 *Hint 2: Certain quantities that can be computed among a pair of node coordinates do not change when the coordinates are rotated or translated -- these are **invariant quantities**. On the other hand, certain quantities may rotate or translate along with the coordinates -- these are **equivariant quantities**. We want you to think about how you can set up the message passing in a way that messages for the node feature updates are <ins>invariant</ins> to 3D rotations and translations, while messages for the node coordinates are <ins>equivariant</ins> to the same.*\n",
        " \n",
        "**❗️Note:** This task has multiple possible approaches for acheiving. Directly importing or copying implementations from PyG will not be accepted as a valid answer.\n",
        "\n",
        "**❗️Note:** The trivial solution $\\mathbf{X}^{\\ell+1} = \\mathbf{X}^{\\ell}$ will not be accepted as a valid answer. A general intuition about GNNs is that each node learns how to **borrow information** from its neighbours — here, this holds true for both node feature information as well as node coordinate information. Thus, we want you to use message passing to update the node coordinates by aggregating from the node coordinates of the neighbours. The ‘game’ here is about how to design a coordinate message function such that it is equivariant to 3D symmetries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oKraX_u9MXCS"
      },
      "outputs": [],
      "source": [
        "class EquivariantMPNNLayer(MessagePassing):\n",
        "    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):\n",
        "        \"\"\"Message Passing Neural Network Layer\n",
        "\n",
        "        This layer is equivariant to 3D rotations and translations.\n",
        "\n",
        "        Args:\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            aggr: (str) - aggregation function `\\oplus` (sum/mean/max)\n",
        "        \"\"\"\n",
        "        # Set the aggregation function\n",
        "        super().__init__(aggr=aggr)\n",
        "\n",
        "        self.emb_dim = emb_dim\n",
        "        self.edge_dim = edge_dim\n",
        "\n",
        "        # ============ YOUR CODE HERE ==============\n",
        "        # Define the MLPs constituting your new layer.\n",
        "        # At the least, you will need `\\psi` and `\\phi` \n",
        "        # (but their definitions may be different from what\n",
        "        # we used previously).\n",
        "        #\n",
        "        # self.mlp_msg = ...  # MLP `\\psi`\n",
        "        # self.mlp_upd = ...  # MLP `\\phi`\n",
        "        # ===========================================\n",
        "\n",
        "    def forward(self, h, pos, edge_index, edge_attr):\n",
        "        \"\"\"\n",
        "        The forward pass updates node features `h` via one round of message passing.\n",
        "\n",
        "        Args:\n",
        "            h: (n, d) - initial node features\n",
        "            pos: (n, 3) - initial node coordinates\n",
        "            edge_index: (e, 2) - pairs of edges (i, j)\n",
        "            edge_attr: (e, d_e) - edge features\n",
        "\n",
        "        Returns:\n",
        "            out: [(n, d),(n,3)] - updated node features\n",
        "        \"\"\"\n",
        "        # ============ YOUR CODE HERE ==============\n",
        "        # Notice that the `forward()` function has a new argument \n",
        "        # `pos` denoting the initial node coordinates. Your task is\n",
        "        # to update the `propagate()` function in order to pass `pos`\n",
        "        # to the `message()` function along with the other arguments.\n",
        "        #\n",
        "        # out = self.propagate(...)\n",
        "        # return out\n",
        "        # ==========================================\n",
        "\n",
        "    # ============ YOUR CODE HERE ==============\n",
        "    # Write custom `message()`, `aggregate()`, and `update()` functions\n",
        "    # which ensure that the layer is 3D rotation and translation equivariant.\n",
        "    #\n",
        "    # def message(self, ...):\n",
        "    #   ...  \n",
        "    #\n",
        "    # def aggregate(self, ...):\n",
        "    #   ...\n",
        "    #\n",
        "    # def update(self, ...):\n",
        "    #   ...\n",
        "    #\n",
        "    # ==========================================\n",
        "\n",
        "    def __repr__(self) -> str:\n",
        "        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')\n",
        "\n",
        "\n",
        "class FinalMPNNModel(MPNNModel):\n",
        "    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):\n",
        "        \"\"\"Message Passing Neural Network model for graph property prediction\n",
        "\n",
        "        This model uses both node features and coordinates as inputs, and\n",
        "        is invariant to 3D rotations and translations (the constituent MPNN layers\n",
        "        are equivariant to 3D rotations and translations).\n",
        "\n",
        "        Args:\n",
        "            num_layers: (int) - number of message passing layers `L`\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            in_dim: (int) - initial node feature dimension `d_n`\n",
        "            edge_dim: (int) - edge feature dimension `d_e`\n",
        "            out_dim: (int) - output dimension (fixed to 1)\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        \n",
        "        # Linear projection for initial node features\n",
        "        # dim: d_n -> d\n",
        "        self.lin_in = Linear(in_dim, emb_dim)\n",
        "        \n",
        "        # Stack of MPNN layers\n",
        "        self.convs = torch.nn.ModuleList()\n",
        "        for layer in range(num_layers):\n",
        "            self.convs.append(EquivariantMPNNLayer(emb_dim, edge_dim, aggr='add'))\n",
        "        \n",
        "        # Global pooling/readout function `R` (mean pooling)\n",
        "        # PyG handles the underlying logic via `global_mean_pool()`\n",
        "        self.pool = global_mean_pool\n",
        "\n",
        "        # Linear prediction head\n",
        "        # dim: d -> out_dim\n",
        "        self.lin_pred = Linear(emb_dim, out_dim)\n",
        "        \n",
        "    def forward(self, data):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            data: (PyG.Data) - batch of PyG graphs\n",
        "\n",
        "        Returns: \n",
        "            out: (batch_size, out_dim) - prediction for each graph\n",
        "        \"\"\"\n",
        "        h = self.lin_in(data.x) # (n, d_n) -> (n, d)\n",
        "        pos = data.pos\n",
        "        \n",
        "        for conv in self.convs:\n",
        "            # Message passing layer\n",
        "            h_update, pos_update = conv(h, pos, data.edge_index, data.edge_attr)\n",
        "            \n",
        "            # Update node features\n",
        "            h = h + h_update # (n, d) -> (n, d)\n",
        "            # Note that we add a residual connection after each MPNN layer\n",
        "            \n",
        "            # Update node coordinates\n",
        "            pos = pos_update # (n, 3) -> (n, 3)\n",
        "\n",
        "        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)\n",
        "\n",
        "        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)\n",
        "\n",
        "        return out.view(-1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8Xo7x-U2DH-8"
      },
      "source": [
        "Awesome! You have now defined a new message passing layer that is equivariant to 3D rotations and translations, and used it to construct your final MPNN model for molecular property prediction.\n",
        "\n",
        "## 💻**Task 4.3:** Write down the update equation of your new `EquivariantMPNNLayer` and use that to prove that the layer is equivariant to 3D rotations and translations. **(1 Mark)**\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "En3oGF0QDZtf"
      },
      "source": [
        "Great! You have successfully written the update equation for your new `EquivariantMPNNLayer` and shown how it is indeed equivariant to 3D rotations and translations.\n",
        "\n",
        "Let us just perform some sanity checks to verify this.\n",
        "\n",
        "## 💻**Task 4.4:** Perform unit tests for your `EquivariantMPNNLayer` and `FinalMPNNModel`. Firstly, write the unit test for 3D rotation and translation equivariance for the layer. Then, show that the layer is equivariant to 3D rotations and translations, and that the model is invariant to 3D rotations and translations. **(1 Mark)**\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jW0FATN-P6CJ"
      },
      "outputs": [],
      "source": [
        "def rot_trans_equivariance_unit_test(module, dataloader):\n",
        "    \"\"\"Unit test for checking whether a module (GNN layer) is \n",
        "    rotation and translation equivariant.\n",
        "    \"\"\"\n",
        "    it = iter(dataloader)\n",
        "    data = next(it)\n",
        "\n",
        "    out_1, pos_1 = module(data.x, data.pos, data.edge_index, data.edge_attr)\n",
        "\n",
        "    Q = random_orthogonal_matrix(dim=3)\n",
        "    t = torch.rand(3)\n",
        "    # ============ YOUR CODE HERE ==============\n",
        "    # Perform random rotation + translation on data.\n",
        "    #\n",
        "    # data.pos = ...\n",
        "    # ==========================================\n",
        "\n",
        "    # Forward pass on rotated + translated example\n",
        "    out_2, pos_2 = module(data.x, data.pos, data.edge_index, data.edge_attr)\n",
        "    \n",
        "    # ============ YOUR CODE HERE ==============\n",
        "    # Check whether output varies after applying transformations.\n",
        "    # return ...\n",
        "    # =========================================="
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hgxcp6WpP6OF"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate temporary model, layer, and dataloader for unit testing.\n",
        "# Remember that we are now unit testing the FinalMPNNModel, \n",
        "# which is  composed of the EquivariantMPNNLayer.\n",
        "#\n",
        "# layer = ...\n",
        "# model = ...\n",
        "# ==========================================\n",
        "dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)\n",
        "\n",
        "# Rotation and translation invariance unit test for MPNN model\n",
        "print(f\"Is {type(model).__name__} rotation and translation invariant? --> {rot_trans_invariance_unit_test(model, dataloader)}!\")\n",
        "\n",
        "# Rotation and translation invariance unit test for MPNN layer\n",
        "print(f\"Is {type(layer).__name__} rotation and translation equivariant? --> {rot_trans_equivariance_unit_test(layer, dataloader)}!\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qUdwCBgiDl3t"
      },
      "source": [
        "At last! You have defined the `EquivariantMPNNLayer` and `FinalMPNNModel`, after which you have proved and experimentally verified the new layer is equivariant to 3D rotations and translations.\n",
        "\n",
        "It is finally time to run an experiment with our final geometrically principled model!\n",
        "\n",
        "## 💻**Task 4.5:** Train and evaluate your `FinalMPNNModel`. Additionally, provide a few sentences explaining the model's results compared to the basic `MPNNModel`, the naive `CoordMPNNModel`, and the `InvariantMPNNModel` defined previously. Is the new model better? By a significant margin or only minorly better? **(0.5 Mark)**\n",
        "\n",
        "Re-use the experiment loop we have provided and fill in the `YOUR CODE HERE` sections to run the experiment.\n",
        "\n",
        "You will be training an `EquivariantMPNNModel` consisting of 4 layers of message passing with a hidden dimension of 64, in order to compare your result fairly to the previous vanilla `MPNNModel`, naive `CoordMPNNModel`, and `InvariantMPNNModel`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lIm8C-4wRVpd"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate your FinalMPNNModel with the appropriate arguments.\n",
        "#\n",
        "# model = FinalMPNNModel(...)\n",
        "# ==========================================\n",
        "\n",
        "model_name = type(model).__name__\n",
        "best_val_error, test_error, train_time, perf_per_epoch = run_experiment(\n",
        "    model, \n",
        "    model_name, # \"MPNN w/ Features and Coordinates (Equivariant Layers)\", \n",
        "    train_loader,\n",
        "    val_loader, \n",
        "    test_loader,\n",
        "    n_epochs=100\n",
        ")\n",
        "\n",
        "RESULTS[model_name] = (best_val_error, test_error, train_time)\n",
        "df_temp = pd.DataFrame(perf_per_epoch, columns=[\"Test MAE\", \"Val MAE\", \"Epoch\", \"Model\"])\n",
        "DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6Ai6EhSASjP5"
      },
      "outputs": [],
      "source": [
        "RESULTS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sGjNj20A0JcO"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Val MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 2));"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4xZFggF9ulrT"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Test MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 1));"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G_n4dNz1UVpj"
      },
      "source": [
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2SyqO97bUZpS"
      },
      "source": [
        "Congratulations! You have now gone from a vanilla `MPNNModel`, to a naive use of coodinate information in `CoordMPNNModel`, to a more geometrically principled approach in `InvariantMPNNModel`, and finally arrived at `FinalMPNNModel`, a **GNN that is invariant** to 3D rotations and translations while consisting of **message passing layers that are equivariant** to these 3D symmetries.\n",
        "\n",
        "In the next parts, we will compare these models under two different settings.\n",
        "\n",
        "---\n",
        "---\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uO-84rI6Exu5"
      },
      "source": [
        "# 🌯 Part 5: Wrapping up\n",
        "\n",
        "In this section, we will wrap up the practical by analysing two important aspects of the models that we have studied so far: **sample efficiency** and choice of **graph structure**.\n",
        "\n",
        "❗️**Note:** Ideally, **you do not need to write any new code** for the tasks in this part. You are only required to run the cells in the notebook and infer the empirical results that you see. This is an exercise to simulate how you may need to infer tables and figures when reading or writing your own research papers."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P0q0JS7DQiaN"
      },
      "source": [
        "### Sample Efficiency\n",
        "\n",
        "We firstly want you to think about sample efficiency -- model A is more sample efficient than model B if it can get the most out of every sample in the sense that it can reach better performance with lesser data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LYC7tnauPA1B"
      },
      "source": [
        "## 💻**Task 5.1:** Study all the models' performance across the number of training epochs. What do you observe? Explain your findings. (1 Mark)\n",
        "\n",
        "You can consider the number of **training epochs** as a proxy for the number of **training samples**, i.e. a model is more sample efficient if it converges to better performance within fewer epochs.\n",
        "\n",
        "Compare the models' performance across the number of training samples. How do the different modelling assumptions of the standard `MPNNModel`, the `CoordMPNNModel`, the `InvariantMPNNModel`, and the `FinalMPNN` influence sample efficiency? Which models perform best in low-sample regimes? What happens as we increase the sample size?\n",
        "\n",
        "Use the `sns.lineplot()` function provided along with the results from `DF_RESULTS` to visualise the validation and test set MAE w.r.t. the number of training epochs in order to answer this question.\n",
        "\n",
        "**❗️Note:** It is highly encouraged that you attempt this task even if you have not been successful in implementing all of the models in the practical. Just answer based on the models you did understand and implement!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1i7In-ObMEkL"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Val MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 2));"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8NzfBLV5MXy5"
      },
      "outputs": [],
      "source": [
        "p = sns.lineplot(x=\"Epoch\", y=\"Test MAE\", hue=\"Model\", data=DF_RESULTS)\n",
        "p.set(ylim=(0, 1));"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vvVdmNAkMJGJ"
      },
      "source": [
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "04WVO1Mnukdm"
      },
      "source": [
        "### Dense vs. Sparse Graphs\n",
        "\n",
        "Now, let's turn our attention to the choice of the **underlying graph structure**.\n",
        "\n",
        "In this practical we have been using fully-connected adjacency matrices to represent molecules (i.e. all atoms in a molecule are connected to each other, except self-loops). Note, however, that the information about the molecular graph has always been available to the models through the edge attributes `data.edge_attr`:\n",
        "- When two atoms are physically connected, the edge attributes indicate the bond type (single, double, triple, or aromatic) through a one-hot vector.\n",
        "- When two atoms are **not** physically connected, all edge attributes are zero.\n",
        "\n",
        "In the following task, we will study the advantages/downsides of fully-connected adjacency matrices versus sparse adjacency matrices (where an edge between two atoms is present only when there exists a physical connection between them)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "twtGyU5TMmb7"
      },
      "source": [
        "## 💻**Task 5.2:** Compare the models' performance in the two scenarios (fully-connected versus sparse graphs). Explain your findings. (1 Mark)\n",
        "\n",
        "The code to load datasets in the sparse format is provided to you. You may need to wait for some time to let all the models finish training with the sparse format. \n",
        "\n",
        "Grab a coffee/tea! ☕️\n",
        "\n",
        "**❗️Note:** Once again, it is highly encouraged that you attempt this task even if you have not been successful in implementing all of the models in the practical. Just answer based on the models you did understand and implement!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KCXierFguw7b"
      },
      "outputs": [],
      "source": [
        "# Load QM9 dataset with sparse graphs (by removing the full graphs transform)\n",
        "sparse_dataset = QM9(path, transform=SetTarget())\n",
        "\n",
        "# Normalize targets per data sample to mean = 0 and std = 1.\n",
        "mean = sparse_dataset.data.y.mean(dim=0, keepdim=True)\n",
        "std = sparse_dataset.data.y.std(dim=0, keepdim=True)\n",
        "sparse_dataset.data.y = (sparse_dataset.data.y - mean) / std\n",
        "mean, std = mean[:, target].item(), std[:, target].item()\n",
        "\n",
        "# Split datasets (3K subset)\n",
        "train_dataset_sparse = sparse_dataset[:1000]\n",
        "val_dataset_sparse = sparse_dataset[1000:2000]\n",
        "test_dataset_sparse = sparse_dataset[2000:]\n",
        "print(f\"Created sparse dataset splits with {len(train_dataset_sparse)} training, {len(val_dataset_sparse)} validation, {len(test_dataset_sparse)} test samples.\")\n",
        "\n",
        "# Create dataloaders with batch size = 32\n",
        "train_loader_sparse = DataLoader(train_dataset_sparse, batch_size=32, shuffle=True)\n",
        "test_loader_sparse = DataLoader(test_dataset_sparse, batch_size=32, shuffle=False)\n",
        "val_loader_sparse = DataLoader(val_dataset_sparse, batch_size=32, shuffle=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r5Hzvjjlu8Po"
      },
      "source": [
        "Let's now check that the sparse dataset is actually more sparse than the fully-connected dataset that we have been using throughout the practical:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KcoOtUlRu-7y"
      },
      "outputs": [],
      "source": [
        "val_batch_sparse = next(iter(val_loader_sparse))\n",
        "val_batch_dense = next(iter(val_loader))\n",
        "\n",
        "# These two batches should correspond to the same molecules. Let's add a sanity check\n",
        "assert torch.allclose(val_batch_sparse.y, val_batch_dense.y, atol=1e-4)\n",
        "\n",
        "print(f\"Number of edges in sparse batch {val_batch_sparse.edge_index.shape[-1]}. Number of edges in dense batch {val_batch_dense.edge_index.shape[-1]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZmEYl8nUvoiy"
      },
      "source": [
        "Let's now compare the models under the two scenarios:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rA2iXAeg0c1s"
      },
      "outputs": [],
      "source": [
        "sparse_results = {}\n",
        "dense_results = RESULTS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hphYbE9WvnQg"
      },
      "outputs": [],
      "source": [
        "# ============ YOUR CODE HERE ==============\n",
        "# Instantiate your models\n",
        "models = [MPNNModel(), CoordMPNNModel(), InvariantMPNNModel(), FinalMPNNModel()]\n",
        "# ==========================================\n",
        "\n",
        "for model in models:\n",
        "  model_name = type(model).__name__\n",
        "\n",
        "  if model_name not in sparse_results:\n",
        "    sparse_results[model_name] = run_experiment(\n",
        "        model, \n",
        "        model_name, \n",
        "        train_loader_sparse,\n",
        "        val_loader_sparse, \n",
        "        test_loader_sparse,\n",
        "        n_epochs=100\n",
        "    )\n",
        "  \n",
        "  if model_name not in dense_results:\n",
        "    dense_results[model_name] = run_experiment(\n",
        "        model, \n",
        "        model_name, \n",
        "        train_loader,\n",
        "        val_loader, \n",
        "        test_loader,\n",
        "        n_epochs=100\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q2RoAlb-ytrE"
      },
      "outputs": [],
      "source": [
        "df_sparse = pd.DataFrame.from_dict(sparse_results, orient='index', columns=['Best val MAE', 'Test MAE', 'Train time', 'Train History'])\n",
        "df_dense = pd.DataFrame.from_dict(dense_results, orient='index', columns=['Best val MAE', 'Test MAE', 'Train time'])\n",
        "df_sparse['type'] = 'sparse'\n",
        "df_dense['type'] = 'dense'\n",
        "df = df_sparse.append(df_dense)\n",
        "\n",
        "sns.set(rc={'figure.figsize':(10, 6)})\n",
        "sns.barplot(x=df.index, y=\"Test MAE\", hue=\"type\", data=df);\n",
        "\n",
        "# You might want to save and download this plot\n",
        "# plt.savefig(\"comparison.png\")\n",
        "# files.download(\"comparison.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3G6fTYdwzLWj"
      },
      "source": [
        "Compare the models' performances under the two scenarios. Which models performed better/worst? Why do you think that is the case? Did you observe any differences between the fully-connected and sparse scenarios? Provide at least *two* arguments to explain the differences.\n",
        "\n",
        "---\n",
        "\n",
        "<font color='red'>❗️YOUR ANSWER HERE</font>\n",
        "\n",
        "---"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g80G1Pn1NWd1"
      },
      "source": [
        "---\n",
        "---\n",
        "---\n",
        "\n",
        "[Fin.](https://www.youtube.com/watch?v=b9434BoGkNQ)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rZNq0sMWCzDz"
      },
      "source": [
        "---\n",
        "---\n",
        "---\n",
        "\n",
        "# FAQ and General Advice\n",
        "\n",
        "### Unit Testing\n",
        "Consider using the unit test functions as sanity checks to make sure what you think is going on is actually going on. You can even consider writing your own unit tests beyond what we have asked in order to test specific properties of layers/models or numerical issues during training.\n",
        "\n",
        "### Theory vs. Empirical Results\n",
        "If you are confident that your proposed layer is theoretically sound, e.g. in Task 4.2. if your solution satisfied 3D equivariance but the results are not impressive or you are unable to achieve stable training, it may be due to numerical instability or engineering issues. If that is the case and you are not able to overcome those issues eventually, you may still submit whatever you have as a solution. You will be awarded partial marks as long as the theory is correct.\n",
        "\n",
        "### GPU rate-limit on Google Colab\n",
        "**TL;DR** Don’t panic, start early, save and load your results instead of re-running every time.\n",
        "\n",
        "We experienced rate-limits several times during the testing of the practical. It seems that there is an upper limit to the amount of GPU computation per user per 12-24 hours. When the limit is hit, Colab disconnects the GPU runtime (you do reconnect back to the GPU by the next day). Thus, we have tried to keep the practical as computationally simple as we can. \n",
        "\n",
        "We have several suggestions to make life more manageable, which we enumerate in the following bullet points:\n",
        "- If possible, do not leave things for the last moment. Start early so that you are not struggling with the rate limit on the day of the deadline!\n",
        "- If you do get rate-limited, you can consider writing and testing your model implementations with very small dataset sizes, e.g. 100 samples each. When you reconnect to the GPU, you can re-run your models with the full dataset.\n",
        "- If you find yourself hit by regular rate-limiting (e.g. if you also have other Colab projects running, this can happen a lot), you can save the results of each task to your Google Drive/local storage and simply load them each time you re-run the notebook.\n",
        "- You can use some combination of a new Google account, your cam.ac.uk account, and/or a new IP address to get a fresh GPU runtime if you have been rate-limited on one account.\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xM3HXnhwDk_n"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "_l0tU8gum9eT",
        "TAwoHyWs452X",
        "V03t1f51tW2w",
        "0fZTuGyVNC6n",
        "kY6J_vL3hgCN",
        "wqcJy1HtPgj6",
        "yO7iwRjvDD16",
        "c-az-clhTLLv",
        "04WVO1Mnukdm"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) \n[Clang 12.0.1 ]"
    },
    "vscode": {
      "interpreter": {
        "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
