{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Completeness Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import operator\n",
    "sys.path.append('../')\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "import e3nn\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))\n",
    "\n",
    "from experiments.utils.plot_utils import plot_3d\n",
    "from experiments.utils.train_utils import run_experiment\n",
    "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n",
    "from models_cpl import BasicModel_cpl, SchNetModel_cpl, GVPGNNModel_cpl, EGNNModel_cpl, TFNModel_cpl\n",
    "\n",
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Two-body counterexample\n",
    "\n",
    "Pair of local neighbourhoods that are indistinguishable when comparing their set of $2$-body scalars, i.e. the unordered set of pairwise distances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_two_body_envs():\n",
    "    dataset = []\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [5, 0, 0],\n",
    "        [3, 0, 4]\n",
    "    ])\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    atoms = torch.LongTensor([ 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0], [1, 2] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [5, 0, 0],\n",
    "        [-5, 0, 0]\n",
    "    ])\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_two_body_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model\n",
    "model_name = \"basic_cpl\"\n",
    "\n",
    "correlation = 2\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"spherenet\": SphereNetModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "    \"basic_cpl\": BasicModel_cpl,\n",
    "    \"schnet_cpl\": SchNetModel_cpl,\n",
    "    \"gvp_cpl\": GVPGNNModel_cpl,\n",
    "    \"egnn_cpl\": EGNNModel_cpl,\n",
    "    \"tfn_cpl\": TFNModel_cpl,\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Three-body counterexample\n",
    "\n",
    "Pair of local neighbourhoods that are indistinguishable when comparing their set of $3$-body scalars, i.e. the unordered set of pairwise distances as well as angles."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_three_body_envs():\n",
    "    dataset = []\n",
    "\n",
    "    a_x, a_y, a_z = 5, 0, 5\n",
    "    b_x, b_y, b_z = 5, 5, 5\n",
    "    c_x, c_y, c_z = 0, 5, 5\n",
    "    \n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a_x, a_y, a_z],\n",
    "        [+b_x, +b_y, b_z],\n",
    "        [-b_x, -b_y, b_z],\n",
    "        [c_x, +c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a_x, a_y, a_z],\n",
    "        [+b_x, +b_y, b_z],\n",
    "        [-b_x, -b_y, b_z],\n",
    "        [c_x, -c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_three_body_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model\n",
    "model_name = \"basic_cpl\"\n",
    "\n",
    "correlation = 3\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"spherenet\": SphereNetModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "    \"basic_cpl\": BasicModel_cpl,\n",
    "    \"schnet_cpl\": SchNetModel_cpl,\n",
    "    \"gvp_cpl\": GVPGNNModel_cpl,\n",
    "    \"egnn_cpl\": EGNNModel_cpl,\n",
    "    \"tfn_cpl\": TFNModel_cpl,\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
