{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Are High-Degree Representations Really Uncessary in Equivariant Graph Neural Networks?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Background*: \n",
    "As symmetric graphs, five regular polyhedra are invariant to rotations up to certain angles. Interestingly, we theoretically proved that any equivariant GNN on these symmetric graphs will degenerate to a zero function if the degree of their representation is fixed to be 1. \n",
    "\n",
    "*Experiment*: \n",
    "In this notebook, we evaluate equivarinat layers on their ability to distinguish the orientation of $k$-fold strcuture in 3D space. An [$L$-fold symmetric structure](https://en.wikipedia.org/wiki/Rotational_symmetry) does not change when rotated by an angle $\\frac{2\\pi}{L}$ around a point (in 2D) or axis (3D)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import math\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "import e3nn\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.utils.plot_utils import plot_2d, plot_3d\n",
    "from experiments.utils.train_utils import run_experiment\n",
    "from models.schnet import SchNetModel\n",
    "from models.egnn import EGNNModel\n",
    "from models.gvpgnn import GVPGNNModel\n",
    "from models.tfn import TFNModel\n",
    "from models.mace import MACEModel\n",
    "from models.hegnn import HEGNNModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_rotsym_envs(fold=3):\n",
    "    dataset = []\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0 ] + [ 0 ] * fold)\n",
    "    edge_index = torch.LongTensor( [ [0] * fold, [i for i in range(1, fold+1)] ] )\n",
    "    x = torch.Tensor([1,0,0])\n",
    "    pos = [\n",
    "        torch.Tensor([0,0,0]),  # origin\n",
    "        x,   # first spoke \n",
    "    ]\n",
    "    for count in range(1, fold):\n",
    "        R = e3nn.o3.matrix_z(torch.Tensor([2*math.pi/fold * count])).squeeze(0)\n",
    "        pos.append(x @ R.T)\n",
    "    pos = torch.stack(pos)\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    pos = pos @ e3nn.o3.rand_matrix()\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "fold = 2\n",
    "dataset = create_rotsym_envs(fold)\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=1)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"hegnn\"\n",
    "correlation = 2\n",
    "max_ell = 10\n",
    "\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"egnn\": partial(EGNNModel, equivariant_pred=True),\n",
    "    \"gvp\": partial(GVPGNNModel, equivariant_pred=True),\n",
    "    \"tfn\": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),\n",
    "    \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),\n",
    "    \"hegnn\": partial(HEGNNModel, max_ell=max_ell, all_ell=False, equivariant_pred=True),\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=2,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn-dojo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
