{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GS4GalqayyJ2"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install wandb\n",
        "!apt-get install git\n",
        "!apt autoremove\n",
        "!pip3 install awscli\n",
        "\n",
        "!mkdir -p /root/workspace/data/\n",
        "!mkdir -p /root/workspace/out/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q2ogKJ59zNcw"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "%cd /root/workspace\n",
        "!git clone https://github.com/Utah-Math-Data-Science/UnitSphere.git\n",
        "!git clone https://github.com/chaitjo/geometric-gnn-dojo.git\n",
        "!pip3 install -r /root/workspace/UnitSphere/requirements.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QPGcEH8K_lSb"
      },
      "outputs": [],
      "source": [
        "%cd /root/workspace/geometric-gnn-dojo/\n",
        "!git stash\n",
        "!git pull"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lHOxeU9Gw9HS"
      },
      "outputs": [],
      "source": [
        "# %%capture\n",
        "%cd /root/workspace\n",
        "!cp /root/workspace/UnitSphere/ext/train_mse_utils.py ./geometric-gnn-dojo/experiments/utils/train_utils.py # remove once iclr is pulled\n",
        "!cp /root/workspace/UnitSphere/ext/comenet.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled\n",
        "!echo \"from models.comenet import ComENetModel\" >> ./geometric-gnn-dojo/models/__init__.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UkC8_r-NLdXo"
      },
      "source": [
        "# Models"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from abc import ABCMeta\n",
        "import ast\n",
        "import torch\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from mpl_toolkits.mplot3d import Axes3D\n",
        "from matplotlib import cm\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "from torch_geometric.transforms import BaseTransform\n",
        "\n",
        "import sys\n",
        "sys.path.append('/root/workspace/UnitSphere/alignment/pyorbit/utils/')\n",
        "from alignment3D import *\n",
        "from geometry import angle_between_vectors, planar_normal, project_onto_plane\n",
        "from hopcroft import PartitionRefinement\n",
        "from qhull import Qhull\n",
        "\n",
        "sys.path.append('/root/workspace/UnitSphere/alignment/pyorbit/vis/')\n",
        "from visualizer import Visualizer, plot_axes, plot_mol, plot_shell, plot_3d_pointcloud, plot_3d_polyhedron, plot_point, plot_plane\n",
        "\n",
        "def build_adjacency_list(edges):\n",
        "    adj_list = {}\n",
        "    for edge in edges:\n",
        "        a, b = edge\n",
        "        if a not in adj_list:\n",
        "            adj_list[a] = []\n",
        "        if b not in adj_list:\n",
        "            adj_list[b] = []\n",
        "        adj_list[a].append(b)\n",
        "        adj_list[b].append(a)\n",
        "    for key in adj_list:\n",
        "        adj_list[key].sort()\n",
        "    adj_list = dict(sorted(adj_list.items()))\n",
        "    return adj_list\n",
        "\n",
        "def get_key(dct, value):\n",
        "    keys = []\n",
        "    for key, val in dct.items():\n",
        "        if val == value:\n",
        "            keys.append(key)\n",
        "    return keys\n",
        "\n",
        "def direct_graph(edges):\n",
        "    dg = []\n",
        "    for edge in edges:\n",
        "        dg.append(list(edge))\n",
        "        dg.append(list(edge[::-1]))\n",
        "    return dg\n",
        "\n",
        "def custom_round(number, tolerance):\n",
        "    k = int(-np.log10(tolerance))\n",
        "    return round(number, k)\n",
        "\n",
        "def list_rotate(lst):\n",
        "    idx = lst.index(min(lst))\n",
        "    return lst[idx:] + lst[:idx]\n",
        "\n",
        "class Molecule:\n",
        "    def __init__(self, data=None, cat_data=None):\n",
        "        self.pos = data\n",
        "        self.z = cat_data\n",
        "\n",
        "class Frame(BaseTransform):\n",
        "    def __init__(self, tol=1e-2, *args, **kwargs):\n",
        "        super().__init__()\n",
        "        self.tol = tol\n",
        "        self.chull = Qhull()\n",
        "\n",
        "    def __call__(self, data):\n",
        "        pos, cat_data, edge_index_hull, edge_attr_hull, radial_arr = self.get_frame(data.pos, data.atoms.squeeze())\n",
        "        data.edge_index_hull = torch.tensor(edge_index_hull, dtype=torch.long).contiguous()\n",
        "        data.edge_attr_hull = torch.tensor(edge_attr_hull, dtype=torch.float)\n",
        "        if not hasattr(data, 'edge_index') or data.edge_index is None:\n",
        "          data.edge_index = torch.tensor(edge_index_hull, dtype=torch.long).contiguous()\n",
        "        else:\n",
        "          print(data.edge_index)\n",
        "        data.radial_attr = radial_arr\n",
        "        return data\n",
        "\n",
        "    def align(self, data, shell_data, cat_data, pth):\n",
        "        funcs = {0: z_axis_alignment, 1: zy_planar_alignment, 2: sign_alignment}\n",
        "        for idx,val in enumerate(pth):\n",
        "            # print('func index {}'.format(idx))\n",
        "            # print('input {}'.format(val))\n",
        "            # print(shell_data[val])\n",
        "            data = funcs[idx](data, shell_data[val])\n",
        "            shell_data = funcs[idx](shell_data, shell_data[val])\n",
        "        return data, shell_data\n",
        "\n",
        "    def traverse(self, sorted_graph, shell_data, shell_rank):\n",
        "        edge = 0\n",
        "        v0 = sorted_graph[edge][0][0]\n",
        "        if shell_rank == 1:\n",
        "            return [v0]\n",
        "        s0 = shell_data[v0]\n",
        "\n",
        "        v1 = None\n",
        "        while v1 is None and edge < len(sorted_graph):\n",
        "            possible_indices = sorted_graph[edge][1]\n",
        "            possible_indices = [i for i in possible_indices if i != v0]\n",
        "            for idx in possible_indices:\n",
        "                if np.abs(np.dot(s0, shell_data[idx])) > self.tol:\n",
        "                    v1 = idx\n",
        "                    break\n",
        "            if v1 is None:\n",
        "                edge += 1\n",
        "\n",
        "        if shell_rank == 2:\n",
        "            return [v0, v1]\n",
        "\n",
        "        v2 = self.v2_subroutine(v0, v1, edge, sorted_graph, shell_data, shell_rank)\n",
        "        if v2 is None:\n",
        "            v2 = self.v2_subroutine(v1, v0, edge, sorted_graph, shell_data, shell_rank)\n",
        "\n",
        "        assert v2 is not None, 'v2 is None'\n",
        "\n",
        "        return [v0, v1, v2]\n",
        "\n",
        "    def v2_subroutine(self, v0, v1, edge, sorted_graph, shell_data, shell_rank):\n",
        "        s0 = shell_data[v0]\n",
        "        s1 = shell_data[v1]\n",
        "        v2 = None\n",
        "        while v2 is None and edge < len(sorted_graph):\n",
        "            if v1 in sorted_graph[edge][0]:\n",
        "                possible_indices = sorted_graph[edge][1]\n",
        "                possible_indices = [i for i in possible_indices if i != v0]\n",
        "                possible_indices = [i for i in possible_indices if i != v1]\n",
        "                for idx in possible_indices:\n",
        "                    cond1 = np.abs(np.dot(s0, shell_data[idx])) > self.tol\n",
        "                    cond2 = np.abs(np.dot(s1, shell_data[idx])) > self.tol\n",
        "                    if cond1 and cond2:\n",
        "                        v2 = idx\n",
        "                        break\n",
        "            if v2 is None:\n",
        "                edge += 1\n",
        "        return v2\n",
        "\n",
        "\n",
        "    def convert_partition(self, dist_hash, g_hash, r_encoding, g_encoding):\n",
        "        edges = list(tuple(ast.literal_eval(k)) for k in self.hopcroft._partition.keys())\n",
        "        ret_edges = []\n",
        "        ret_graph = []\n",
        "        for edge in edges:\n",
        "            # print(edge)\n",
        "            a,b = edge\n",
        "            r0 = get_key(dist_hash, a[0])\n",
        "            g0 = get_key(g_hash, a[1])\n",
        "            r1 = get_key(dist_hash, b[0])\n",
        "            g1 = get_key(g_hash, b[1])\n",
        "            ret_edges.append([(r0,g0),(r1,g1)])\n",
        "            r0 = get_key(r_encoding, a[0])\n",
        "            r1 = get_key(r_encoding, b[0])\n",
        "            ret_graph.append([r0,r1])\n",
        "\n",
        "        indexed_edges = sorted(enumerate(ret_edges), key=lambda x: x[1])\n",
        "        sorted_inidces = [i for i,_ in indexed_edges]\n",
        "        ret_edges = [element for index, element in indexed_edges]\n",
        "        ret_graph = [ret_graph[i] for i in sorted_inidces]\n",
        "        return sorted(ret_edges), ret_graph\n",
        "\n",
        "\n",
        "    def construct_dfa(self, encoding, graph):\n",
        "        dfa_encoding = {}\n",
        "        dfa_set = list()\n",
        "        for i,edge in enumerate(graph):\n",
        "            value = str([encoding[edge[0]], encoding[edge[1]]])\n",
        "            dfa_encoding[(edge[0], edge[1])] = (value, i)\n",
        "            dfa_set.append(value)\n",
        "        return dfa_set, dfa_encoding\n",
        "\n",
        "    def align_center(self, pointcloud):\n",
        "        return pointcloud - np.mean(pointcloud,axis=0)\n",
        "\n",
        "    def get_hull_geometric_info(self, shell_data,\n",
        "                                adj_list,\n",
        "                                shell_rank):\n",
        "        # Project edges onto relative plane\n",
        "        s_feature = {}\n",
        "\n",
        "        for point in adj_list.keys():\n",
        "            r_ij = shell_data[adj_list[point]]-shell_data[point]\n",
        "            if shell_rank == 1:\n",
        "                d_ij = np.zeros_like(np.linalg.norm(r_ij, axis=1))\n",
        "            else:\n",
        "                d_ij = np.linalg.norm(r_ij, axis=1)\n",
        "            lst = {}\n",
        "            for ct in range(len(r_ij)):\n",
        "                lst[adj_list[point][ct]] = (\n",
        "                                            d_ij[ct],\n",
        "                                            (r_ij[ct][0],\n",
        "                                             r_ij[ct][1],\n",
        "                                             r_ij[ct][2],\n",
        "                                             )\n",
        "                                            )\n",
        "\n",
        "            s_feature[point] = lst\n",
        "        return s_feature\n",
        "\n",
        "    def geometric_encoding(self, shell_data,\n",
        "                           adj_list,\n",
        "                           shell_rank,\n",
        "                           angle_sorted=False):\n",
        "        # Project edges onto relative plane\n",
        "        encoding = {}\n",
        "        g_hash = {}\n",
        "        s_feature = {}\n",
        "\n",
        "        for point in adj_list.keys():\n",
        "            r_ij = shell_data[adj_list[point]]-shell_data[point]\n",
        "            if shell_rank == 1:\n",
        "                d_ij = np.zeros_like(np.linalg.norm(r_ij, axis=1))\n",
        "            else:\n",
        "                d_ij = np.linalg.norm(r_ij, axis=1)\n",
        "            projection = project_onto_plane(r_ij, shell_data[point])\n",
        "            angle = []\n",
        "            for i in range(len(projection)):\n",
        "\n",
        "                if shell_rank == 3:\n",
        "                    # angle += [angle_between_vectors(projection[i], projection[i-1])]\n",
        "                    # To do: optimize\n",
        "                    if i < len(projection) - 1:\n",
        "                        if angle_sorted:\n",
        "                            angle.append(tuple(sorted([angle_between_vectors(projection[i], projection[i+1]),\n",
        "                                            angle_between_vectors(projection[i], projection[i-1])])))\n",
        "                        else:\n",
        "                            angle.append(tuple([angle_between_vectors(projection[i], projection[i-1]),\n",
        "                                            angle_between_vectors(projection[i], projection[i+1])]))\n",
        "                            # if np.isnan(angle_between_vectors(projection[i], projection[i-1])):\n",
        "                            #     print(projection[i])\n",
        "                            #     print(projection[i-1])\n",
        "                    else:\n",
        "                        if angle_sorted:\n",
        "                            angle.append(tuple(sorted([angle_between_vectors(projection[i], projection[0]),\n",
        "                                            angle_between_vectors(projection[i], projection[i-1])])))\n",
        "                        else:\n",
        "                            angle.append(tuple([angle_between_vectors(projection[i], projection[i-1]),\n",
        "                                            angle_between_vectors(projection[i], projection[0])]))\n",
        "                            # if np.isnan(angle_between_vectors(projection[i], projection[i-1])):\n",
        "                            #     print(projection[i])\n",
        "                            #     print(projection[i-1])\n",
        "                    ### modified by hyh: save two angles ###\n",
        "                else:\n",
        "                    angle += [(0, 0)]\n",
        "\n",
        "\n",
        "            # lexicographical shift\n",
        "            ### modified by hyh ###\n",
        "            # lst = [(custom_round(a,self.tol), custom_round(d, self.tol)) for a,d in zip(angle, d_ij)]\n",
        "            lst = {}\n",
        "            ct = 0\n",
        "            for angles, d in zip(angle, d_ij):\n",
        "                # lst.append(\n",
        "                #         (\n",
        "                #             d,\n",
        "                #             (\n",
        "                #                 custom_round(angles[0], self.tol),\n",
        "                #                 custom_round(angles[1], self.tol)\n",
        "                #             ),\n",
        "                #             (point, adj_list[point][ct])\n",
        "                #         )\n",
        "                #     )\n",
        "                lst[adj_list[point][ct]] = (\n",
        "                                            d,\n",
        "                                            (\n",
        "                                                custom_round(angles[0], self.tol),\n",
        "                                                custom_round(angles[1], self.tol)\n",
        "                                            )\n",
        "                                            )\n",
        "                ct += 1\n",
        "            s_feature[point] = lst\n",
        "\n",
        "            # lst = tuple(list_rotate(lst))\n",
        "            # if lst not in g_hash:\n",
        "            #     g_hash[lst] = id(lst)\n",
        "            # encoding[point] = g_hash[lst]\n",
        "            g_hash = None\n",
        "            encoding = None\n",
        "\n",
        "        return g_hash, encoding, s_feature\n",
        "\n",
        "\n",
        "    def check_type(self, data, *args, **kwargs):\n",
        "        if isinstance(data, torch.Tensor):\n",
        "            return data.detach().cpu().numpy()\n",
        "        elif isinstance(data, np.ndarray):\n",
        "            return data\n",
        "        else:\n",
        "            raise TypeError(f\"Data type not supported {type(data)}\")\n",
        "\n",
        "    def project_sphere(self, data, cat_data, *args, **kwargs):\n",
        "\n",
        "        distances = np.linalg.norm(data, axis=1, keepdims=False)\n",
        "        temp =  data/np.linalg.norm(data, axis=1, keepdims=True)\n",
        "        arr, key = np.unique(temp, axis=0, return_inverse=True)\n",
        "\n",
        "        # record which node projected\n",
        "        proj_index_record = {}\n",
        "        for k in range(len(key)):\n",
        "            proj_index_record[key[k]] = []\n",
        "        for k in range(len(key)):\n",
        "            proj_index_record[key[k]].append(k)\n",
        "        ### modified by hyh ###\n",
        "\n",
        "\n",
        "        encoding = {}\n",
        "        dists_hash = {}\n",
        "        for val in set(key):\n",
        "            dists = [(custom_round(d,self.tol), custom_round(c,self.tol))  for d,c in zip(distances[key==val],cat_data[key==val])]\n",
        "            dists = tuple(sorted(dists))\n",
        "            if dists not in dists_hash:\n",
        "                dists_hash[dists] = id(dists)\n",
        "\n",
        "            encoding[val] = dists_hash[dists]\n",
        "\n",
        "        proj_index_record_reverse = {}\n",
        "        for key in proj_index_record:\n",
        "            for i in range(len(proj_index_record[key])):\n",
        "                proj_index_record_reverse[proj_index_record[key][i]] = key\n",
        "\n",
        "        return dists_hash, encoding, arr, proj_index_record, proj_index_record_reverse\n",
        "\n",
        "    def get_recover_adj(self,\n",
        "                        adj_list,\n",
        "                        shell_data_proj_id_rcrd):\n",
        "        # step one\n",
        "        recover_adj_list_1 = {}\n",
        "        for key in adj_list:\n",
        "            recover_key = shell_data_proj_id_rcrd[key]\n",
        "            for k in range(len(recover_key)):\n",
        "                recover_adj_list_1[recover_key[k]] = adj_list[key]\n",
        "\n",
        "        recover_adj_list_2 = {}\n",
        "        for key in recover_adj_list_1:\n",
        "            lst = recover_adj_list_1[key]\n",
        "            temp = []\n",
        "            for k in range(len(lst)):\n",
        "                temp += shell_data_proj_id_rcrd[lst[k]]\n",
        "            temp.sort()\n",
        "            recover_adj_list_2[key] = temp\n",
        "\n",
        "        recover_adj_list_2 = dict(sorted(recover_adj_list_2.items()))\n",
        "        # for key in recover_adj_list_2:\n",
        "        #     recover_adj_list_2[key].sort()\n",
        "        return recover_adj_list_2\n",
        "\n",
        "    ### modified by hyh ###\n",
        "    def get_merged_edge_index(self,\n",
        "                              adj_list,\n",
        "                              shell_data_proj_id_rcrd,\n",
        "                              data_edge_index):\n",
        "        # step one\n",
        "        recover_adj_list_1 = {}\n",
        "        for key in adj_list:\n",
        "            recover_key = shell_data_proj_id_rcrd[key]\n",
        "            for k in range(len(recover_key)):\n",
        "                recover_adj_list_1[recover_key[k]] = adj_list[key]\n",
        "\n",
        "        recover_adj_list_2 = {}\n",
        "        for key in recover_adj_list_1:\n",
        "            lst = recover_adj_list_1[key]\n",
        "            temp = []\n",
        "            for k in range(len(lst)):\n",
        "                temp += shell_data_proj_id_rcrd[lst[k]]\n",
        "            recover_adj_list_2[key] = temp\n",
        "\n",
        "        edge_node = np.unique(data_edge_index[0])\n",
        "        data_edge_index_list = {}\n",
        "        for k in range(len(edge_node)):\n",
        "            data_edge_index_list[edge_node[k]] = []\n",
        "        for k in range(len(data_edge_index[0])):\n",
        "            data_edge_index_list[int(data_edge_index[0][k])].append(int(data_edge_index[1][k]))\n",
        "\n",
        "        for key in recover_adj_list_2:\n",
        "            lst = data_edge_index_list[key]\n",
        "            for ik in range(len(lst)):\n",
        "                if lst[ik] not in recover_adj_list_2[key]:\n",
        "                    recover_adj_list_2[key].append(lst[ik])\n",
        "\n",
        "        return recover_adj_list_2\n",
        "\n",
        "    def merge_coord_info(self,\n",
        "                         data, s_feature,\n",
        "                         shell_data_proj_id_rcrd):\n",
        "        new_coord_fea = {}\n",
        "        for key in shell_data_proj_id_rcrd:\n",
        "            for k in range(len(shell_data_proj_id_rcrd[key])):\n",
        "                key_ = shell_data_proj_id_rcrd[key][k]\n",
        "                new_coord_fea[key_] = {'R': np.linalg.norm(data[key_])}\n",
        "\n",
        "        return new_coord_fea\n",
        "\n",
        "    def get_radial_arr(self, data):\n",
        "        radial_arr = []\n",
        "        for i in range(len(data)):\n",
        "            radial_arr.append(np.linalg.norm(data[i]))\n",
        "        return radial_arr\n",
        "\n",
        "    def adj_arr(self, adj_list):\n",
        "        arr = [[], []]\n",
        "        for key in adj_list:\n",
        "            temp = adj_list[key].copy()\n",
        "            for k in range(len(temp)):\n",
        "                arr[0].append(int(key))\n",
        "                arr[1].append(int(temp[k]))\n",
        "        return arr\n",
        "\n",
        "    def edge_attr_arr(self, s_feature,\n",
        "                      proj_id_rcrd_rvrs,\n",
        "                      edge_index_hull):\n",
        "\n",
        "        attr_arr = []\n",
        "        for i in range(len(edge_index_hull[0])):\n",
        "            key1 = proj_id_rcrd_rvrs[edge_index_hull[0][i]]\n",
        "            key2 = proj_id_rcrd_rvrs[edge_index_hull[1][i]]\n",
        "            temp = s_feature[key1][key2]\n",
        "            # attr_arr.append(\n",
        "            #         [temp[0],\n",
        "            #          temp[1][0],\n",
        "            #          temp[1][1]]\n",
        "            #     )\n",
        "            attr_arr.append(\n",
        "                    [temp[0],\n",
        "                    temp[1][0],\n",
        "                    temp[1][1],\n",
        "                    temp[1][2]]\n",
        "                )\n",
        "        return attr_arr\n",
        "\n",
        "    def get_frame(self, data, cat_data, data_edge_index=None, *args, **kwargs):\n",
        "\n",
        "        data = self.check_type(data) # Assert Type\n",
        "        cat_data = self.check_type(cat_data) # Assert Type\n",
        "        data = self.align_center(data) # Assert Centered\n",
        "        indices = np.linalg.norm(data, axis=1) > self.tol\n",
        "        original_data = data.copy()\n",
        "        original_cat = cat_data.copy()\n",
        "        data = data[indices]\n",
        "        cat_data = cat_data[indices]\n",
        "\n",
        "        ### In order to debug, intentionally make two points proj into one\n",
        "        # data[1] = data[0].copy() * 2\n",
        "        ### modified by hyh ###\n",
        "\n",
        "        # PROJECT ONTO SPHERE\n",
        "        ### modified by hyh ###\n",
        "        dist_hash, r_encoding, shell_data, shell_data_proj_id_rcrd,  shell_data_proj_id_rcrd_rvrs= self.project_sphere(data,\n",
        "                                                                                                                        cat_data,\n",
        "                                                                                                                        *args,\n",
        "                                                                                                                        **kwargs)\n",
        "\n",
        "\n",
        "\n",
        "        # GET CONVEX HULL\n",
        "        shell_rank = np.linalg.matrix_rank(shell_data, tol=self.tol)\n",
        "        shell_n = shell_data.shape[0]\n",
        "        shell_graph = self.chull.get_chull_graph(shell_data, shell_rank, shell_n)\n",
        "\n",
        "\n",
        "        # bool_lst = [i in shell_graph for i in range(shell_n)]\n",
        "        # if not all(bool_lst):\n",
        "        #     false_values = [i for i, x in enumerate(bool_lst) if not x]\n",
        "        #     shell_data = np.delete(shell_data, false_values, axis=0)\n",
        "        #     # PROJECT ONTO SPHERE\n",
        "        #     ### modified by hyh ###\n",
        "        #     dist_hash, r_encoding, shell_data, _ = self.project_sphere(shell_data, cat_data,\n",
        "        #                                                                *args, **kwargs)\n",
        "        #     cat_hash, cat_encoding = self.categorical_encoding(data, cat_data)\n",
        "\n",
        "        #     # GET CONVEX HULL\n",
        "        #     shell_rank = np.linalg.matrix_rank(shell_data, tol=self.tol)\n",
        "        #     shell_n = shell_data.shape[0]\n",
        "        #     shell_graph = self.chull.get_chull_graph(shell_data, shell_rank, shell_n)\n",
        "\n",
        "        # bool_lst = [i in shell_graph for i in range(shell_n)]\n",
        "        # assert all(bool_lst), 'Convex Hull is not correct'\n",
        "\n",
        "        # GET GEOMETRIC ENCODING\n",
        "        adj_list = build_adjacency_list(shell_graph)\n",
        "\n",
        "        s_feature = self.get_hull_geometric_info(shell_data,\n",
        "                                                 adj_list,\n",
        "                                                 shell_rank,\n",
        "                                                 )\n",
        "\n",
        "        rcvr_adj_list = self.get_recover_adj(adj_list, shell_data_proj_id_rcrd)\n",
        "\n",
        "        edge_index_hull = self.adj_arr(rcvr_adj_list)\n",
        "\n",
        "        edge_attr_hull = self.edge_attr_arr(s_feature,\n",
        "                                            shell_data_proj_id_rcrd_rvrs,\n",
        "                                            edge_index_hull)\n",
        "        radial_arr = self.get_radial_arr(data)\n",
        "\n",
        "        return data, cat_data, edge_index_hull, edge_attr_hull, radial_arr"
      ],
      "metadata": {
        "id": "Yp9819OaVAqQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "# Based on the code from: https://github.com/TUM-DAML/gemnet_pytorch\n",
        "# https://github.com/TUM-DAML/gemnet_pytorch/blob/master/gemnet/model/layers/basis_utils.py\n",
        "# https://github.com/TUM-DAML/gemnet_pytorch/blob/master/gemnet/model/layers/basis_layers.py\n",
        "\n",
        "import math\n",
        "import torch\n",
        "import sympy as sym\n",
        "import numpy as np\n",
        "from scipy.optimize import brentq\n",
        "from scipy import special as sp\n",
        "from math import pi as PI\n",
        "from scipy.special import binom\n",
        "from torch_geometric.nn.models.schnet import GaussianSmearing\n",
        "\n",
        "\n",
        "def Jn(r, n):\n",
        "    \"\"\"\n",
        "    numerical spherical bessel functions of order n\n",
        "    \"\"\"\n",
        "    return sp.spherical_jn(n, r)\n",
        "\n",
        "\n",
        "def Jn_zeros(n, k):\n",
        "    \"\"\"\n",
        "    Compute the first k zeros of the spherical bessel functions up to order n (excluded)\n",
        "    \"\"\"\n",
        "    zerosj = np.zeros((n, k), dtype=\"float32\")\n",
        "    zerosj[0] = np.arange(1, k + 1) * np.pi\n",
        "    points = np.arange(1, k + n) * np.pi\n",
        "    racines = np.zeros(k + n - 1, dtype=\"float32\")\n",
        "    for i in range(1, n):\n",
        "        for j in range(k + n - 1 - i):\n",
        "            foo = brentq(Jn, points[j], points[j + 1], (i,))\n",
        "            racines[j] = foo\n",
        "        points = racines\n",
        "        zerosj[i][:k] = racines[:k]\n",
        "\n",
        "    return zerosj\n",
        "\n",
        "\n",
        "def spherical_bessel_formulas(n):\n",
        "    \"\"\"\n",
        "    Computes the sympy formulas for the spherical bessel functions up to order n (excluded)\n",
        "    \"\"\"\n",
        "    x = sym.symbols(\"x\")\n",
        "    # j_i = (-x)^i * (1/x * d/dx)^î * sin(x)/x\n",
        "    j = [sym.sin(x) / x]  # j_0\n",
        "    a = sym.sin(x) / x\n",
        "    for i in range(1, n):\n",
        "        b = sym.diff(a, x) / x\n",
        "        j += [sym.simplify(b * (-x) ** i)]\n",
        "        a = sym.simplify(b)\n",
        "    return j\n",
        "\n",
        "\n",
        "def bessel_basis(n, k):\n",
        "    \"\"\"\n",
        "    Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to\n",
        "    order n (excluded) and maximum frequency k (excluded).\n",
        "    Returns:\n",
        "        bess_basis: list\n",
        "            Bessel basis formulas taking in a single argument x.\n",
        "            Has length n where each element has length k. -> In total n*k many.\n",
        "    \"\"\"\n",
        "    zeros = Jn_zeros(n, k)\n",
        "    normalizer = []\n",
        "    for order in range(n):\n",
        "        normalizer_tmp = []\n",
        "        for i in range(k):\n",
        "            normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1) ** 2]\n",
        "        normalizer_tmp = (\n",
        "            1 / np.array(normalizer_tmp) ** 0.5\n",
        "        )  # sqrt(2/(j_l+1)**2) , sqrt(1/c**3) not taken into account yet\n",
        "        normalizer += [normalizer_tmp]\n",
        "\n",
        "    f = spherical_bessel_formulas(n)\n",
        "    x = sym.symbols(\"x\")\n",
        "    bess_basis = []\n",
        "    for order in range(n):\n",
        "        bess_basis_tmp = []\n",
        "        for i in range(k):\n",
        "            bess_basis_tmp += [\n",
        "                sym.simplify(\n",
        "                    normalizer[order][i] * f[order].subs(x, zeros[order, i] * x)\n",
        "                )\n",
        "            ]\n",
        "        bess_basis += [bess_basis_tmp]\n",
        "    return bess_basis\n",
        "\n",
        "\n",
        "def sph_harm_prefactor(l, m):\n",
        "    \"\"\"Computes the constant pre-factor for the spherical harmonic of degree l and order m.\n",
        "    Parameters\n",
        "    ----------\n",
        "        l: int\n",
        "            Degree of the spherical harmonic. l >= 0\n",
        "        m: int\n",
        "            Order of the spherical harmonic. -l <= m <= l\n",
        "    Returns\n",
        "    -------\n",
        "        factor: float\n",
        "    \"\"\"\n",
        "    # sqrt((2*l+1)/4*pi * (l-m)!/(l+m)! )\n",
        "    return (\n",
        "        (2 * l + 1)\n",
        "        / (4 * np.pi)\n",
        "        * np.math.factorial(l - abs(m))\n",
        "        / np.math.factorial(l + abs(m))\n",
        "    ) ** 0.5\n",
        "\n",
        "\n",
        "def associated_legendre_polynomials(L, zero_m_only=True, pos_m_only=True):\n",
        "    \"\"\"Computes string formulas of the associated legendre polynomials up to degree L (excluded).\n",
        "    Parameters\n",
        "    ----------\n",
        "        L: int\n",
        "            Degree up to which to calculate the associated legendre polynomials (degree L is excluded).\n",
        "        zero_m_only: bool\n",
        "            If True only calculate the polynomials for the polynomials where m=0.\n",
        "        pos_m_only: bool\n",
        "            If True only calculate the polynomials for the polynomials where m>=0. Overwritten by zero_m_only.\n",
        "    Returns\n",
        "    -------\n",
        "        polynomials: list\n",
        "            Contains the sympy functions of the polynomials (in total L many if zero_m_only is True else L^2 many).\n",
        "    \"\"\"\n",
        "    # calculations from http://web.cmb.usc.edu/people/alber/Software/tomominer/docs/cpp/group__legendre__polynomials.html\n",
        "    z = sym.symbols(\"z\")\n",
        "    P_l_m = [[0] * (2 * l + 1) for l in range(L)]  # for order l: -l <= m <= l\n",
        "\n",
        "    P_l_m[0][0] = 1\n",
        "    if L > 0:\n",
        "        if zero_m_only:\n",
        "            # m = 0\n",
        "            P_l_m[1][0] = z\n",
        "            for l in range(2, L):\n",
        "                P_l_m[l][0] = sym.simplify(\n",
        "                    ((2 * l - 1) * z * P_l_m[l - 1][0] - (l - 1) * P_l_m[l - 2][0]) / l\n",
        "                )\n",
        "            return P_l_m\n",
        "        else:\n",
        "            # for m >= 0\n",
        "            for l in range(1, L):\n",
        "                P_l_m[l][l] = sym.simplify(\n",
        "                    (1 - 2 * l) * (1 - z ** 2) ** 0.5 * P_l_m[l - 1][l - 1]\n",
        "                )  # P_00, P_11, P_22, P_33\n",
        "\n",
        "            for m in range(0, L - 1):\n",
        "                P_l_m[m + 1][m] = sym.simplify(\n",
        "                    (2 * m + 1) * z * P_l_m[m][m]\n",
        "                )  # P_10, P_21, P_32, P_43\n",
        "\n",
        "            for l in range(2, L):\n",
        "                for m in range(l - 1):  # P_20, P_30, P_31\n",
        "                    P_l_m[l][m] = sym.simplify(\n",
        "                        (\n",
        "                            (2 * l - 1) * z * P_l_m[l - 1][m]\n",
        "                            - (l + m - 1) * P_l_m[l - 2][m]\n",
        "                        )\n",
        "                        / (l - m)\n",
        "                    )\n",
        "\n",
        "            if not pos_m_only:\n",
        "                # for m < 0: P_l(-m) = (-1)^m * (l-m)!/(l+m)! * P_lm\n",
        "                for l in range(1, L):\n",
        "                    for m in range(1, l + 1):  # P_1(-1), P_2(-1) P_2(-2)\n",
        "                        P_l_m[l][-m] = sym.simplify(\n",
        "                            (-1) ** m\n",
        "                            * np.math.factorial(l - m)\n",
        "                            / np.math.factorial(l + m)\n",
        "                            * P_l_m[l][m]\n",
        "                        )\n",
        "\n",
        "            return P_l_m\n",
        "\n",
        "\n",
        "def real_sph_harm(L, spherical_coordinates, zero_m_only=True):\n",
        "    \"\"\"\n",
        "    Computes formula strings of the the real part of the spherical harmonics up to degree L (excluded).\n",
        "    Variables are either spherical coordinates phi and theta (or cartesian coordinates x,y,z) on the UNIT SPHERE.\n",
        "    Parameters\n",
        "    ----------\n",
        "        L: int\n",
        "            Degree up to which to calculate the spherical harmonics (degree L is excluded).\n",
        "        spherical_coordinates: bool\n",
        "            - True: Expects the input of the formula strings to be phi and theta.\n",
        "            - False: Expects the input of the formula strings to be x, y and z.\n",
        "        zero_m_only: bool\n",
        "            If True only calculate the harmonics where m=0.\n",
        "    Returns\n",
        "    -------\n",
        "        Y_lm_real: list\n",
        "            Computes formula strings of the the real part of the spherical harmonics up\n",
        "            to degree L (where degree L is not excluded).\n",
        "            In total L^2 many sph harm exist up to degree L (excluded). However, if zero_m_only only is True then\n",
        "            the total count is reduced to be only L many.\n",
        "    \"\"\"\n",
        "    z = sym.symbols(\"z\")\n",
        "    P_l_m = associated_legendre_polynomials(L, zero_m_only)\n",
        "    if zero_m_only:\n",
        "        # for all m != 0: Y_lm = 0\n",
        "        Y_l_m = [[0] for l in range(L)]\n",
        "    else:\n",
        "        Y_l_m = [[0] * (2 * l + 1) for l in range(L)]  # for order l: -l <= m <= l\n",
        "\n",
        "    # convert expressions to spherical coordiantes\n",
        "    if spherical_coordinates:\n",
        "        # replace z by cos(theta)\n",
        "        theta = sym.symbols(\"theta\")\n",
        "        for l in range(L):\n",
        "            for m in range(len(P_l_m[l])):\n",
        "                if not isinstance(P_l_m[l][m], int):\n",
        "                    P_l_m[l][m] = P_l_m[l][m].subs(z, sym.cos(theta))\n",
        "\n",
        "    ## calculate Y_lm\n",
        "    # Y_lm = N * P_lm(cos(theta)) * exp(i*m*phi)\n",
        "    #             { sqrt(2) * (-1)^m * N * P_l|m| * sin(|m|*phi)   if m < 0\n",
        "    # Y_lm_real = { Y_lm                                           if m = 0\n",
        "    #             { sqrt(2) * (-1)^m * N * P_lm * cos(m*phi)       if m > 0\n",
        "\n",
        "    for l in range(L):\n",
        "        Y_l_m[l][0] = sym.simplify(sph_harm_prefactor(l, 0) * P_l_m[l][0])  # Y_l0\n",
        "\n",
        "    if not zero_m_only:\n",
        "        phi = sym.symbols(\"phi\")\n",
        "        for l in range(1, L):\n",
        "            # m > 0\n",
        "            for m in range(1, l + 1):\n",
        "                Y_l_m[l][m] = sym.simplify(\n",
        "                    2 ** 0.5\n",
        "                    * (-1) ** m\n",
        "                    * sph_harm_prefactor(l, m)\n",
        "                    * P_l_m[l][m]\n",
        "                    * sym.cos(m * phi)\n",
        "                )\n",
        "            # m < 0\n",
        "            for m in range(1, l + 1):\n",
        "                Y_l_m[l][-m] = sym.simplify(\n",
        "                    2 ** 0.5\n",
        "                    * (-1) ** m\n",
        "                    * sph_harm_prefactor(l, -m)\n",
        "                    * P_l_m[l][m]\n",
        "                    * sym.sin(m * phi)\n",
        "                )\n",
        "\n",
        "        # convert expressions to cartesian coordinates\n",
        "        if not spherical_coordinates:\n",
        "            # replace phi by atan2(y,x)\n",
        "            x = sym.symbols(\"x\")\n",
        "            y = sym.symbols(\"y\")\n",
        "            for l in range(L):\n",
        "                for m in range(len(Y_l_m[l])):\n",
        "                    Y_l_m[l][m] = sym.simplify(Y_l_m[l][m].subs(phi, sym.atan2(y, x)))\n",
        "    return Y_l_m\n",
        "\n",
        "class Envelope(torch.nn.Module):\n",
        "    def __init__(self, exponent):\n",
        "        super(Envelope, self).__init__()\n",
        "        self.p = exponent + 1\n",
        "        self.a = -(self.p + 1) * (self.p + 2) / 2\n",
        "        self.b = self.p * (self.p + 2)\n",
        "        self.c = -self.p * (self.p + 1) / 2\n",
        "\n",
        "    def forward(self, x):\n",
        "        p, a, b, c = self.p, self.a, self.b, self.c\n",
        "        x_pow_p0 = x.pow(p - 1)\n",
        "        x_pow_p1 = x_pow_p0 * x\n",
        "        x_pow_p2 = x_pow_p1 * x\n",
        "        return 1. / x + a * x_pow_p0 + b * x_pow_p1 + c * x_pow_p2\n",
        "\n",
        "class dist_emb(torch.nn.Module):\n",
        "    def __init__(self, num_radial, cutoff=5.0, envelope_exponent=5):\n",
        "        super(dist_emb, self).__init__()\n",
        "        self.cutoff = cutoff\n",
        "        self.envelope = Envelope(envelope_exponent)\n",
        "\n",
        "        self.freq = torch.nn.Parameter(torch.Tensor(num_radial))\n",
        "\n",
        "        self.reset_parameters()\n",
        "\n",
        "    def reset_parameters(self):\n",
        "        self.freq.data = torch.arange(1, self.freq.numel() + 1).float().mul_(PI)\n",
        "\n",
        "    def forward(self, dist):\n",
        "        dist = dist.unsqueeze(-1) / self.cutoff\n",
        "        return self.envelope(dist) * (self.freq * dist).sin()\n",
        "\n",
        "class angle_emb(torch.nn.Module):\n",
        "    def __init__(self, num_radial, num_spherical, cutoff=8.0):\n",
        "        super(angle_emb, self).__init__()\n",
        "        assert num_radial <= 64\n",
        "        self.num_spherical = num_spherical\n",
        "        self.num_radial = num_radial\n",
        "        self.cutoff = cutoff\n",
        "\n",
        "        bessel_formulas = bessel_basis(num_spherical, num_radial)\n",
        "        Y_lm = real_sph_harm(\n",
        "            num_spherical, spherical_coordinates=True, zero_m_only=True\n",
        "        )\n",
        "        self.sph_funcs = []\n",
        "        self.bessel_funcs = []\n",
        "\n",
        "        x = sym.symbols(\"x\")\n",
        "        theta = sym.symbols(\"theta\")\n",
        "        modules = {\"sin\": torch.sin, \"cos\": torch.cos, \"sqrt\": torch.sqrt}\n",
        "        m = 0\n",
        "        for l in range(len(Y_lm)):\n",
        "            if l == 0:\n",
        "                first_sph = sym.lambdify([theta], Y_lm[l][m], modules)\n",
        "                self.sph_funcs.append(\n",
        "                    lambda theta: torch.zeros_like(theta) + first_sph(theta)\n",
        "                )\n",
        "            else:\n",
        "                self.sph_funcs.append(sym.lambdify([theta], Y_lm[l][m], modules))\n",
        "            for n in range(num_radial):\n",
        "                self.bessel_funcs.append(\n",
        "                    sym.lambdify([x], bessel_formulas[l][n], modules)\n",
        "                )\n",
        "\n",
        "    def forward(self, dist, angle):\n",
        "        dist = dist / self.cutoff\n",
        "        rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)\n",
        "        sbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)\n",
        "        n, k = self.num_spherical, self.num_radial\n",
        "        out = (rbf.view(-1, n, k) * sbf.view(-1, n, 1)).view(-1, n * k)\n",
        "        return out\n",
        "\n",
        "\n",
        "class torsion_emb(torch.nn.Module):\n",
        "    def __init__(self, num_radial, num_spherical, cutoff=8.0):\n",
        "        super(torsion_emb, self).__init__()\n",
        "        assert num_radial <= 64\n",
        "        self.num_radial = num_radial\n",
        "        self.num_spherical = num_spherical\n",
        "        self.cutoff = cutoff\n",
        "\n",
        "        bessel_formulas = bessel_basis(num_spherical, num_radial)\n",
        "        Y_lm = real_sph_harm(\n",
        "            num_spherical, spherical_coordinates=True, zero_m_only=False\n",
        "        )\n",
        "        self.sph_funcs = []\n",
        "        self.bessel_funcs = []\n",
        "\n",
        "        x = sym.symbols(\"x\")\n",
        "        theta = sym.symbols(\"theta\")\n",
        "        phi = sym.symbols(\"phi\")\n",
        "        modules = {\"sin\": torch.sin, \"cos\": torch.cos, \"sqrt\": torch.sqrt}\n",
        "        for l in range(len(Y_lm)):\n",
        "            for m in range(len(Y_lm[l])):\n",
        "                if (\n",
        "                        l == 0\n",
        "                ):\n",
        "                    first_sph = sym.lambdify([theta, phi], Y_lm[l][m], modules)\n",
        "                    self.sph_funcs.append(\n",
        "                        lambda theta, phi: torch.zeros_like(theta)\n",
        "                                           + first_sph(theta, phi)\n",
        "                    )\n",
        "                else:\n",
        "                    self.sph_funcs.append(\n",
        "                        sym.lambdify([theta, phi], Y_lm[l][m], modules)\n",
        "                    )\n",
        "            for j in range(num_radial):\n",
        "                self.bessel_funcs.append(\n",
        "                    sym.lambdify([x], bessel_formulas[l][j], modules)\n",
        "                )\n",
        "\n",
        "        self.register_buffer(\n",
        "            \"degreeInOrder\", torch.arange(num_spherical) * 2 + 1, persistent=False\n",
        "        )\n",
        "\n",
        "    def forward(self, dist, theta, phi):\n",
        "        dist = dist / self.cutoff\n",
        "        rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)\n",
        "        sbf = torch.stack([f(theta, phi) for f in self.sph_funcs], dim=1)\n",
        "\n",
        "        n, k = self.num_spherical, self.num_radial\n",
        "        rbf = rbf.view((-1, n, k)).repeat_interleave(self.degreeInOrder, dim=1).view((-1, n ** 2 * k))\n",
        "        sbf = sbf.repeat_interleave(k, dim=1)\n",
        "        out = rbf * sbf\n",
        "        return out\n"
      ],
      "metadata": {
        "id": "53wNRwtKZ9R4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from torch_geometric.nn.conv import MessagePassing\n",
        "from torch_scatter import scatter, scatter_min\n",
        "\n",
        "def get_angle_torsion(edge_index,\n",
        "                      vecs, dist,\n",
        "                      num_nodes,\n",
        "                      cutoff=9999):\n",
        "    j, i = edge_index\n",
        "\n",
        "    # Calculate distances.\n",
        "    _, argmin0 = scatter_min(dist, i, dim_size=num_nodes)\n",
        "    argmin0[argmin0 >= len(i)] = 0\n",
        "    n0 = j[argmin0]\n",
        "    add = torch.zeros_like(dist).to(dist.device)\n",
        "    add[argmin0] = cutoff\n",
        "    dist1 = dist + add\n",
        "\n",
        "    _, argmin1 = scatter_min(dist1, i, dim_size=num_nodes)\n",
        "    argmin1[argmin1 >= len(i)] = 0\n",
        "    n1 = j[argmin1]\n",
        "    # --------------------------------------------------------\n",
        "\n",
        "    _, argmin0_j = scatter_min(dist, j, dim_size=num_nodes)\n",
        "    argmin0_j[argmin0_j >= len(j)] = 0\n",
        "    n0_j = i[argmin0_j]\n",
        "\n",
        "    add_j = torch.zeros_like(dist).to(dist.device)\n",
        "    add_j[argmin0_j] = cutoff\n",
        "    dist1_j = dist + add_j\n",
        "\n",
        "    # i[argmin] = range(0, num_nodes)\n",
        "    _, argmin1_j = scatter_min(dist1_j, j, dim_size=num_nodes)\n",
        "    argmin1_j[argmin1_j >= len(j)] = 0\n",
        "    n1_j = i[argmin1_j]\n",
        "\n",
        "    # ----------------------------------------------------------\n",
        "\n",
        "    # n0, n1 for i\n",
        "    n0 = n0[i]\n",
        "    n1 = n1[i]\n",
        "\n",
        "    # n0, n1 for j\n",
        "    n0_j = n0_j[j]\n",
        "    n1_j = n1_j[j]\n",
        "\n",
        "\n",
        "    mask_iref = n0 == j\n",
        "    iref = torch.clone(n0)\n",
        "    iref[mask_iref] = n1[mask_iref]\n",
        "    idx_iref = argmin0[i]\n",
        "    idx_iref[mask_iref] = argmin1[i][mask_iref]\n",
        "\n",
        "    mask_jref = n0_j == i\n",
        "    jref = torch.clone(n0_j)\n",
        "    jref[mask_jref] = n1_j[mask_jref]\n",
        "    idx_jref = argmin0_j[j]\n",
        "    idx_jref[mask_jref] = argmin1_j[j][mask_jref]\n",
        "\n",
        "    pos_ji, pos_in0, pos_in1, pos_iref, pos_jref_j = (\n",
        "        vecs,\n",
        "        vecs[argmin0][i],\n",
        "        vecs[argmin1][i],\n",
        "        vecs[idx_iref],\n",
        "        vecs[idx_jref]\n",
        "    )\n",
        "\n",
        "    # Calculate angles.\n",
        "    a = ((-pos_ji) * pos_in0).sum(dim=-1)\n",
        "    b = torch.cross(-pos_ji, pos_in0).norm(dim=-1)\n",
        "    theta = torch.atan2(b, a)\n",
        "    theta[theta < 0] = theta[theta < 0] + math.pi\n",
        "\n",
        "    # Calculate torsions.\n",
        "    dist_ji = pos_ji.pow(2).sum(dim=-1).sqrt()\n",
        "    plane1 = torch.cross(-pos_ji, pos_in0)\n",
        "    plane2 = torch.cross(-pos_ji, pos_in1)\n",
        "    a = (plane1 * plane2).sum(dim=-1)  # cos_angle * |plane1| * |plane2|\n",
        "    b = (torch.cross(plane1, plane2) * pos_ji).sum(dim=-1) / dist_ji\n",
        "    phi = torch.atan2(b, a)\n",
        "    phi[phi < 0] = phi[phi < 0] + math.pi\n",
        "\n",
        "    return theta, phi"
      ],
      "metadata": {
        "id": "_91ACrZFc8DF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.nn import Linear, ReLU, SiLU, Sequential\n",
        "from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool\n",
        "from torch_scatter import scatter\n",
        "\n",
        "\n",
        "class MPNNLayer(MessagePassing):\n",
        "    def __init__(self, emb_dim, activation=\"relu\", norm=\"layer\", aggr=\"add\"):\n",
        "        \"\"\"Vanilla Message Passing GNN layer\n",
        "\n",
        "        Args:\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            activation: (str) - non-linearity within MLPs (swish/relu)\n",
        "            norm: (str) - normalisation layer (layer/batch)\n",
        "            aggr: (str) - aggregation function `\\oplus` (sum/mean/max)\n",
        "        \"\"\"\n",
        "        # Set the aggregation function\n",
        "        super().__init__(aggr=aggr)\n",
        "\n",
        "        self.emb_dim = emb_dim\n",
        "        self.activation = {\"swish\": SiLU(), \"relu\": ReLU()}[activation]\n",
        "        self.norm = {\"layer\": torch.nn.LayerNorm, \"batch\": torch.nn.BatchNorm1d}[norm]\n",
        "\n",
        "        # MLP `\\psi_h` for computing messages `m_ij`\n",
        "        self.mlp_msg = Sequential(\n",
        "            Linear(2 * (emb_dim), emb_dim),\n",
        "            self.norm(emb_dim),\n",
        "            self.activation,\n",
        "            Linear(emb_dim, emb_dim),\n",
        "            self.norm(emb_dim),\n",
        "            self.activation,\n",
        "        )\n",
        "        # MLP `\\phi` for computing updated node features `h_i^{l+1}`\n",
        "        self.mlp_upd = Sequential(\n",
        "            Linear(2 * emb_dim, emb_dim),\n",
        "            self.norm(emb_dim),\n",
        "            self.activation,\n",
        "            Linear(emb_dim, emb_dim),\n",
        "            self.norm(emb_dim),\n",
        "            self.activation,\n",
        "        )\n",
        "\n",
        "    def forward(self, h, edge_index):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            h: (n, d) - initial node features\n",
        "            edge_index: (e, 2) - pairs of edges (i, j)\n",
        "        Returns:\n",
        "            out: (n, d) - updated node features\n",
        "        \"\"\"\n",
        "        out = self.propagate(edge_index, h=h)\n",
        "        return out\n",
        "\n",
        "    def message(self, h_i, h_j):\n",
        "        # Compute messages\n",
        "        msg = torch.cat([h_i, h_j], dim=-1)\n",
        "        msg = self.mlp_msg(msg)\n",
        "        return msg\n",
        "\n",
        "    def aggregate(self, inputs, index):\n",
        "        # Aggregate messages\n",
        "        msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)\n",
        "        return msg_aggr\n",
        "\n",
        "    def update(self, aggr_out, h):\n",
        "        upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1))\n",
        "        return upd_out\n",
        "\n",
        "    def __repr__(self) -> str:\n",
        "        return f\"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})\"\n",
        "\n",
        "\n",
        "class MPNNModel(torch.nn.Module):\n",
        "    \"\"\"\n",
        "    MLP model\n",
        "    \"\"\"\n",
        "    def __init__(\n",
        "        self,\n",
        "        num_layers: int = 5,\n",
        "        emb_dim: int = 128,\n",
        "        in_dim: int = 1,\n",
        "        out_dim: int = 1,\n",
        "        activation: str = \"relu\",\n",
        "        norm: str = \"layer\",\n",
        "        aggr: str = \"sum\",\n",
        "        pool: str = \"sum\",\n",
        "        residual: bool = True,\n",
        "        equivariant_pred: bool = False,\n",
        "        *kwargs\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Initializes an instance of the EGNNModel class with the provided parameters.\n",
        "\n",
        "        Parameters:\n",
        "        - num_layers (int): Number of layers in the model (default: 5)\n",
        "        - emb_dim (int): Dimension of the node embeddings (default: 128)\n",
        "        - in_dim (int): Input dimension of the model (default: 1)\n",
        "        - out_dim (int): Output dimension of the model (default: 1)\n",
        "        - activation (str): Activation function to be used (default: \"relu\")\n",
        "        - norm (str): Normalization method to be used (default: \"layer\")\n",
        "        - aggr (str): Aggregation method to be used (default: \"sum\")\n",
        "        - pool (str): Global pooling method to be used (default: \"sum\")\n",
        "        - residual (bool): Whether to use residual connections (default: True)\n",
        "        - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        self.equivariant_pred = equivariant_pred\n",
        "        self.residual = residual\n",
        "\n",
        "        # Embedding lookup for initial node features\n",
        "        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)\n",
        "\n",
        "        # Stack of GNN layers\n",
        "        self.convs = torch.nn.ModuleList()\n",
        "        for _ in range(num_layers):\n",
        "            self.convs.append(MPNNLayer(emb_dim, activation, norm, aggr))\n",
        "\n",
        "        # Global pooling/readout function\n",
        "        self.pool = {\"mean\": global_mean_pool, \"sum\": global_add_pool}[pool]\n",
        "\n",
        "        self.pred = torch.nn.Sequential(\n",
        "            torch.nn.Linear(emb_dim, emb_dim),\n",
        "            torch.nn.ReLU(),\n",
        "            torch.nn.Linear(emb_dim, out_dim)\n",
        "        )\n",
        "\n",
        "    def forward(self, batch):\n",
        "\n",
        "        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)\n",
        "\n",
        "        for conv in self.convs:\n",
        "            # Message passing layer\n",
        "            h_update = conv(h, batch.edge_index)\n",
        "\n",
        "            # Update node features (n, d) -> (n, d)\n",
        "            h = h + h_update if self.residual else h_update\n",
        "\n",
        "        if not self.equivariant_pred:\n",
        "            # Select only scalars for invariant prediction\n",
        "            out = self.pool(h, batch.batch)  # (n, d) -> (batch_size, d)\n",
        "        else:\n",
        "            out = self.pool(h, batch.batch)\n",
        "\n",
        "        return self.pred(out)  # (batch_size, out_dim)"
      ],
      "metadata": {
        "id": "Kuau-QL94pVk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch.nn import Linear, ReLU, SiLU, Sequential\n",
        "from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool\n",
        "from torch_scatter import scatter\n",
        "\n",
        "class EMPNNLayer(MessagePassing):\n",
        "    def __init__(self,\n",
        "                 emb_dim,\n",
        "                 attr_dim,\n",
        "                 num_radial: int = 2,\n",
        "                 num_spherical: int = 2,\n",
        "                 cutoff: float = 8.0,\n",
        "                 activation=\"relu\", norm=\"layer\", aggr=\"add\"):\n",
        "        \"\"\"Vanilla Message Passing GNN layer\n",
        "\n",
        "        Args:\n",
        "            emb_dim: (int) - hidden dimension `d`\n",
        "            activation: (str) - non-linearity within MLPs (swish/relu)\n",
        "            norm: (str) - normalisation layer (layer/batch)\n",
        "            aggr: (str) - aggregation function `\\oplus` (sum/mean/max)\n",
        "        \"\"\"\n",
        "        # Set the aggregation function\n",
        "        super().__init__(aggr=aggr)\n",
        "\n",
        "        self.emb_dim = emb_dim\n",
        "        self.activation = {\"swish\": SiLU(), \"relu\": ReLU()}[activation]\n",
        "\n",
        "        # MLP `\\psi_h` for computing messages `m_ij`\n",
        "        self.mlp_msg = Sequential(\n",
        "            Linear(2 * emb_dim + num_radial * num_spherical**2, emb_dim),\n",
        "            self.activation,\n",
        "            Linear(emb_dim, emb_dim),\n",
        "        )\n",
        "        # MLP `\\phi` for computing updated node features `h_i^{l+1}`\n",
        "        self.mlp_upd = Sequential(\n",
        "            Linear(2 * emb_dim, emb_dim),\n",
        "            self.activation,\n",
        "            Linear(emb_dim, emb_dim),\n",
        "        )\n",
        "\n",
        "    def forward(self, h, pos, edge_index, edge_attr):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            h: (n, d) - initial node features\n",
        "            edge_index: (e, 2) - pairs of edges (i, j)\n",
        "        Returns:\n",
        "            out: (n, d) - updated node features\n",
        "        \"\"\"\n",
        "        out = self.propagate(edge_index, h=h, edge_attr=edge_attr)\n",
        "        return out\n",
        "\n",
        "    def message(self, h_i, h_j, edge_attr):\n",
        "        # Compute messages\n",
        "        msg = torch.cat([h_i, h_j, edge_attr], dim=-1)\n",
        "        msg = self.mlp_msg(msg)\n",
        "        return msg\n",
        "\n",
        "    def aggregate(self, inputs, index):\n",
        "        # Aggregate messages\n",
        "        msg_aggr = scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)\n",
        "        return msg_aggr\n",
        "\n",
        "    def update(self, aggr_out, h):\n",
        "        upd_out = self.mlp_upd(torch.cat([h, aggr_out], dim=-1))\n",
        "        return upd_out\n",
        "\n",
        "    def __repr__(self) -> str:\n",
        "        return f\"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})\"\n",
        "\n",
        "\n",
        "class EMPNNModel(torch.nn.Module):\n",
        "    \"\"\"\n",
        "    MLP model with edge attribute convolution\n",
        "    \"\"\"\n",
        "    def __init__(\n",
        "        self,\n",
        "        num_layers: int = 5,\n",
        "        emb_dim: int = 128,\n",
        "        attr_dim: int = 4,\n",
        "        in_dim: int = 1,\n",
        "        out_dim: int = 1,\n",
        "        num_radial: int = 2,\n",
        "        num_spherical: int = 2,\n",
        "        cutoff: float = 8.0,\n",
        "        activation: str = \"swish\",\n",
        "        norm: str = \"layer\",\n",
        "        aggr: str = \"sum\",\n",
        "        pool: str = \"sum\",\n",
        "        residual: bool = True,\n",
        "        equivariant_pred: bool = False,\n",
        "        *kwargs\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Initializes an instance of the EGNNModel class with the provided parameters.\n",
        "\n",
        "        Parameters:\n",
        "        - num_layers (int): Number of layers in the model (default: 5)\n",
        "        - emb_dim (int): Dimension of the node embeddings (default: 128)\n",
        "        - in_dim (int): Input dimension of the model (default: 1)\n",
        "        - out_dim (int): Output dimension of the model (default: 1)\n",
        "        - activation (str): Activation function to be used (default: \"relu\")\n",
        "        - norm (str): Normalization method to be used (default: \"layer\")\n",
        "        - aggr (str): Aggregation method to be used (default: \"sum\")\n",
        "        - pool (str): Global pooling method to be used (default: \"sum\")\n",
        "        - residual (bool): Whether to use residual connections (default: True)\n",
        "        - equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)\n",
        "        \"\"\"\n",
        "        super().__init__()\n",
        "        self.equivariant_pred = equivariant_pred\n",
        "        self.residual = residual\n",
        "\n",
        "        # Embedding lookup for initial node features\n",
        "        self.emb_in = torch.nn.Embedding(in_dim, emb_dim)\n",
        "\n",
        "        self.feature_emb = torsion_emb(num_radial=num_radial,\n",
        "                                                 num_spherical=num_spherical)\n",
        "\n",
        "        # Stack of GNN layers\n",
        "        self.convs = torch.nn.ModuleList()\n",
        "        for _ in range(num_layers):\n",
        "            self.convs.append(EMPNNLayer(emb_dim, attr_dim, num_radial, num_spherical, cutoff, activation, norm, aggr))\n",
        "\n",
        "        # Global pooling/readout function\n",
        "        self.pool = {\"mean\": global_mean_pool, \"sum\": global_add_pool}[pool]\n",
        "\n",
        "        self.pred = torch.nn.Sequential(\n",
        "            torch.nn.Linear(emb_dim, emb_dim),\n",
        "            torch.nn.SiLU(),\n",
        "            torch.nn.Linear(emb_dim, out_dim)\n",
        "        )\n",
        "\n",
        "    def forward(self, batch):\n",
        "\n",
        "        h = self.emb_in(batch.atoms)  # (n,) -> (n, d)\n",
        "\n",
        "        edge_index_hull, edge_attr_hull, r = batch.edge_index_hull, batch.edge_attr_hull, batch.radial_attr\n",
        "        dist_hull = edge_attr_hull[:, 0]\n",
        "        vecs_hull = edge_attr_hull[:, 1:]\n",
        "        i_hull, j_hull = edge_index_hull\n",
        "\n",
        "        theta_hull, phi_hull = get_angle_torsion(edge_index = edge_index_hull,\n",
        "                                                 vecs = vecs_hull,\n",
        "                                                 dist = dist_hull,\n",
        "                                                 num_nodes = batch.atoms.size(0))\n",
        "\n",
        "        edge_attr = self.feature_emb(dist_hull, theta_hull, phi_hull)\n",
        "        for conv in self.convs:\n",
        "            h_update = conv(h, batch.pos, batch.edge_index, edge_attr)\n",
        "            h = h + h_update if self.residual else h_update\n",
        "\n",
        "        out = self.pool(h, batch.batch)\n",
        "\n",
        "        return self.pred(out)  # (batch_size, out_dim)\n"
      ],
      "metadata": {
        "id": "H1OoHe5Pvzm-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nrIzD1hSLhT_"
      },
      "source": [
        "# Datasets"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from scipy.spatial import Delaunay, delaunay_plot_2d\n",
        "from scipy.spatial import Voronoi, voronoi_plot_2d\n",
        "from scipy.spatial import ConvexHull\n",
        "\n",
        "\n",
        "def compute_convhull_edges(pos, vis=False):\n",
        "        edges = []\n",
        "        hull = ConvexHull(pos, qhull_options='Qx')\n",
        "        for simplex in hull.simplices:\n",
        "          edges.append(simplex)\n",
        "        edge_index = np.array(list(edges))\n",
        "        edge_index = torch.from_numpy(edge_index).T.to(torch.long)\n",
        "        edge_index = to_undirected(edge_index)\n",
        "        return edge_index\n",
        "\n",
        "\n",
        "def compute_voronoi_edges(pos, vis=False):\n",
        "    pos_np = pos.numpy()  # Convert to numpy array\n",
        "    tri = Delaunay(pos)\n",
        "    vor = Voronoi(pos_np)\n",
        "    if vis:\n",
        "      try:\n",
        "        voronoi_plot_2d(vor)\n",
        "        delaunay_plot_2d(tri)\n",
        "      except:\n",
        "        pass\n",
        "    rows, cols = tri.vertex_neighbor_vertices\n",
        "    edges = []\n",
        "    for i in range(len(rows) - 1):\n",
        "        start, end = rows[i], rows[i + 1]\n",
        "        neighbors = cols[start:end]\n",
        "        for neighbor in neighbors:\n",
        "            edges.append([i, neighbor])\n",
        "\n",
        "    return edges\n"
      ],
      "metadata": {
        "id": "wbaGRFGPuh3E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cOFCwx4W7X1d"
      },
      "source": [
        "## Simple Chain Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6MgPmXp7RAl"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "sys.path.append('/root/workspace/geometric-gnn-dojo/')\n",
        "\n",
        "import scipy\n",
        "import torch\n",
        "import torch_geometric\n",
        "from torch_geometric.data import Data\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.transforms import KNNGraph, RadiusGraph, RemoveIsolatedNodes\n",
        "from torch_geometric.utils import to_undirected\n",
        "import e3nn\n",
        "from functools import partial\n",
        "\n",
        "from torch_geometric.seed import seed_everything\n",
        "\n",
        "from experiments.utils.plot_utils import plot_3d\n",
        "\n",
        "def create_kchains(k,connectivity='radius'):\n",
        "    seed_everything(10)\n",
        "    assert k >= 2\n",
        "    assert connectivity in ['radius', 'knn', 'voronoi', 'convhull', 'full', 'unitsphere']\n",
        "\n",
        "    dataset = []\n",
        "\n",
        "    # Graph 0\n",
        "    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
        "    cell = torch.diag(torch.ones(3,dtype=torch.float)).view(1,3,3)\n",
        "    outer_box = torch.FloatTensor([[-4, -4, 0], [-4,4,0], [4,4,0], [4,-4,0]])\n",
        "    inner_box = torch.FloatTensor([[-2, -2, 0], [-2,2,0], [2,2,0], [2,-2,0]])\n",
        "    pos = torch.cat([outer_box, inner_box])\n",
        "    y = torch.FloatTensor([0])  # Label gvp0\n",
        "    data1 = Data(atoms=atoms, pos=pos, y=y, natoms=k+2, cell=cell)\n",
        "\n",
        "    # Edges\n",
        "    if connectivity == 'radius':\n",
        "      data1 = RadiusGraph(4)(data1)\n",
        "    elif connectivity == 'voronoi':\n",
        "      voronoi_edges = compute_voronoi_edges(data1.pos[:,:-1])\n",
        "      data1.edge_index = torch.tensor(voronoi_edges, dtype=torch.long).t().contiguous()\n",
        "    elif connectivity == 'convhull':\n",
        "      data1.edge_index = compute_convhull_edges(data1.pos[:,:-1])\n",
        "    elif connectivity == 'unitsphere':\n",
        "      data1 = Frame()(data1)\n",
        "      # print(data1)\n",
        "    elif connectivity == 'knn':\n",
        "      data1 = KNNGraph(k=4)(data1)\n",
        "    elif connectivity == 'full':\n",
        "      edge_index = []\n",
        "      for i in range(k+2):\n",
        "        for j in range(k+2):\n",
        "          edge_index.append([i,j])\n",
        "          edge_index.append([j,i])\n",
        "      data1.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()\n",
        "\n",
        "    edge_index = to_undirected(data1.edge_index)\n",
        "    edges_set = set(map(tuple, edge_index.t().tolist()))\n",
        "    data1.edge_index = torch.tensor(list(edges_set), dtype=torch.long).t()\n",
        "\n",
        "    dataset.append(data1)\n",
        "\n",
        "    # Graph 1\n",
        "    for i in range(9):\n",
        "      atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
        "      outer_box = torch.FloatTensor([[-4, -4, 0], [-4,4,0], [4,4,0], [4,-4,0]])\n",
        "      inner_box = torch.FloatTensor([[-2, -2, 0], [-2,2,0], [2,2,0], [2,-2,0]])\n",
        "      # rotate inner box 45 degrees\n",
        "      random_rotation = 90*torch.rand(1)\n",
        "      # print(random_rotation)\n",
        "      rotation = torch.FloatTensor([[np.cos(random_rotation), -np.sin(random_rotation), 0],\n",
        "       [np.sin(random_rotation), np.cos(random_rotation), 0],\n",
        "       [0, 0, 1]])\n",
        "\n",
        "\n",
        "      inner_box = torch.matmul(inner_box,rotation)\n",
        "      pos = torch.cat([outer_box, inner_box])\n",
        "      y = torch.FloatTensor([2*np.pi*random_rotation/180])  # Label 1\n",
        "      data2 = Data(atoms=atoms, pos=pos, y=y, natoms=k+2, cell=cell)\n",
        "\n",
        "      # Edges\n",
        "      if connectivity == 'radius':\n",
        "        data2 = RadiusGraph(4)(data2)\n",
        "      if connectivity == 'voronoi':\n",
        "        voronoi_edges = compute_voronoi_edges(data2.pos[:,:-1])\n",
        "        data2.edge_index = torch.tensor(voronoi_edges, dtype=torch.long).t().contiguous()\n",
        "      elif connectivity == 'unitsphere':\n",
        "        data2 = Frame()(data2)\n",
        "      elif connectivity == 'convhull':\n",
        "        data2.edge_index = compute_convhull_edges(data2.pos[:,:-1])\n",
        "      elif connectivity == 'knn':\n",
        "        data2 = KNNGraph(3)(data2)\n",
        "      elif connectivity == 'full':\n",
        "        edge_index = []\n",
        "        for i in range(k+2):\n",
        "          for j in range(k+2):\n",
        "            edge_index.append([i,j])\n",
        "            edge_index.append([j,i])\n",
        "        data2.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()\n",
        "\n",
        "      edge_index = to_undirected(data2.edge_index)\n",
        "      edges_set = set(map(tuple, edge_index.t().tolist()))\n",
        "      data2.edge_index = torch.tensor(list(edges_set), dtype=torch.long).t()\n",
        "\n",
        "      dataset.append(data2)\n",
        "\n",
        "    return dataset\n",
        "\n",
        "# Create dataset\n",
        "# for connectivity in ['radius','knn','convhull','voronoi','full','unitsphere']:\n",
        "for connectivity in ['unitsphere']:\n",
        "  k = 6\n",
        "  print(f'Connectivity: {connectivity}')\n",
        "  dataset = create_kchains(k=k, connectivity=connectivity)\n",
        "  for data in dataset:\n",
        "      print(data)\n",
        "      plot_3d(data, lim=2*k)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vp_z2Z9MbSn"
      },
      "source": [
        "# Experiments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EWBbN5g-jB4_"
      },
      "source": [
        "## Simple Chain Experiment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D8D5ONl89czH"
      },
      "outputs": [],
      "source": [
        "# Create dataloaders\n",
        "import random\n",
        "\n",
        "from experiments.utils.train_utils import run_experiment\n",
        "from models import SchNetModel, DimeNetPPModel, SphereNetModel, ComENetModel\n",
        "\n",
        "def run(model_name,connect_list=[],cutoff_name=None):\n",
        "  k = 6\n",
        "  num_layers = 1\n",
        "  for connectivity in connect_list:\n",
        "    print('*'*20 + f'\\nConnectivity: {connectivity}\\n' + '*'*20)\n",
        "    dataset = create_kchains(k=k, connectivity=connectivity)\n",
        "    for cutoff in range(5,11):\n",
        "      print(f\"\\nCutoff: {cutoff}\")\n",
        "      print(f\"Chain Length: {k}\")\n",
        "\n",
        "\n",
        "      # Create dataloaders\n",
        "      dataloader = DataLoader(dataset[:6], batch_size=1, shuffle=True)\n",
        "      test_loader = DataLoader(dataset[6:8], batch_size=2, shuffle=False)\n",
        "      val_loader = DataLoader(dataset[8:], batch_size=2, shuffle=False)\n",
        "\n",
        "\n",
        "      # use_edge_attr = True if connectivity == 'convhull' else False\n",
        "      # use_edge_attr = False\n",
        "\n",
        "      correlation = 2\n",
        "      kwargs = {cutoff_name:cutoff} if cutoff_name else {}\n",
        "      model = {\n",
        "          \"empnn\": partial(EMPNNModel, emb_dim=256, num_radial=32, num_spherical=3),\n",
        "          \"mpnn\": MPNNModel,\n",
        "          \"schnet\": partial(SchNetModel,  num_gaussians=256, num_filters=8),\n",
        "          \"dimenet\": DimeNetPPModel,\n",
        "          \"spherenet\": partial(SphereNetModel, out_emb_channels=256),\n",
        "          \"comenet\": partial(ComENetModel, hidden_channels=128, num_radial=8, num_spherical=8),\n",
        "      }[model_name](num_layers=num_layers, in_dim=1, out_dim=1, **kwargs)\n",
        "\n",
        "      best_val_acc, test_acc, train_time = run_experiment(\n",
        "          model,\n",
        "          dataloader,\n",
        "          val_loader,\n",
        "          test_loader,\n",
        "          n_epochs=100,\n",
        "          n_times=10,\n",
        "          verbose=False,\n",
        "          device='cuda',\n",
        "      )"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# EMPNN\n",
        "run('empnn',['unitsphere'])"
      ],
      "metadata": {
        "id": "LHHkZQ-D6VW7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# MPNN\n",
        "run('mpnn',['radius', 'knn', 'voronoi'])"
      ],
      "metadata": {
        "id": "uL93JsAgxBw6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6lTJPBsu7rum"
      },
      "outputs": [],
      "source": [
        "# SCHNET\n",
        "run('schnet',['radius', 'knn', 'voronoi'],'cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QOaiX2Hk7xrO"
      },
      "outputs": [],
      "source": [
        "# DIMENET\n",
        "run('dimenet',['radius', 'knn', 'voronoi'],'cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "25eg8c6n71bs"
      },
      "outputs": [],
      "source": [
        "# SPHERENET\n",
        "# NEED TO BE CAREFUL WITH SPHERENET EMBEDDING CUTOFF\n",
        "run('spherenet',['radius', 'knn', 'voronoi'],'cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uSbAAb5A74zE"
      },
      "outputs": [],
      "source": [
        "# COMENET\n",
        "run('comenet',['radius', 'knn', 'voronoi'],'cutoff')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "machine_shape": "hm",
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 0
}