{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synthetic Example: $k$-chain geometric graphs\n",
    "\n",
    "We consider $k$-chain geometric graphs, which generalise the examples from PaiNN (Schütt-etal.). \n",
    "Each pair of $k$-chains consists of $k+2$ nodes with $k$ nodes arranged in a line and differentiated by the orientation of the $2$ end points.\n",
    "Thus, $k$-chain graphs are $(k-1)$-hop distinguishable, and $(k-1)$ iterations of GWL are sufficient to distinguish them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append('../../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.loader import DataLoader\n",
    "import e3nn\n",
    "from e3nn import o3\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))\n",
    "\n",
    "from data_utils import create_kchains\n",
    "from plot_utils import plot_2d, plot_3d\n",
    "from train_utils import run_experiment\n",
    "from src.models import MPNNModel, EGNNModel, GVPGNNModel, TFNModel, SchNetModel, DimeNetPPModel, MACEModel\n",
    "\n",
    "# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)\n",
    "# print(f\"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}\")\n",
    "# print(f\"Is MPS available? {torch.backends.mps.is_available()}\")\n",
    "\n",
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    "# device = \"cpu\"\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 4\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_kchains(k=k)\n",
    "for data in dataset:\n",
    "    # plot_2d(data, lim=5*k)\n",
    "    plot_3d(data, lim=5*k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model\n",
    "model_name = \"tfn\"\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "\n",
    "for num_layers in range(k//2+1 , k//2 + 5):\n",
    "\n",
    "    print(f\"\\nNumber of layers: {num_layers}\")\n",
    "    \n",
    "    model = {\n",
    "        \"mpnn\": MPNNModel,\n",
    "        \"schnet\": SchNetModel,\n",
    "        \"dimenet\": DimeNetPPModel,\n",
    "        \"egnn\": EGNNModel,\n",
    "        \"gvp\": GVPGNNModel,\n",
    "        \"tfn\": TFNModel,\n",
    "        \"mace\": partial(MACEModel, correlation=2),\n",
    "    }[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
    "    \n",
    "    best_val_acc, test_acc, train_time = run_experiment(\n",
    "        model, \n",
    "        dataloader,\n",
    "        val_loader, \n",
    "        test_loader,\n",
    "        n_epochs=100,\n",
    "        n_times=10,\n",
    "        device=device,\n",
    "        verbose=False\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('pyg')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "68e86ebbde0223d423b9b665c3b1d327e092bdc1edc82ebfb87a5b0b6583dcb8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
