{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GS4GalqayyJ2"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install wandb\n",
        "!apt-get install git\n",
        "!apt autoremove\n",
        "!pip3 install awscli\n",
        "\n",
        "!mkdir -p /root/workspace/data/\n",
        "!mkdir -p /root/workspace/out/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q2ogKJ59zNcw"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "%cd /root/workspace\n",
        "!git clone https://github.com/chaitjo/geometric-gnn-dojo.git\n",
        "!pip3 install -r /root/workspace/UnitSphere/requirements.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QPGcEH8K_lSb"
      },
      "outputs": [],
      "source": [
        "%cd /root/workspace/geometric-gnn-dojo/\n",
        "!git stash\n",
        "!git pull"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lHOxeU9Gw9HS"
      },
      "outputs": [],
      "source": [
        "# %%capture\n",
        "%cd /root/workspace\n",
        "!cp /root/workspace/UnitSphere/ext/train_nll_utils.py ./geometric-gnn-dojo/experiments/utils/train_utils.py # remove once iclr is pulled\n",
        "!cp /root/workspace/UnitSphere/ext/comenet.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled\n",
        "!echo \"from models.comenet import ComENetModel\" >> ./geometric-gnn-dojo/models/__init__.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UkC8_r-NLdXo"
      },
      "source": [
        "# Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nrIzD1hSLhT_"
      },
      "source": [
        "# Datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ljG7Au3a1NN8"
      },
      "outputs": [],
      "source": [
        "from abc import ABCMeta\n",
        "import ast\n",
        "\n",
        "import time\n",
        "import torch\n",
        "from typing import Tuple\n",
        "from torch import Tensor\n",
        "from torch_sparse import SparseTensor\n",
        "import torch.nn.functional as F\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import sys\n",
        "\n",
        "sys.path.append('/root/workspace/alignment/pyorbit/utils/')\n",
        "from scipy.spatial import ConvexHull\n",
        "\n",
        "def triplets(\n",
        "    edge_index: Tensor,\n",
        "    num_nodes: int,\n",
        ") -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n",
        "    row, col = edge_index  # j->i\n",
        "\n",
        "    value = torch.arange(row.size(0), device=row.device)\n",
        "    adj_t = SparseTensor(row=col, col=row, value=value,\n",
        "                         sparse_sizes=(num_nodes, num_nodes))\n",
        "    adj_t_row = adj_t[row]\n",
        "    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)\n",
        "\n",
        "    # Node indices (k->j->i) for triplets.\n",
        "    idx_i = col.repeat_interleave(num_triplets)\n",
        "    idx_j = row.repeat_interleave(num_triplets)\n",
        "    idx_k = adj_t_row.storage.col()\n",
        "    mask = idx_i != idx_k  # Remove i == k triplets.\n",
        "    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]\n",
        "\n",
        "    # Edge indices (k-j, j->i) for triplets.\n",
        "    idx_kj = adj_t_row.storage.value()[mask]\n",
        "    idx_ji = adj_t_row.storage.row()[mask]\n",
        "\n",
        "    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji\n",
        "\n",
        "class Frame(metaclass=ABCMeta):\n",
        "    def __init__(self, tol=1e-2, *args, **kwargs):\n",
        "        super().__init__()\n",
        "        self.tol = tol\n",
        "\n",
        "    def get_frame(self, data, *args, **kwargs):\n",
        "\n",
        "        # TRANSLATION INVARIANCE\n",
        "        data = self.check_type(data) # Assert Type\n",
        "        data = data - np.mean(data, axis=0) # Assert Centered\n",
        "        data = data[np.linalg.norm(data, axis=1) > self.tol]\n",
        "\n",
        "        # PROJECT ONTO SPHERE\n",
        "        distances = np.linalg.norm(data, axis=1, keepdims=False)\n",
        "        shell_data =  data/np.linalg.norm(data, axis=1, keepdims=True)\n",
        "        data_duplicate_dict = {}\n",
        "        for i, d in enumerate(shell_data):\n",
        "          for j in range(i+1, len(shell_data)):\n",
        "            if np.allclose(d, shell_data[j]):\n",
        "              if i not in data_duplicate_dict:\n",
        "                data_duplicate_dict[i] = [j]\n",
        "              else:\n",
        "                data_duplicate_dict[i] += [j]\n",
        "\n",
        "        # FIND NEAREST TWO NEIGHBORS\n",
        "        edges = []\n",
        "        hull = ConvexHull(shell_data, qhull_options='Qx')\n",
        "        for simplex in hull.simplices:\n",
        "          edges.append(simplex)\n",
        "          if simplex[0] in data_duplicate_dict:\n",
        "            for j in data_duplicate_dict[simplex[0]]:\n",
        "              edges.append([j, simplex[1]])\n",
        "              # edges.append([simplex[0], j])\n",
        "          if simplex[1] in data_duplicate_dict:\n",
        "            for j in data_duplicate_dict[simplex[1]]:\n",
        "              edges.append([simplex[0], j])\n",
        "              # edges.append([j, simplex[1]])\n",
        "\n",
        "        edge_index = np.array(list(edges))\n",
        "        edge_index = torch.from_numpy(edge_index).T.to(torch.long)\n",
        "        edge_index = to_undirected(edge_index)\n",
        "        # Generate edge attributes\n",
        "        print(distances.shape)\n",
        "        edge_attr = torch.from_numpy(distances).reshape(-1,1)\n",
        "\n",
        "        # DIMENET WAY\n",
        "        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(\n",
        "            edge_index,\n",
        "            num_nodes=shell_data.shape[0])\n",
        "\n",
        "        # Calculate distances.\n",
        "        shell_data = torch.from_numpy(shell_data)\n",
        "        dist = (shell_data[i] - shell_data[j]).pow(2).sum(dim=-1).sqrt()\n",
        "        # Calculate angles.\n",
        "        pos_jk, pos_ij = shell_data[idx_j] - shell_data[idx_k], shell_data[idx_i] - shell_data[idx_j]\n",
        "        a = (pos_ij * pos_jk).sum(dim=-1)\n",
        "        b = pos_ij[:, 0] * pos_jk[:, 1] - pos_ij[:, 1] * pos_jk[:, 0] #torch.cross(pos_ij, pos_jk, dim=1).norm(dim=-1)\n",
        "        angle = torch.atan2(b, a)\n",
        "        angle_attr = torch.zeros(shell_data.shape[0], shell_data.shape[0])\n",
        "        for i,index in enumerate(idx_kj):\n",
        "          edge = edge_index[:,index]\n",
        "          angle_attr[edge[0], edge[1]] = angle[i]\n",
        "          angle_attr[edge[1], edge[0]] = angle[i]\n",
        "        # make 8xN matrix filled in with angles or zeros elsewhere\n",
        "        edge_attr = torch.cat([edge_attr, angle_attr], dim=1)\n",
        "\n",
        "        return edge_index, edge_attr\n",
        "\n",
        "    def check_type(self, data, *args, **kwargs):\n",
        "        if isinstance(data, torch.Tensor):\n",
        "            return data.detach().cpu().numpy()\n",
        "        elif isinstance(data, np.ndarray):\n",
        "            return data\n",
        "        else:\n",
        "            raise TypeError(f\"Data type not supported {type(data)}\")\n",
        "# dataset = create_kchains(k=6, connectivity='convhull')\n",
        "# for data in dataset:\n",
        "#     plot_3d(data, lim=2*k)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wbaGRFGPuh3E"
      },
      "outputs": [],
      "source": [
        "from scipy.spatial import Delaunay, delaunay_plot_2d\n",
        "from scipy.spatial import Voronoi, voronoi_plot_2d\n",
        "\n",
        "\n",
        "def compute_convhull_edges(pos, vis=False):\n",
        "        edges = []\n",
        "        hull = ConvexHull(pos, qhull_options='Qx')\n",
        "        for simplex in hull.simplices:\n",
        "          edges.append(simplex)\n",
        "        edge_index = np.array(list(edges))\n",
        "        edge_index = torch.from_numpy(edge_index).T.to(torch.long)\n",
        "        edge_index = to_undirected(edge_index)\n",
        "        return edge_index\n",
        "\n",
        "\n",
        "def compute_voronoi_edges(pos, vis=False):\n",
        "    pos_np = pos.numpy()  # Convert to numpy array\n",
        "    tri = Delaunay(pos)\n",
        "    vor = Voronoi(pos_np)\n",
        "    if vis:\n",
        "      try:\n",
        "        voronoi_plot_2d(vor)\n",
        "        delaunay_plot_2d(tri)\n",
        "      except:\n",
        "        pass\n",
        "    rows, cols = tri.vertex_neighbor_vertices\n",
        "    edges = []\n",
        "    for i in range(len(rows) - 1):\n",
        "        start, end = rows[i], rows[i + 1]\n",
        "        neighbors = cols[start:end]\n",
        "        for neighbor in neighbors:\n",
        "            edges.append([i, neighbor])\n",
        "\n",
        "    return edges\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cOFCwx4W7X1d"
      },
      "source": [
        "## Simple Chain Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6MgPmXp7RAl"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "sys.path.append('/root/workspace/geometric-gnn-dojo/')\n",
        "\n",
        "import torch\n",
        "import torch_geometric\n",
        "from torch_geometric.data import Data\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.transforms import KNNGraph\n",
        "from torch_geometric.utils import to_undirected\n",
        "import e3nn\n",
        "from functools import partial\n",
        "\n",
        "from torch_geometric.seed import seed_everything\n",
        "\n",
        "from experiments.utils.plot_utils import plot_3d\n",
        "\n",
        "def create_kchains(k,connectivity='radius'):\n",
        "    seed_everything(10)\n",
        "    assert k >= 2\n",
        "    assert connectivity in ['radius', 'knn', 'voronoi', 'convhull', 'full', 'unitsphere']\n",
        "\n",
        "    dataset = []\n",
        "\n",
        "    # Graph 0\n",
        "    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
        "    cell = torch.diag(torch.ones(3,dtype=torch.float)).view(1,3,3)\n",
        "    edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n",
        "    pos = torch.FloatTensor(\n",
        "        [[-4, -3, 0]] +\n",
        "        [[0, 5*i , 0] for i in range(k)] +\n",
        "        [[4, 5*(k-1) + 3, 0]]\n",
        "    )\n",
        "    y = torch.LongTensor([0])  # Label gvp0\n",
        "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y, natoms=k+2, cell=cell)\n",
        "\n",
        "    # Edges\n",
        "    if connectivity == 'voronoi':\n",
        "      voronoi_edges = compute_voronoi_edges(data1.pos[:,:-1])\n",
        "      data1.edge_index = torch.tensor(voronoi_edges, dtype=torch.long).t().contiguous()\n",
        "    elif connectivity == 'convhull':\n",
        "      data1.edge_index = compute_convhull_edges(data1.pos[:,:-1])\n",
        "    elif connectivity == 'unitsphere':\n",
        "      frame = Frame()\n",
        "      data1.edge_index, data1.edge_attr =  frame.get_frame(data1.pos[:,:-1])\n",
        "    elif connectivity == 'knn':\n",
        "      data1 = KNNGraph(3)(data1)\n",
        "    elif connectivity == 'full':\n",
        "      edge_index = []\n",
        "      for i in range(k+2):\n",
        "        for j in range(k+2):\n",
        "          edge_index.append([i,j])\n",
        "          edge_index.append([j,i])\n",
        "      data1.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()\n",
        "\n",
        "    edge_index = to_undirected(data1.edge_index)\n",
        "    edges_set = set(map(tuple, edge_index.t().tolist()))\n",
        "    data1.edge_index = torch.tensor(list(edges_set), dtype=torch.long).t()\n",
        "\n",
        "    dataset.append(data1)\n",
        "\n",
        "    # Graph 1\n",
        "    atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
        "    edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n",
        "    pos = torch.FloatTensor(\n",
        "        [[4, -3, 0]] +\n",
        "        [[0, 5*i , 0] for i in range(k)] +\n",
        "        [[4, 5*(k-1) + 3, 0]]\n",
        "    )\n",
        "    y = torch.LongTensor([1])  # Label 1\n",
        "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y, natoms=k+2, cell=cell)\n",
        "\n",
        "    # Edges\n",
        "    if connectivity == 'voronoi':\n",
        "      voronoi_edges = compute_voronoi_edges(data2.pos[:,:-1])\n",
        "      data2.edge_index = torch.tensor(voronoi_edges, dtype=torch.long).t().contiguous()\n",
        "    elif connectivity == 'unitsphere':\n",
        "      frame = Frame()\n",
        "      data2.edge_index, data2.edge_attr =  frame.get_frame(data2.pos[:,:-1])\n",
        "    elif connectivity == 'convhull':\n",
        "      data2.edge_index = compute_convhull_edges(data2.pos[:,:-1])\n",
        "    elif connectivity == 'knn':\n",
        "      data2 = KNNGraph(3)(data2)\n",
        "    elif connectivity == 'full':\n",
        "      edge_index = []\n",
        "      for i in range(k+2):\n",
        "        for j in range(k+2):\n",
        "          edge_index.append([i,j])\n",
        "          edge_index.append([j,i])\n",
        "      data2.edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()\n",
        "\n",
        "    edge_index = to_undirected(data2.edge_index)\n",
        "    edges_set = set(map(tuple, edge_index.t().tolist()))\n",
        "    data2.edge_index = torch.tensor(list(edges_set), dtype=torch.long).t()\n",
        "\n",
        "    dataset.append(data2)\n",
        "\n",
        "    return dataset\n",
        "\n",
        "# Create dataset\n",
        "for connectivity in ['radius','knn','convhull','voronoi','full','unitsphere']:\n",
        "  k = 6\n",
        "  print(f'Connectivity: {connectivity}')\n",
        "  dataset = create_kchains(k=k, connectivity=connectivity)\n",
        "  for data in dataset:\n",
        "      print(data.edge_index)\n",
        "      plot_3d(data, lim=2*k)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vp_z2Z9MbSn"
      },
      "source": [
        "# Experiments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EWBbN5g-jB4_"
      },
      "source": [
        "## Simple Chain Experiment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D8D5ONl89czH"
      },
      "outputs": [],
      "source": [
        "# Create dataloaders\n",
        "import random\n",
        "\n",
        "from experiments.utils.train_utils import run_experiment\n",
        "from models import SchNetModel, DimeNetPPModel, SphereNetModel, ComENetModel\n",
        "\n",
        "def run(model_name,cutoff_name=None):\n",
        "  k = 6\n",
        "  num_layers = 1\n",
        "  for connectivity in ['radius','knn','convhull','voronoi','full','unitsphere']:\n",
        "  # for connectivity in ['full', 'unitsphere']:\n",
        "      print('*'*20 + f'\\nConnectivity: {connectivity}\\n' + '*'*20)\n",
        "      dataset = create_kchains(k=k, connectivity=connectivity)\n",
        "      for cutoff in range(1,11):\n",
        "      # for cutoff in [8]:\n",
        "        print(f\"\\nCutoff: {cutoff}\")\n",
        "        print(f\"Chain Length: {k}\")\n",
        "\n",
        "\n",
        "        # Create dataloaders\n",
        "        dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
        "        val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
        "        test_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
        "\n",
        "\n",
        "        # use_edge_attr = True if connectivity == 'convhull' else False\n",
        "        # use_edge_attr = False\n",
        "\n",
        "        correlation = 2\n",
        "        kwargs = {cutoff_name:cutoff} if cutoff_name else {}\n",
        "        model = {\n",
        "            # INV\n",
        "            \"schnet\": partial(SchNetModel,  num_gaussians=64, num_filters=64, pool='mean'),\n",
        "            \"dimenet\": DimeNetPPModel,\n",
        "            \"spherenet\": partial(SphereNetModel, out_emb_channels=256),\n",
        "            \"comenet\": partial(ComENetModel, hidden_channels=128, num_radial=8, num_spherical=8),\n",
        "        }[model_name](num_layers=num_layers, in_dim=1, out_dim=2, **kwargs)\n",
        "\n",
        "        best_val_acc, test_acc, train_time = run_experiment(\n",
        "            model,\n",
        "            dataloader,\n",
        "            val_loader,\n",
        "            test_loader,\n",
        "            n_epochs=100,\n",
        "            n_times=10,\n",
        "            verbose=False,\n",
        "            device='cuda',\n",
        "        )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6lTJPBsu7rum"
      },
      "outputs": [],
      "source": [
        "# SCHNET\n",
        "run('schnet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QOaiX2Hk7xrO"
      },
      "outputs": [],
      "source": [
        "# DIMENET\n",
        "run('dimenet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "25eg8c6n71bs"
      },
      "outputs": [],
      "source": [
        "# SPHERENET\n",
        "run('spherenet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uSbAAb5A74zE"
      },
      "outputs": [],
      "source": [
        "# COMENET\n",
        "run('comenet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "a0gA1wedmqoU"
      },
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}