{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Are High-Degree Representations Really Uncessary in Equivariant Graph Neural Networks?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In actual scenarios, slight perturbations (such as molecular vibrations) will destroy strict symmetry and no longer produce the degradation phenomenon. We therefore designed this perturbation experiment for a simple study and were surprised to find that HEGNN can bring better robustness through the introduction of high-degree steerable features."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "import e3nn\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.utils.plot_utils import plot_2d, plot_3d\n",
    "from experiments.utils.train_utils import run_experiment\n",
    "from models.schnet import SchNetModel\n",
    "from models.egnn import EGNNModel\n",
    "from models.gvpgnn import GVPGNNModel\n",
    "from models.tfn import TFNModel\n",
    "from models.mace import MACEModel\n",
    "from models.hegnn import HEGNNModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from experiments.data.reg_poly_coords import pos_dict, ver_num_dict\n",
    "\n",
    "eps = 0.01\n",
    "\n",
    "# NSR = 1 / SNR\n",
    "def add_noise(x : torch.FloatTensor, NSR : float = 0.05):\n",
    "    x = x - torch.mean(x, dim=0, keepdim=True)\n",
    "    avg_norm = torch.mean(torch.norm(x, dim=1), dim=0)\n",
    "    noise = torch.randn_like(x) * avg_norm * NSR\n",
    "    return x + noise\n",
    "\n",
    "def disturbance(NSR : float = 0.05):\n",
    "    for k in pos_dict.keys():\n",
    "        pos_dict[k] = add_noise(pos_dict[k], NSR)\n",
    "    return\n",
    "\n",
    "disturbance(eps)\n",
    "\n",
    "def get_edge(pos):\n",
    "    assert isinstance(pos, torch.Tensor)\n",
    "    num_node = pos.size(0)\n",
    "    edge_index = [[], []]\n",
    "    for i in range(num_node):\n",
    "        for j in range(i+1, num_node):\n",
    "            edge_index[0].append(i)\n",
    "            edge_index[1].append(j)\n",
    "\n",
    "    return torch.LongTensor(edge_index)\n",
    "\n",
    "# create environments\n",
    "def create_envs(face_num=20):\n",
    "    dataset = []\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.zeros(ver_num_dict[face_num], dtype=torch.long)\n",
    "    pos = pos_dict[face_num]\n",
    "    edge_index = get_edge(pos)\n",
    "\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "\n",
    "    # Environment 1\n",
    "    # pos = torch.matmul(pos, torch.Tensor(random_rotation_matrix()))\n",
    "    pos = pos @ e3nn.o3.rand_matrix()\n",
    "    \n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# creat dataset\n",
    "face_num = 4    # Only select from 4, 6, 8, 12, 20\n",
    "dataset = create_envs(face_num)\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=2)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"hegnn\"\n",
    "\n",
    "correlation = 2\n",
    "max_ell = 3\n",
    "\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"egnn\": partial(EGNNModel, equivariant_pred=True),\n",
    "    \"gvp\": partial(GVPGNNModel, equivariant_pred=True),\n",
    "    \"tfn\": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),\n",
    "    \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),\n",
    "    \"hegnn\": partial(HEGNNModel, max_ell=max_ell, all_ell=False, equivariant_pred=True),\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "for batch in test_loader:\n",
    "    print(model(batch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set parameters\n",
    "model_name = \"hegnn\"\n",
    "\n",
    "correlation = 2\n",
    "max_ell = 3\n",
    "\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"egnn\": partial(EGNNModel, equivariant_pred=True),\n",
    "    \"gvp\": partial(GVPGNNModel, equivariant_pred=True),\n",
    "    \"tfn\": partial(TFNModel, max_ell=max_ell, equivariant_pred=True),\n",
    "    \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation, equivariant_pred=True),\n",
    "    \"hegnn\": partial(HEGNNModel, max_ell=max_ell, all_ell=True, equivariant_pred=True),\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn-dojo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
