{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "import operator\n",
    "sys.path.append('../')\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "import e3nn\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))\n",
    "\n",
    "from experiments.utils.plot_utils import plot_3d\n",
    "from experiments.utils.train_utils import run_experiment\n",
    "from models import SchNetModel, DimeNetPPModel, SphereNetModel, EGNNModel, GVPGNNModel, TFNModel, MACEModel\n",
    "from models_chiral import BasicModel_chiral, EGNNModel_chiral\n",
    "\n",
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Four-body non-chiral counterexample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_four_body_nonchiral_envs():\n",
    "    dataset = []\n",
    "\n",
    "    a1_x, a1_y, a1_z = 3, 2, -4\n",
    "    a2_x, a2_y, a2_z = 0, 2, 5\n",
    "    a3_x, a3_y, a3_z = -3, 2, -4\n",
    "    b1_x, b1_y, b1_z = 3, -2, -4\n",
    "    b2_x, b2_y, b2_z = 0, -2, 5\n",
    "    b3_x, b3_y, b3_z = -3, -2, -4\n",
    "    c_x, c_y, c_z = 0, 5, 0\n",
    "\n",
    "    angle = 2 * torch.pi / 10 # random angle\n",
    "    Q = e3nn.o3.matrix_y(torch.tensor(angle)).numpy()\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a3_x, a3_y, a3_z],\n",
    "        [b1_x, b1_y, b1_z] @ Q,\n",
    "        [b2_x, b2_y, b2_z] @ Q,\n",
    "        [b3_x, b3_y, b3_z] @ Q,\n",
    "        [c_x, +c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0, 0, 0, 0], [1, 2, 3, 4, 5, 6, 7] ])\n",
    "    pos = - torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a3_x, a3_y, a3_z],\n",
    "        [b1_x, b1_y, b1_z] @ Q,\n",
    "        [b2_x, b2_y, b2_z] @ Q,\n",
    "        [b3_x, b3_y, b3_z] @ Q,\n",
    "        [c_x, -c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_four_body_nonchiral_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model\n",
    "model_name = \"basic_chair\"\n",
    "color = True\n",
    "tp = True\n",
    "\n",
    "correlation = 4\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"spherenet\": SphereNetModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "    \"basic_chair\": partial(BasicModel_chiral, color=color, tp=tp),\n",
    "    \"egnn_chair\": partial(EGNNModel_chiral, color=color, tp=tp),\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Four-body chiral counterexample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_four_body_chiral_envs():\n",
    "    dataset = []\n",
    "\n",
    "    a1_x, a1_y, a1_z = 3, 0, -4\n",
    "    a2_x, a2_y, a2_z = 0, 0, 5\n",
    "    a3_x, a3_y, a3_z = -3, 0, -4\n",
    "    c_x, c_y, c_z = 0, 5, 0\n",
    "    \n",
    "    angle = 2 * torch.pi / 7 # random angle\n",
    "    Q = e3nn.o3.matrix_y(torch.tensor(angle)).numpy()\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a3_x, a3_y, a3_z] @ Q,\n",
    "        [c_x, +c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    atoms = torch.LongTensor([ 0, 0, 0, 0, 0 ])\n",
    "    edge_index = torch.LongTensor([ [0, 0, 0, 0], [1, 2, 3, 4] ])\n",
    "    pos = torch.FloatTensor([ \n",
    "        [0, 0, 0],\n",
    "        [a1_x, a1_y, a1_z],\n",
    "        [a2_x, a2_y, a2_z],\n",
    "        [a3_x, a3_y, a3_z] @ Q,\n",
    "        [c_x, -c_y, c_z],\n",
    "    ])\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_four_body_chiral_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model\n",
    "model_name = \"basic_chair\"\n",
    "color = True\n",
    "tp = False\n",
    "\n",
    "correlation = 4\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"spherenet\": SphereNetModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": partial(TFNModel, hidden_irreps=e3nn.o3.Irreps(f'64x0e + 64x0o + 64x1e + 64x1o + 64x2e + 64x2o')),\n",
    "    \"mace\": partial(MACEModel, correlation=correlation, hidden_irreps=e3nn.o3.Irreps(f'32x0e + 32x0o + 32x1e + 32x1o + 32x2e + 32x2o')),\n",
    "    \"basic_chair\": partial(BasicModel_chiral, color=color, tp=tp),\n",
    "    \"egnn_chair\": partial(EGNNModel_chiral, color=color, tp=tp),\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A special case requires special coloring\n",
    "All points are distributed on the sphere, and the geometric center coincides with the center of the sphere."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "from e3nn.o3 import matrix_x, matrix_y, matrix_z\n",
    "# from utils.train_utils import seed\n",
    "\n",
    "# seed(0)\n",
    "times = 3\n",
    "def get_rotation():\n",
    "    return e3nn.o3.rand_matrix()\n",
    "rotations = [get_rotation() for i in range(times)]\n",
    "\n",
    "base_pos = torch.Tensor([\n",
    "    [-1, 0, 0],\n",
    "    [1/3, sqrt(8)/3, 0],\n",
    "    [1/3, -sqrt(8)/3, 0],\n",
    "    [1/6, sqrt(35)/6, 0],\n",
    "    [1/6, -sqrt(35)/6, 0],\n",
    "]) * 5\n",
    "my_pos = torch.cat([base_pos @ rot for rot in rotations], dim = 0)\n",
    "my_pos = torch.cat([torch.zeros(1, 3), my_pos], dim =0)\n",
    "my_atoms = torch.zeros(my_pos.size(0)).long()\n",
    "my_edge_index = torch.Tensor([\n",
    "    [0] * (my_pos.size(0) - 1),\n",
    "    [i + 1 for i in range(my_pos.size(0) - 1)]\n",
    "]).long()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_special_case_envs():\n",
    "    dataset = []\n",
    "\n",
    "    # Environment 0\n",
    "    atoms = my_atoms\n",
    "    edge_index = my_edge_index\n",
    "    pos = my_pos\n",
    "    y = torch.LongTensor([0])  # Label 0\n",
    "    data1 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data1.edge_index = to_undirected(data1.edge_index)\n",
    "    dataset.append(data1)\n",
    "    \n",
    "    # Environment 1\n",
    "    atoms = my_atoms\n",
    "    edge_index = my_edge_index\n",
    "    pos = -my_pos\n",
    "    y = torch.LongTensor([1])  # Label 1\n",
    "    data2 = Data(atoms=atoms, edge_index=edge_index, pos=pos, y=y)\n",
    "    data2.edge_index = to_undirected(data2.edge_index)\n",
    "    dataset.append(data2)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_special_case_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set model\n",
    "model_name = \"egnn_chair\"\n",
    "color = True\n",
    "tp = False\n",
    "\n",
    "correlation = 4\n",
    "model = {\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"spherenet\": SphereNetModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": partial(TFNModel, hidden_irreps=e3nn.o3.Irreps(f'64x0e + 64x0o + 64x1e + 64x1o + 64x2e + 64x2o')),\n",
    "    \"mace\": partial(MACEModel, correlation=correlation, hidden_irreps=e3nn.o3.Irreps(f'32x0e + 32x0o + 32x1e + 32x1o + 32x2e + 32x2o')),\n",
    "    \"basic_chair\": partial(BasicModel_chiral, color=color, tp=tp),\n",
    "    \"egnn_chair\": partial(EGNNModel_chiral, color=color, tp=tp),\n",
    "}[model_name](num_layers=1, in_dim=1, out_dim=2)\n",
    "\n",
    "# By increasing the epoch, the accuracy of the classifier can be increased.\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "\n",
    "import torch\n",
    "import e3nn\n",
    "\n",
    "X = torch.Tensor([\n",
    "    [-1, 0, 0],\n",
    "    [1/3, sqrt(8)/3, 0],\n",
    "    [1/3, -sqrt(8)/3, 0],\n",
    "    [1/6, sqrt(35)/6, 0],\n",
    "    [1/6, -sqrt(35)/6, 0],\n",
    "])\n",
    "\n",
    "sh_irreps = e3nn.o3.Irreps.spherical_harmonics(5)\n",
    "spherical_harmonics = e3nn.o3.SphericalHarmonics(\n",
    "    sh_irreps, normalize=True, normalization=\"component\"\n",
    ")\n",
    "tp = e3nn.o3.FullyConnectedTensorProduct(sh_irreps, sh_irreps, '3x1o')\n",
    "\n",
    "rot_num, test_times, successful_rate = 10, 100, 0\n",
    "for i in range(test_times):\n",
    "    G = torch.cat([X @ e3nn.o3.rand_matrix() for _ in range(rot_num)], dim = 0)\n",
    "    SH = torch.sum(spherical_harmonics(G), dim=0)\n",
    "    virtual_node = tp(SH, SH).view(3, 3)\n",
    "    successful_rate += 1 / test_times if (torch.abs(torch.det(virtual_node)) > 1e-2) else 0\n",
    "print(successful_rate)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "aesc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
