{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "%%capture\n",
        "!pip install wandb\n",
        "!apt-get install git\n",
        "!apt autoremove\n",
        "!pip3 install awscli\n",
        "\n",
        "!mkdir -p /root/workspace/data/\n",
        "!mkdir -p /root/workspace/out/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q2ogKJ59zNcw"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "%cd /root/workspace\n",
        "\n",
        "!git clone https://github.com/chaitjo/geometric-gnn-dojo.git\n",
        "!git clone https://github.com/Open-Catalyst-Project/ocp.git\n",
        "!pip3 install -r ./steerable-v1/requirements.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J0oYjpC1axi1"
      },
      "outputs": [],
      "source": [
        "%%capture\n",
        "%cd /root/workspace/ocp/\n",
        "!pip3 install -e .\n",
        "!pip3 install lmdb\n",
        "!pip3 install orjson"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8XHTFQfz_hsY"
      },
      "outputs": [],
      "source": [
        "%cd /root/workspace/steerable-v1/\n",
        "!git stash\n",
        "!git pull"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QPGcEH8K_lSb"
      },
      "outputs": [],
      "source": [
        "%cd /root/workspace/geometric-gnn-dojo/\n",
        "!git stash\n",
        "!git pull"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lHOxeU9Gw9HS"
      },
      "outputs": [],
      "source": [
        "# %%capture\n",
        "%cd /root/workspace\n",
        "!cp ./steerable-v1/train_utils.py ./geometric-gnn-dojo/experiments/utils/train_utils.py # remove once iclr is pulled\n",
        "\n",
        "!cp ./steerable-v1/comenet.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled\n",
        "!echo \"from models.comenet import ComENetModel\" >> ./geometric-gnn-dojo/models/__init__.py\n",
        "\n",
        "!cp ./steerable-v1/painn.py ./geometric-gnn-dojo/models/painn.py # remove once iclr is pulled\n",
        "!echo \"from models.painn import PaiNN\" >> ./geometric-gnn-dojo/models/__init__.py\n",
        "\n",
        "!cp ./steerable-v1/escn.py ./geometric-gnn-dojo/models/escn.py # remove once iclr is pulled\n",
        "!echo \"from models.escn import eSCN\" >> ./geometric-gnn-dojo/models/__init__.py\n",
        "\n",
        "!cp ./steerable-v1/equiformer_v2.py ./geometric-gnn-dojo/models/equiformer.py # remove once iclr is pulled\n",
        "!echo \"from models.equiformer import EquiformerV2_OC20\" >> ./geometric-gnn-dojo/models/__init__.py\n",
        "\n",
        "!cp ./steerable-v1/gemnet_t.py ./geometric-gnn-dojo/models/gemnet_t.py # remove once iclr is pulled\n",
        "!echo \"from models.gemnet_t import GemNetT\" >> ./geometric-gnn-dojo/models/__init__.py\n",
        "\n",
        "!cp ./steerable-v1/gemnet_q.py ./geometric-gnn-dojo/models/gemnet_q.py # remove once iclr is pulled\n",
        "!echo \"from models.gemnet_q import GemNetOC\" >> ./geometric-gnn-dojo/models/__init__.py\n",
        "\n",
        "!cp ./steerable-v1/_steerable.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled\n",
        "!cp ./steerable-v1/segnn.py ./geometric-gnn-dojo/models/ # remove once iclr is pulled\n",
        "!echo \"from models.segnn import SEGNN\" >> ./geometric-gnn-dojo/models/__init__.py"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UkC8_r-NLdXo"
      },
      "source": [
        "# Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nrIzD1hSLhT_"
      },
      "source": [
        "# Datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cOFCwx4W7X1d"
      },
      "source": [
        "## Simple Chain Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6MgPmXp7RAl"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "sys.path.append('/root/workspace/geometric-gnn-dojo/')\n",
        "\n",
        "import torch\n",
        "import torch_geometric\n",
        "from torch_geometric.data import Data\n",
        "from torch_geometric.loader import DataLoader\n",
        "from torch_geometric.utils import to_undirected\n",
        "import e3nn\n",
        "from functools import partial\n",
        "\n",
        "from torch_geometric.seed import seed_everything\n",
        "\n",
        "from experiments.utils.plot_utils import plot_3d\n",
        "\n",
        "def create_kchains(k,n):\n",
        "    seed_everything(10)\n",
        "    assert k >= 2\n",
        "    assert n >= 1\n",
        "\n",
        "    dataset = []\n",
        "    for i in range(n):\n",
        "      M = torch.rand(3,3)\n",
        "      Q, _ = torch.linalg.qr(M, mode='complete')\n",
        "      b = torch.rand(3)\n",
        "\n",
        "      # Graph 0\n",
        "      atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
        "      cell = torch.diag(torch.ones(3,dtype=torch.float)).view(1,3,3)\n",
        "      edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n",
        "      pos = torch.FloatTensor(\n",
        "          [[-4, -3, 0]] +\n",
        "          [[0, 5*i , 0] for i in range(k)] +\n",
        "          [[4, 5*(k-1) + 3, 0]]\n",
        "      )\n",
        "      # center_of_mass = torch.mean(pos, dim=0)\n",
        "      # pos = pos - center_of_mass\n",
        "      y = torch.LongTensor([0])  # Label gvp0\n",
        "      # data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
        "      # data1.edge_index = to_undirected(data1.edge_index)\n",
        "      # dataset.append(data1)\n",
        "      transf_pos = [Q@val+b for val in pos]\n",
        "      transf_pos = torch.vstack(transf_pos)\n",
        "      data1 = Data(atoms=atoms, edge_index=edge_index, pos=transf_pos, y=y, natoms=k+2, cell=cell)\n",
        "      data1.edge_index = to_undirected(data1.edge_index)\n",
        "      dataset.append(data1)\n",
        "\n",
        "      # Graph 1\n",
        "      atoms = torch.LongTensor( [0] + [0] + [0]*(k-1) + [0] )\n",
        "      edge_index = torch.LongTensor( [ [i for i in range((k+2) - 1)], [i for i in range(1, k+2)] ] )\n",
        "      pos = torch.FloatTensor(\n",
        "          [[4, -3, 0]] +\n",
        "          [[0, 5*i , 0] for i in range(k)] +\n",
        "          [[4, 5*(k-1) + 3, 0]]\n",
        "      )\n",
        "      # center_of_mass = torch.mean(pos, dim=0)\n",
        "      # pos = pos - center_of_mass\n",
        "      y = torch.LongTensor([1])  # Label 1\n",
        "      # data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
        "      # data2.edge_index = to_undirected(data2.edge_index)\n",
        "      # dataset.append(data2)\n",
        "      transf_pos = [Q@val+b for val in pos]\n",
        "      transf_pos = torch.vstack(transf_pos)\n",
        "      data2 = Data(atoms=atoms, edge_index=edge_index, pos=transf_pos, y=y, natoms=k+2, cell=cell)\n",
        "      data2.edge_index = to_undirected(data2.edge_index)\n",
        "      dataset.append(data2)\n",
        "\n",
        "    return dataset\n",
        "\n",
        "# Create dataset\n",
        "k = 4\n",
        "dataset = create_kchains(k=k,n=1)\n",
        "for data in dataset:\n",
        "    print(data.pos)\n",
        "    plot_3d(data, lim=2*k)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vp_z2Z9MbSn"
      },
      "source": [
        "# Experiments"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EWBbN5g-jB4_"
      },
      "source": [
        "## Simple Chain Experiment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D8D5ONl89czH"
      },
      "outputs": [],
      "source": [
        "# Create dataloaders\n",
        "import random\n",
        "\n",
        "from experiments.utils.train_utils import run_experiment\n",
        "from models import SchNetModel, DimeNetPPModel, SphereNetModel, ComENetModel, GemNetT, GemNetOC, EGNNModel, GVPGNNModel, PaiNN, eSCN, EquiformerV2_OC20, MACEModel, TFNModel, SEGNN\n",
        "\n",
        "\n",
        "total = 50\n",
        "seed_everything(10)\n",
        "permuted_g1 = list(range(total))\n",
        "permuted_g2 = list(range(total))\n",
        "random.shuffle(permuted_g1)\n",
        "random.shuffle(permuted_g2)\n",
        "\n",
        "print('split_pt1',permuted_g1)\n",
        "# print('split_pt2',permuted_g2)\n",
        "\n",
        "\n",
        "def run(model_name,cutoff_name=None):\n",
        "  for k, num_layers in zip([2,2,2,3,3,3,3,4,4,4,4,4],[1,2,3,1,2,3,4,2,3,4,5,6]):\n",
        "      train_n = int(.5*total)\n",
        "      val_n = int(.3*total)\n",
        "      test_n = int(.2*total)\n",
        "\n",
        "      dataset = create_kchains(k=k, n=total)\n",
        "\n",
        "      train_data = [dataset[2*i+1] for i in permuted_g1[:train_n]]\n",
        "      train_data = train_data + [dataset[2*i] for i in permuted_g1[:train_n]]\n",
        "      dataloader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
        "\n",
        "      val_data = [dataset[2*i+1] for i in permuted_g1[train_n:train_n+val_n]]\n",
        "      val_data = val_data + [dataset[2*i] for i in permuted_g1[train_n:train_n+val_n]]\n",
        "      val_loader = DataLoader(val_data, batch_size=32, shuffle=False)\n",
        "\n",
        "      test_data = [dataset[2*i+1] for i in permuted_g1[train_n+val_n:]]\n",
        "      test_data = test_data + [dataset[2*i] for i in permuted_g1[train_n+val_n:]]\n",
        "      test_loader = DataLoader(test_data, batch_size=32, shuffle=False)\n",
        "\n",
        "      for name,count_data in zip(['train','val','test'],[train_data, val_data, test_data]):\n",
        "        all_y_values = torch.cat([data.y for data in count_data])\n",
        "        unique_values, counts = all_y_values.unique(return_counts=True)\n",
        "        value_counts = {value.item(): count.item() for value, count in zip(unique_values, counts)}\n",
        "        print(name,value_counts)\n",
        "\n",
        "      print(f\"\\nNumber of layers: {num_layers}\")\n",
        "      print(f\"Chain Length: {k}\")\n",
        "\n",
        "      correlation = 2\n",
        "      kwargs = {cutoff_name:5.1} if cutoff_name else {}\n",
        "      model = {\n",
        "          # INV\n",
        "          \"schnet\": SchNetModel,\n",
        "          \"dimenet\": DimeNetPPModel,\n",
        "          \"spherenet\": SphereNetModel,\n",
        "          \"comenet\": partial(ComENetModel, hidden_channels=128, num_output_layers=2),\n",
        "          # Equiv\n",
        "          \"egnn\": EGNNModel,\n",
        "          \"gvp\": partial(GVPGNNModel, s_dim=32, v_dim=1),\n",
        "          # Steerable\n",
        "          \"mace_1\": partial(MACEModel, correlation=correlation, max_ell=1),\n",
        "          \"mace_2\": partial(MACEModel, correlation=correlation, max_ell=2),\n",
        "          \"escn_1\": partial(eSCN, lmax_list=[1], mmax_list=[1], hidden_channels=256),#, sphere_channels= 16, hidden_channels = 128, edge_channels = 16, num_sphere_samples = 16),\n",
        "          \"escn_2\": partial(eSCN, lmax_list=[2], mmax_list=[2], hidden_channels=(256*4//9)),#, sphere_channels= 16, hidden_channels = 128, edge_channels = 16, num_sphere_samples = 16),\n",
        "          \"equiformer_0\":partial(EquiformerV2_OC20, attn_hidden_channels=64, lmax_list=[0], mmax_list=[0]),\n",
        "          \"equiformer_1\":partial(EquiformerV2_OC20, attn_hidden_channels=16, lmax_list=[1], mmax_list=[1]),\n",
        "          \"equiformer_2\":partial(EquiformerV2_OC20, attn_hidden_channels=7, lmax_list=[2], mmax_list=[2]),\n",
        "          # If Time\n",
        "          \"gemnet_t\": GemNetT,\n",
        "          \"gemnet_q\": GemNetOC,\n",
        "          \"painn\":PaiNN,\n",
        "          \"tfn\": TFNModel,\n",
        "          \"segnn\": SEGNN,\n",
        "      }[model_name](num_layers=num_layers, in_dim=1, out_dim=2, **kwargs)\n",
        "\n",
        "      best_val_acc, test_acc, train_time = run_experiment(\n",
        "          model,\n",
        "          dataloader,\n",
        "          val_loader,\n",
        "          test_loader,\n",
        "          n_epochs=100,\n",
        "          n_times=10,\n",
        "          verbose=False,\n",
        "          device='cuda',\n",
        "      )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6lTJPBsu7rum"
      },
      "outputs": [],
      "source": [
        "# SCHNET\n",
        "run('schnet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QOaiX2Hk7xrO"
      },
      "outputs": [],
      "source": [
        "# DIMENET\n",
        "run('dimenet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "25eg8c6n71bs"
      },
      "outputs": [],
      "source": [
        "# SPHERENET\n",
        "run('spherenet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uSbAAb5A74zE"
      },
      "outputs": [],
      "source": [
        "# COMENET\n",
        "run('comenet','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hXMk2O8077C7"
      },
      "outputs": [],
      "source": [
        "#EGNN\n",
        "run('egnn')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kexHoFxo7-4k"
      },
      "outputs": [],
      "source": [
        "#GVP\n",
        "run('gvp','r_max')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-D2dOW7E4oX-"
      },
      "outputs": [],
      "source": [
        "#eSCN\n",
        "run('escn_1','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QFFlr6Aq7BVl"
      },
      "outputs": [],
      "source": [
        "#eSCN\n",
        "run('escn_2','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ify1LzqImWGd"
      },
      "outputs": [],
      "source": [
        "#MACE\n",
        "run('mace_1','r_max')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ifAV4R5x6r5"
      },
      "outputs": [],
      "source": [
        "#MACE\n",
        "run('mace_2','r_max')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T3E_UUnQOOd7"
      },
      "outputs": [],
      "source": [
        "#Equiformer\n",
        "run('equiformer_0','max_radius')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3i21lGU44tG6"
      },
      "outputs": [],
      "source": [
        "#Equiformer\n",
        "run('equiformer_1','max_radius')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xAGesm1i7JJ0"
      },
      "outputs": [],
      "source": [
        "#Equiformer\n",
        "run('equiformer_2','max_radius')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xrZfUSk-gZSw"
      },
      "outputs": [],
      "source": [
        "#Equiformer\n",
        "run('equiformer_2','max_radius')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YV72KvIjyINg"
      },
      "outputs": [],
      "source": [
        "#PaiNN\n",
        "run('painn','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DiSWZsTk4iYZ"
      },
      "outputs": [],
      "source": [
        "#GemNetT\n",
        "run('gemnet_t','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4q7ACGgE5LxE"
      },
      "outputs": [],
      "source": [
        "#GemNetQ\n",
        "run('gemnet_q','cutoff')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s6QxwWga44V0"
      },
      "outputs": [],
      "source": [
        "#TFN\n",
        "run('tfn','r_max')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hFNMzgSe5DsF"
      },
      "outputs": [],
      "source": [
        "#SEGNN\n",
        "run('segnn')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1IWRh0TKlQXA"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "machine_shape": "hm",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
