{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Slash Class tutorial  \n",
    "This is notebook demonstrating how the Slash class works\n",
    "\n",
    "\n",
    "### NOTE: This notebook may need some refactoring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start importing modules...\n"
     ]
    },
    {
     "ename": "ImportError",
     "evalue": "cannot import name 'dataList'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-2-e4118ff8e7bc>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;31m#own modules\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdataGen\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataList\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobsList\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mnetwork\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mNet\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mspnasp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mImportError\u001b[0m: cannot import name 'dataList'"
     ]
    }
   ],
   "source": [
    "print(\"start importing modules...\")\n",
    "\n",
    "import time\n",
    "import sys\n",
    "#sys.path.append('../../')\n",
    "sys.path.append('../')\n",
    "sys.path.append('../SpnAsp/')\n",
    "sys.path.append('../EinsumNetwork/')\n",
    "\n",
    "#torch, numpy, ...\n",
    "import torch\n",
    "#from torch.utils.tensorboard import SummaryWriter\n",
    "torch.cuda.empty_cache()\n",
    "import numpy as np\n",
    "import importlib\n",
    "\n",
    "#own modules\n",
    "from dataGen import dataList, obsList, test_loader, train_loader\n",
    "from network import Net\n",
    "import spnasp\n",
    "from spnasp import SpnASP\n",
    "import utils\n",
    "\n",
    "#seeds\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "print(\"...done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our training inputs are a set of images and a set of oberservations as labels\n",
    "\n",
    "NeurASP expects the data in form of lists of a hashmap. Thats because we need to map the input for example two images to its corresponding atoms in the logic program.  \n",
    "Example MNIST digit addition: [{'im1':[...], 'im2':[...] }, {'im1':[...], 'im2':[...] }, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataList and obsList length: 30000 30000\n",
      "observations for first five samples: [':- not addition(i1, i2, 12).' ':- not addition(i1, i2, 15).'\n",
      " ':- not addition(i1, i2, 6).' ':- not addition(i1, i2, 5).'\n",
      " ':- not addition(i1, i2, 7).']\n",
      "keys for each input dict_keys(['i1', 'i2'])\n"
     ]
    }
   ],
   "source": [
    "#train set\n",
    "print(\"dataList and obsList length:\", len(dataList), len(obsList))\n",
    "print(\"observations for first five samples:\",obsList[0:5])\n",
    "print(\"keys for each input\", dataList[0:5][0].keys())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([100, 1, 28, 28])\n",
      "torch.Size([100])\n"
     ]
    }
   ],
   "source": [
    "#the test set contains single images and their labels as class indices, this is for testing the neural network\n",
    "one_test = next(iter(test_loader))\n",
    "print(one_test[0].shape)\n",
    "print(one_test[1].shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SpnAsp Program and Hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "program = '''\n",
    "img(i1). img(i2).\n",
    "addition(A,B,N) :- digit(0,A,N1), digit(0,B,N2), N=N1+N2.\n",
    "spn(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).\n",
    "'''\n",
    "\n",
    "BATCH_SIZE = 100\n",
    "EPOCHS = 10\n",
    "LR = 0.01\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SPN ASP Instantiation\n",
    "- Instantiate neural networks or sum product networks.\n",
    "- Define nnMapping: a dictionary that maps neural network names (i.e., strings) to the neural network objects (i.e., torch.nn.Module object)\n",
    "- Define optimizers: a dictionary that maps neural network names (i.e., strings) to optimizer (we use the Adam optimizer here)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "########\n",
    "# Start training and testing\n",
    "########\n",
    "\n",
    "saveModelPath = 'data/model.pt'\n",
    "\n",
    "\n",
    "#iterate over lr and batchsizes\n",
    "\n",
    "\n",
    "#setup new SPN ASP program given the network parameters\n",
    "#first define the network(s)\n",
    "m = Net(structure = 'binary-trees',\n",
    "        depth = 8,\n",
    "        num_repetitions = 10,\n",
    "        num_var = 784)\n",
    "\n",
    "#then create a mapping from network to atom and network optimizer\n",
    "nnMapping = {'digit': m}\n",
    "optimizers = {'digit': torch.optim.Adam(m.parameters(), lr=LR, eps=1e-7)}\n",
    "\n",
    "#then instantiate SpnAsp using the networks, the mappings and the logic program\n",
    "SpnASPobj = SpnASP(program, nnMapping, optimizers)\n",
    "\n",
    "\n",
    "#metric lists\n",
    "train_accuracy_list = []\n",
    "test_accuracy_list = []\n",
    "startTime = time.time()\n",
    "\n",
    "\n",
    "# train the network and evaluate the performance\n",
    "for e in range(EPOCHS):\n",
    "    print('Epoch {}...'.format(e+1))\n",
    "\n",
    "    time_train= time.time()\n",
    "    SpnASPobj.learn(dataList=dataList, obsList=obsList, epoch=1, smPickle='data/stableModels.pickle', batchSize=BATCH_SIZE) #learn from the data\n",
    "    timestamp_train = utils.time_delta_now(time_train)\n",
    "\n",
    "\n",
    "    time_test = time.time()\n",
    "    test_acc, _ = SpnASPobj.testNN('digit', test_loader)\n",
    "    train_acc, _ = SpnASPobj.testNN('digit', train_loader)\n",
    "    train_accuracy_list.append([train_acc,e])\n",
    "    test_accuracy_list.append([test_acc, e])\n",
    "    timestamp_test = utils.time_delta_now(time_test)\n",
    "\n",
    "\n",
    "\n",
    "    #save and print statistics\n",
    "    print('Train Acc: {:0.2f}%, Test Acc: {:0.2f}%'.format(train_acc, test_acc))\n",
    "    print('--- train time:  ---', timestamp_train)\n",
    "    print('--- test time:  ---' , timestamp_test)\n",
    "    print('--- total time from beginning:  ---', utils.time_delta_now(startTime))\n",
    "\n",
    "\n",
    "    #export results\n",
    "    train_accuracy_list = np.array(train_accuracy_list)\n",
    "    test_accuracy_list = np.array(test_accuracy_list)\n",
    "\n",
    "\n",
    "    suffix = \"_epoch:{}_bs:{}_lr:{}\".format(EPOCHS, batch_size, lr)\n",
    "    export_path = \"../../results/plots\" + suffix +\".svg\"\n",
    "\n",
    "\n",
    "\n",
    "    utils.export_results(test_accuracy_list=test_accuracy_list, train_accuracy_list=train_accuracy_list,\n",
    "                         export_path = export_path, export_suffix = suffix , network_params=str(params))\n",
    "    print(\"results saved\\n\\n\")\n",
    "\n",
    "\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# stable models in SPN Asp \n",
    "In this section we take a look how the SpnAsp program stores the logic program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "program='''\n",
    "img(i1). img(i2).\n",
    "addition(A,B,N) :- digit(0,A,N1), digit(0,B,N2), N=N1+N2.\n",
    "spn(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).\n",
    "'''\n",
    "\n",
    "m = Net(structure = 'binary-trees',\n",
    "        depth = 8,\n",
    "        num_repetitions = 10,\n",
    "        num_var = 784)\n",
    "\n",
    "nnMapping = {'digit': m}\n",
    "optimizers = {'digit': torch.optim.Adam(m.parameters(), lr=0.01, eps=1e-7)}\n",
    "SpnASPobj = SpnASP(program, nnMapping, optimizers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(SpnASPobj.mvpp.keys(),\"\\n\")\n",
    "\n",
    "print(\"NNPROB:\",SpnASPobj.mvpp['nnProb'],\"\\n\") #neural network probabilities for image 1 and 2\n",
    "print(SpnASPobj.mvpp['atom'],\"\\n\") #neural network probabilities for image 1 and 2 as sting atoms\n",
    "\n",
    "print(\"neural network rules:\",SpnASPobj.mvpp['nnPrRuleNum'],\"\\n\") #Number of neural networks used in rule atoms -> here it is 2\n",
    "\n",
    "\n",
    "print(\"The programm and its subparts:\\n\")\n",
    "print(SpnASPobj.mvpp['program'],\"\\n\") #The whole Program\n",
    "print(SpnASPobj.mvpp['program_pr'],\"\\n\") #The NeurASP Part of the Program\n",
    "print(SpnASPobj.mvpp['program_asp'],\"\\n\") #The NeurASP Part of the Program"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"examples for a stable model\")\n",
    "print(SpnASPobj.stableModels[0][0])\n",
    "print(SpnASPobj.stableModels[0][1],\"\\n\")\n",
    "print(\"Amount of stable models\",len(SpnASPobj.stableModels),\", which is one for each obersvation\")\n",
    "print(\"Amount of atoms in a stable model\",len(SpnASPobj.stableModels[0][0]),\"\\n\") #always 8 atoms\n",
    "#print(\"Amount of stable models to sum up a certain number. This depends on the number(only 9+8 or 8+9 sum up to 17)\",len(SpnASPobj.stableModels[5])) #depends on amount of combinations for summing to a certain number\n",
    "#print(SpnASPobj.stableModels[5])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
