{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import sys\n",
    "sys.path.append('../../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import functional as F\n",
    "import torch_geometric\n",
    "from torch_geometric.loader import DataLoader\n",
    "import e3nn\n",
    "from e3nn import o3\n",
    "from functools import partial\n",
    "\n",
    "print(\"PyTorch version {}\".format(torch.__version__))\n",
    "print(\"PyG version {}\".format(torch_geometric.__version__))\n",
    "print(\"e3nn version {}\".format(e3nn.__version__))\n",
    "\n",
    "from data_utils import create_rotsym_envs\n",
    "from plot_utils import plot_2d, plot_3d\n",
    "from train_utils import run_experiment\n",
    "from src.models import MPNNModel, EGNNModel, GVPGNNModel, TFNModel, SchNetModel, DimeNetPPModel, MACEModel, OriginalMACEModel\n",
    "\n",
    "# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)\n",
    "# print(f\"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}\")\n",
    "# print(f\"Is MPS available? {torch.backends.mps.is_available()}\")\n",
    "\n",
    "# Set the device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# device = \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
    "# device = \"cpu\"\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset\n",
    "dataset = create_rotsym_envs(fold=2)\n",
    "for data in dataset:\n",
    "    plot_2d(data, lim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set parameters\n",
    "model_name = \"egnn\"\n",
    "correlation = 1\n",
    "max_ell = 2\n",
    "fold = 2\n",
    "\n",
    "# Create dataset\n",
    "dataset = create_rotsym_envs(fold)\n",
    "\n",
    "# Create dataloaders\n",
    "dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "val_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "test_loader = DataLoader(dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "num_layers = 1\n",
    "model = {\n",
    "    \"mpnn\": MPNNModel,\n",
    "    \"schnet\": SchNetModel,\n",
    "    \"dimenet\": DimeNetPPModel,\n",
    "    \"egnn\": EGNNModel,\n",
    "    \"gvp\": GVPGNNModel,\n",
    "    \"tfn\": partial(TFNModel, max_ell=max_ell),\n",
    "    \"mace\": partial(MACEModel, max_ell=max_ell, correlation=correlation),\n",
    "    # \"mace\": partial(OriginalMACEModel, correlation=correlation, num_interactions=num_layers, num_elements=1, irreps_out=o3.Irreps(\"2x0e\")),\n",
    "}[model_name](num_layers=num_layers, in_dim=1, out_dim=2)\n",
    "\n",
    "best_val_acc, test_acc, train_time = run_experiment(\n",
    "    model, \n",
    "    dataloader,\n",
    "    val_loader, \n",
    "    test_loader,\n",
    "    n_epochs=100,\n",
    "    n_times=1,\n",
    "    device=device,\n",
    "    verbose=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "94aa676993820a604ac86f7af94f5432e989a749d5dd43e18f9507de2e8c2897"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
