{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Open MatSci ML Toolkit Tutorial: Training your Custom Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we demonstrate how to setup a _**Open MatSci ML Toolkit**_ experiment starting from selecting a dataset to implementing your own custom graph neural network (GNN) model. This workflow is recommended for testing custom models on your development machine, such as a laptop, before deploying them on a cluster or a machine with multiple GPUs and training them on the full dataset. _**Open MatSci ML Toolkit**_ exposes different interfaces (as base abstract classes) that, with the help of [pytorch-lighting](https://www.pytorchlightning.ai/), enable the user to get running from the ground up in a couple of lines of code. This is what this tutorial aims to achieve."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by importing a couple of useful libraries below. These include the standard python library, [pytorch](https://pytorch.org/), and [dgl](https://www.dgl.ai/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (C) 2022 Intel Corporation\n",
    "# SPDX-License-Identifier: MIT License\n",
    "\n",
    "import warnings\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "import dgl, torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we import the specific abstract classes that implement the _**Open MatSci ML Toolkit**_ interface. In particular, we need a data module and a model module given below. For this tutorial, we choose to focus on the structure to energy/forces task (S2EF) and proceed to import the appropriate modules. The dataset is given by `S2EFDGLDataModule` which allows access to the development dataset (via the `from_devset` method) provided with _**Open MatSci ML Toolkit**_ while the model module is given by `S2EFLitModule` ensures that the developed model interfaces properly with _**Open MatSci ML Toolkit**_'s data pipeline and pytorch lighting. In particular, it implements a `forward` and `training_step` needed for the specific task. The `AbstractEnergyModel` registers the model with pytroch lighting and specifies that the output should be energy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ocpmodels.lightning.data_utils import S2EFDGLDataModule\n",
    "from ocpmodels.models import AbstractEnergyModel, S2EFLitModule"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For reproducibility, we laverage the set seed mechanics of both pytorch lighting and dgl."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 42\n",
    "\n",
    "pl.seed_everything(SEED)\n",
    "dgl.seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This section discusses how to construct a model and integrate it into **Open MatSci ML Toolkit** to run on the OCP dataset. The steps are pretty simple and can be summarized as follows:\n",
    "\n",
    "1. Start by implementing/choosing a graph neural network layer\n",
    "2. Using this layer, construct a layered model that subclasses `AbstractEnergyModel`; this interfaces the model with pytorch lighting. Note that any `AbstractEnergyModel` must output a scalar value representing energy\n",
    "3. Implement any customization for the model datapipeline; this can be achieved by editing `S2EFLitModule` and its associated method `_get_inputs`\n",
    "\n",
    "For best practices on designing DGL models, please refer to our model guideline given [here](ocpmodels/models/README.md).\n",
    "\n",
    "Now we can look at an example of how to build a new model and integrate it with pytorch lighting:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select/define a convolution layer and create a model\n",
    "\n",
    "from dgl.nn import GraphConv, AvgPooling\n",
    "\n",
    "class GraphConvModel(AbstractEnergyModel):\n",
    "    def __init__(self, num_layers, in_dim, hidden_dim):\n",
    "        super().__init__()\n",
    "        sizes = [in_dim] + [hidden_dim] * num_layers\n",
    "        layers = []\n",
    "        for indx, (_in, _out) in enumerate(zip(sizes[:-1], sizes[1:])):\n",
    "            layers.append(\n",
    "                GraphConv(_in, _out, activation=F.silu if indx<num_layers-1 else None)\n",
    "            )\n",
    "        self.convs = nn.ModuleList(layers)\n",
    "\n",
    "        self.readout = AvgPooling()\n",
    "        \n",
    "        output_dim = 1  # energy is a scalar\n",
    "        self.proj = nn.Linear(hidden_dim, output_dim)\n",
    "        \n",
    "    def forward(self, graph, features):\n",
    "        for layer in self.convs:\n",
    "            features = layer(graph, features)\n",
    "        pooled = self.readout(graph, features)\n",
    "        out = self.proj(pooled).squeeze()\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we specialize our model inputs by subclassing `S2EFLitModule`. This ensures that the dgl graph object is expanded properly according to our model definition. If you model accepts a dgl model then this step is not required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# implement custom pipeline\n",
    "from dgl import AddSelfLoop\n",
    "\n",
    "class S2EFLitModule(S2EFLitModule):\n",
    "    def _get_inputs(self, batch):\n",
    "        graph_transform = AddSelfLoop()\n",
    "        graph = graph_transform(batch.get(\"graph\"))\n",
    "        features = []\n",
    "        for _, val in graph.ndata.items():\n",
    "            features.append(val if val.dim()>1 else val.view(-1, 1))\n",
    "        features = torch.hstack(features)\n",
    "        return graph, features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below we instantiate the gnn and the corresponding mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "REGRESS_FORCES = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gnn = GraphConvModel(num_layers=3, in_dim=9, hidden_dim=128)\n",
    "# create the S2EF task; lr and gamma are inconsequential because we create\n",
    "# our own optimizer below\n",
    "model = S2EFLitModule(gnn, lr=1e-3, gamma=0.1, regress_forces=REGRESS_FORCES)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Module"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each data module, including `S2EFDGLDataModule`, includes a method to load the smaller development set. This method is `from_devset` and accepts similar arguments such as loading parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "BATCH_SIZE = 16\n",
    "NUM_WORKERS = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# grab the devset; we will create our own data loader but we can rely\n",
    "# on the `DataModule` to grab splits\n",
    "data_module = S2EFDGLDataModule.from_devset(\n",
    "    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "MAX_EPOCHS = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(max_epochs=MAX_EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    trainer.fit(model, datamodule=data_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir $trainer.logger.log_dir"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "a5404e44c4c870fdd2e378f0acaba09fbe7ca9d23ffa1df0a4b4ecf82b7b72d7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
