{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import operator\n",
    "sys.path.append('../')\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "import e3nn\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))\n",
    "\n",
    "from experiments.utils.plot_utils import plot_3d\n",
    "\n",
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Four-body non-chiral counterexample is geomrtrically isomorphic\n",
    "Although the original GWL test claims that MACE (5-body) can distinguish such two geometric graphs, we cannot reproduce this result. We further found that the two geometric graphs can be overlapped by rotation, and we speculate that there is a typo in the original GWL test (or the use of equivariant predict introduces pose information)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "angle = 2 * torch.pi / 10 # random angle\n",
    "Q = e3nn.o3.matrix_y(torch.tensor(angle)).numpy()\n",
    "Ix = torch.diag(torch.Tensor([-1, 1, 1])).numpy()\n",
    "Iy = torch.diag(torch.Tensor([1, -1, 1])).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_four_body_nonchiral_envs():\n",
    "    dataset = []\n",
    "\n",
    "    a1_x, a1_y, a1_z = 3, 2, -4\n",
    "    a2_x, a2_y, a2_z = 0, 2, 5\n",
    "    a3_x, a3_y, a3_z = -3, 2, -4\n",
    "    b1_x, b1_y, b1_z = 3, -2, -4\n",
    "    b2_x, b2_y, b2_z = 0, -2, 5\n",
    "    b3_x, b3_y, b3_z = -3, -2, -4\n",
    "    c_x, c_y, c_z = 0, 5, 0\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a3_x, a3_y, a3_z],\n",
    "        [b1_x, b1_y, b1_z] @ Q,\n",
    "        [b2_x, b2_y, b2_z] @ Q,\n",
    "        [b3_x, b3_y, b3_z] @ Q,\n",
    "        [c_x, +c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1, a permutation is oprated on pos\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [b3_x, b3_y, b3_z] @ Q,\n",
    "        [b2_x, b2_y, b2_z] @ Q,\n",
    "        [b1_x, b1_y, b1_z] @ Q,\n",
    "        [a3_x, a3_y, a3_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [c_x, -c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_four_body_nonchiral_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_0, pos_1 = dataset[0].pos, dataset[1].pos\n",
    "R = Ix @ Q @ Iy\n",
    "\n",
    "match_loss = torch.norm(pos_0 - pos_1 @ R)\n",
    "print(f'Match loss of two geometric graphs: {match_loss}')\n",
    "matrix_type = 'roatation' if torch.abs(torch.det(torch.Tensor(R)) - 1) < 1e-5 else 'reflection'\n",
    "print(f'$R$ is a {matrix_type}.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Four-body chiral counterexample is geomrtrically isomorphic\n",
    "Furthermore, this example does not even have a symmetry axis and is not chiral."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_four_body_chiral_envs():\n",
    "    dataset = []\n",
    "\n",
    "    a1_x, a1_y, a1_z = 3, 0, -4\n",
    "    a2_x, a2_y, a2_z = 0, 0, 5\n",
    "    a3_x, a3_y, a3_z = -3, 0, -4\n",
    "    c_x, c_y, c_z = 0, 5, 0\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a3_x, a3_y, a3_z],\n",
    "        [c_x, +c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a3_x, a3_y, a3_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [c_x, -c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_four_body_chiral_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pos_0, pos_1 = dataset[0].pos, dataset[1].pos\n",
    "\n",
    "angle = torch.pi\n",
    "R = e3nn.o3.matrix_z(torch.tensor(angle)).numpy()\n",
    "\n",
    "match_loss = torch.norm(pos_0 - pos_1 @ R)\n",
    "print(f'Match loss of two geometric graphs: {match_loss}')\n",
    "matrix_type = 'roatation' if torch.abs(torch.det(torch.Tensor(R)) - 1) < 1e-5 else 'reflection'\n",
    "print(f'$R$ is a {matrix_type}.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aesc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
