{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append('../../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "import torch_geometric\n",
    "from torch_geometric.loader import DataLoader\n",
    "import e3nn\n",
    "from e3nn import o3\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))\n",
    "\n",
    "from data_utils import create_two_body_envs, create_three_body_envs, create_four_body_chiral_envs, create_four_body_nonchiral_envs\n",
    "from plot_utils import plot_2d, plot_3d\n",
    "from train_utils import run_experiment\n",
    "from src.models import MPNNModel, EGNNModel, GVPGNNModel, TFNModel, SchNetModel, DimeNetPPModel, MACEModel, OriginalMACEModel\n",
    "\n",
    "# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)\n",
    "# print(f\"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}\")\n",
    "# print(f\"Is MPS available? {torch.backends.mps.is_available()}\")\n",
    "\n",
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    "# device = \"cpu\"\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "dataset = create_two_body_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Set model\n",
    "model_name = \"mace\"\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "num_layers = 1\n",
    "correlation = 2\n",
    "model = {\n",
    "    \"mpnn\": MPNNModel,\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "    # \"mace\": partial(OriginalMACEModel, correlation=correlation, num_interactions=num_layers, num_elements=1, irreps_out=o3.Irreps(\"2x0e\")),\n",
    "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "dataset = create_three_body_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Set model\n",
    "model_name = \"mace\"\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "num_layers = 1\n",
    "correlation = 3\n",
    "model = {\n",
    "    \"mpnn\": MPNNModel,\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "    # \"mace\": partial(OriginalMACEModel, correlation=correlation, num_interactions=num_layers, num_elements=1, irreps_out=o3.Irreps(\"2x0e\")),\n",
    "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "dataset = create_four_body_nonchiral_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Set model\n",
    "model_name = \"mace\"\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "num_layers = 1\n",
    "correlation = 4\n",
    "model = {\n",
    "    \"mpnn\": MPNNModel,\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "    # \"mace\": partial(OriginalMACEModel, correlation=correlation, num_interactions=num_layers, num_elements=1, irreps_out=o3.Irreps(\"2x0e\")),\n",
    "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "dataset = create_four_body_chiral_envs()\n",
    "for data in dataset:\n",
    "    plot_3d(data, lim=5)\n",
    "\n",
    "# Set model\n",
    "model_name = \"mace\"\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=2, shuffle=False)\n",
    "\n",
    "num_layers = 1\n",
    "correlation = 4\n",
    "model = {\n",
    "    \"mpnn\": MPNNModel,\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": TFNModel,\n",
    "    \"mace\": partial(MACEModel, correlation=correlation),\n",
    "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=10,\n",
    "    device=device,\n",
    "    verbose=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
