{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a Simple GCN "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from molProp_train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'zinc250k'\n",
    "prop_name = 'qed'\n",
    "device = 'cuda:0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = MolData (dataset=dataset, prop_name=prop_name, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'data/valid_idx_{dataset.lower()}.json') as f:\n",
    "    test_idx = json.load(f)\n",
    "if dataset == 'qm9':\n",
    "    test_idx = test_idx['valid_idxs']\n",
    "    test_idx = [int(i) for i in test_idx]\n",
    "\n",
    "test_idx = np.array(test_idx)\n",
    "all_mask = np.ones(len(data), dtype=bool)\n",
    "all_mask[test_idx] = 0\n",
    "train_idx = np.where (all_mask)[0]\n",
    "\n",
    "train_data = data[train_idx]\n",
    "test_data = data[test_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "prop_name = 'homo'\n",
    "train_data.change_prop(prop_name)\n",
    "test_data.change_prop(prop_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = SGCReg(data.num_node_features, num_layers=2).to(device)\n",
    "\n",
    "def collate_fn (x):\n",
    "    return Batch.from_data_list(x)\n",
    "\n",
    "batch_size = 64\n",
    "nepochs = 100\n",
    "train_loader = DataLoader (train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) #lambda x: Batch.from_data_list(x))\n",
    "test_loader = DataLoader (test_data, batch_size=batch_size, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "criterion = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    av_loss = 0\n",
    "    for data in train_loader:\n",
    "        out = model(data.x, data.edge_index, data.edge_attr, data.batch)\n",
    "        loss = criterion(out, data.y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        av_loss += loss\n",
    "\n",
    "    return (av_loss/len(train_loader.dataset))\n",
    "\n",
    "def test(loader):\n",
    "    model.eval()\n",
    "    mse = 0\n",
    "    for data in loader:\n",
    "        pred = model(data.x, data.edge_index, data.edge_attr, data.batch)  \n",
    "        mse += ((pred - data.y)**2).sum()\n",
    "    return mse / len(loader.dataset)  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "nepochs = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001, Train MSE: 0.0001, Test MSE: 0.0025\n",
      "Epoch: 002, Train MSE: 0.0000, Test MSE: 0.0025\n",
      "Epoch: 003, Train MSE: 0.0000, Test MSE: 0.0025\n",
      "Epoch: 004, Train MSE: 0.0000, Test MSE: 0.0025\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, nepochs):\n",
    "    train_loss = train()\n",
    "    test_loss = test(test_loader)\n",
    "    print(f'Epoch: {epoch:03d}, Train MSE: {train_loss:.4f}, Test MSE: {test_loss:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_save_fname = f'config/constraints/regmodels/sgc_{dataset}_{prop_name}.pt'\n",
    "torch.save (model.state_dict(), model_save_fname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "44dc670dcdd6ffb1ba23034ae072504999a2c20bd6cc686fd82920ca8c3f3b47"
  },
  "kernelspec": {
   "display_name": "Python 3.7.15 64-bit ('moltemp': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.15"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
