{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Slash Class tutorial  \n",
    "This is notebook demonstrating how the Slash class works\n",
    "\n",
    "\n",
    "### NOTE: This notebook may need some refactoring\n",
    "### Firstly, load all the modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start importing modules...\n",
      "...done\n"
     ]
    }
   ],
   "source": [
    "print(\"start importing modules...\")\n",
    "\n",
    "import time\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "sys.path.append('../SLASH/')\n",
    "sys.path.append('../EinsumNetworks/src/')\n",
    "#torch, numpy, ...\n",
    "import torch\n",
    "from torchvision.transforms import transforms\n",
    "import torchvision\n",
    "torch.cuda.empty_cache()\n",
    "import numpy as np\n",
    "import importlib\n",
    "\n",
    "#own modules\n",
    "from dataGen import MNIST_Addition, get_data_and_query_list  # dataList, obsList, test_loader, train_loader\n",
    "from network_nn import Net_nn\n",
    "import slash\n",
    "from slash import SLASH\n",
    "import utils\n",
    "from einsum_wrapper import EiNet\n",
    "from mvpp import MVPP\n",
    "\n",
    "#seeds\n",
    "utils.set_manual_seed(1)\n",
    "print(\"...done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Secondly, load the data and the queries made of those...\n",
    "This means that we need to generate\n",
    "- $\\mathbf{\\mathcal{D}}$ -- the list of the belonging data entries, i.e. $x_{\\cdot}$\n",
    "- $\\mathbf{Q}$ -- the list of the queries for the logic program\n",
    "- $x_l$ -- the raw data to be mapped by $\\mathcal{T}(\\cdot)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the datavision.transforms suitable for SPNs\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081, )), transforms.Lambda(lambda x: torch.flatten(x))])\n",
    "# Generate the training dataset\n",
    "train_dataset = MNIST_Addition(torchvision.datasets.MNIST(root='./data/', train=True, download=True, transform=transform), 'data/train_data.txt', True)\n",
    "# Generate the queries and the observation list\n",
    "dataList, queryList = get_data_and_query_list(train_dataset)\n",
    "# Load the 'pure' datasets for the mapping\n",
    "test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, transform=transform), batch_size=100, shuffle=True)\n",
    "train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=True, transform=transform), batch_size=100, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our training inputs are a set of images and a set of queries as labels\n",
    "\n",
    "SLASH expects the data in form of lists of a hashmap. Thats because we need to map the input for example two images to its corresponding atoms in the logic program.  \n",
    "Example MNIST digit addition: [{'im1':[...], 'im2':[...] }, {'im1':[...], 'im2':[...] }, ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataList and queryList length: 30000 30000\n",
      "observations for first five samples: [':- not addition(i1, i2, 12).' ':- not addition(i1, i2, 15).'\n",
      " ':- not addition(i1, i2, 6).' ':- not addition(i1, i2, 5).'\n",
      " ':- not addition(i1, i2, 7).']\n",
      "keys for each input dict_keys(['i1', 'i2'])\n",
      "/home/arseny/Documents/splpmln/src/slash_mnist_digit_addition\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOsAAADrCAYAAACICmHVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAFv0lEQVR4nO3dz4tNfxzH8XO+KCllptSYMt9S/gArYqMQCxvNwkYs7Oz5D6SklAWykX9AlgjNZsqKkrX6aoryLWHvfNf6zn3fuT9m5r7ufTyW3uec+2n07KM+zty267oGmHx/bfcCgI0RK4QQK4QQK4QQK4QQK4TYOcjFbds654FN1nVdu96f21khhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghhFghxM7tXkCCK1eulPPHjx+P9PxHjx71nD19+nSkZ+/evbucLy8vl/PFxcWes3379pX33rx5s5w/e/asnPMnOyuEECuEECuEECuEECuEECuEECuEaLuu2/jFbbvxi6fIkydPyvmlS5c27bPbti3ng/z9jfvz+33279+/y/mNGzfK+d27d8v5tOq6bt0fup0VQogVQogVQogVQogVQogVQogVQnifdQs8fPiwnL98+bLn7MSJE+W9m33O+vXr156zs2fPlveeOXOmnB87dmyoNc0qOyuEECuEECuEECuEECuEECuEECuE8D7rBoz6PuuRI0fK+YcPHwZe0yRYWloq5+/evSvnc3Nz5XzHjh0Dr2kaeJ8VwokVQogVQogVQogVQogVQogVQniftam/g7Rpmuby5cvlfLPfKZ1Unz9/Ludra2vlfH5+fpzLmXp2VgghVgghVgghVgghVgghVgjh6KZpmlOnTpXzfkczs3p00+/Ia2FhoZzP6s9tWHZWCCFWCCFWCCFWCCFWCCFWCCFWCOGcdQy+ffs20jzVwYMHy/n+/fu3aCWzwc4KIcQKIcQKIcQKIcQKIcQKIcQKIZyzNk3z6dOnkeY/f/4s51++fBl4TQmWl5dHuv/Vq1djWslssLNCCLFCCLFCCLFCCLFCCLFCCLFCCOesTdOsrq6W84sXL5bz+/fvj3M5MQ4cOFDO27Yt5+/fvx/ncqaenRVCiBVCiBVCiBVCiBVCiBVCiBVCtIN8R2bbtr5Qcx1zc3Pl/Pv371u0kvE7evRoz1m/91F37dpVzk+ePFnO3759W86nVdd16x5Q21khhFghhFghhFghhFghhFghhFfkxiD5aKafa9eu9Zzt2bOnvPfNmzflfFaPZoZlZ4UQYoUQYoUQYoUQYoUQYoUQYoUQzlln3KFDh8r56dOnh372nTt3hr6X/7OzQgixQgixQgixQgixQgixQgixQgjnrDPuwYMH5bz6Wsfnz5+X97548WKoNbE+OyuEECuEECuEECuEECuEECuEECuEcM465ZaWlsr54cOHh37269evh76XwdlZIYRYIYRYIYRYIYRYIYRYIYSjm3AXLlwo5/fu3Svni4uL5fzHjx89Zx8/fizvZbzsrBBCrBBCrBBCrBBCrBBCrBBCrBDCOWu4hYWFct7vHLWf8+fP95ytrq6O9GwGY2eFEGKFEGKFEGKFEGKFEGKFEGKFEM5Zwx0/fryct21bzvudlTpLnRx2VgghVgghVgghVgghVgghVgghVgjhnHXC7d27t5yfO3eunHddV85v37498JrYHnZWCCFWCCFWCCFWCCFWCCFWCCFWCOGcdcJdvXq1nM/Pz5fztbW1cr6ysjLoktgmdlYIIVYIIVYIIVYIIVYIIVYI4ehmwl2/fn2k+2/dulXOf/36NdLz2Tp2VgghVgghVgghVgghVgghVgghVgjR9vtVlX9c3LYbvxgYStd1635Pp50VQogVQogVQogVQogVQogVQogVQgz6Puu/TdP8sxkLAZqmaZq/ew0G+k8RwPbxz2AIIVYIIVYIIVYIIVYIIVYIIVYIIVYIIVYI8R+MpMuVf27IowAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(\"dataList and queryList length:\", len(dataList), len(queryList))\n",
    "print(\"observations for first five samples:\",queryList[0:5])\n",
    "print(\"keys for each input\", dataList[0:5][0].keys())\n",
    "import os\n",
    "print(os.getcwd())\n",
    "import matplotlib.pyplot as plt\n",
    "#print(\"the first imput in the list:\",dataList[1]['i2'])\n",
    "first_image = np.array(dataList[1]['i2'], dtype='float')\n",
    "pixels = first_image.reshape((28, 28))\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.imshow(pixels, cmap='gray')\n",
    "plt.savefig('handwritten_digit_seven.svg', dpi=600, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The test set contains single images and their labels as class indices. This is for testing of the Probabilistic Circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([100, 784])\n",
      "torch.Size([100])\n"
     ]
    }
   ],
   "source": [
    "one_test = next(iter(test_loader))\n",
    "print(one_test[0].shape)\n",
    "print(one_test[1].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "### SLASH Program and Hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "program = '''\n",
    "img(i1). img(i2).\n",
    "addition(A,B,N) :- digit(0,A,N1), digit(0,B,N2), N=N1+N2.\n",
    "pc(digit(1,X), [0,1,2,3,4,5,6,7,8,9]) :- img(X).\n",
    "'''\n",
    "\n",
    "# Define the dictionary of the hyper-parameters\n",
    "params = {'exp_name': 'MNIST-Addition', \n",
    "          'structure': 'poon-domingos','pd_num_pieces':[4,7,28],\n",
    "          'depth':8, 'num_repetitions':10, 'num_var': 784,\n",
    "          'pd_width':28, 'pd_height':28, 'use_spn':True,\n",
    "          'credentials':'AS', 'class_count': 10, 'lr': 0.01,\n",
    "          'bs':100, 'epochs':10, 'p_num':8, 'drop_out':0.0,\n",
    "          'learn_prior':True}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SLASH Instantiation\n",
    "- Instantiate probabilistic circuits or neural networks.\n",
    "- Define nnMapping: a dictionary that maps PC/NN names (i.e., strings) to the PC/NN objects (i.e., torch.nn.Module object)\n",
    "- Define optimizers: a dictionary that maps PC/NN names(i.e., strings) to optimizer (we use the Adam optimizer here)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "P(C) is learnable.\n",
      "P(C) is tensor([-2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,\n",
      "        -2.3026, -2.3026], device='cuda:0', requires_grad=True)\n",
      "train SPN with EM: False\n",
      "The number of the trainable parameters: 6663440\n",
      "Epoch 1...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/300 [00:00<?, ?it/s]<string>:6: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "100%|██████████| 300/300 [01:14<00:00,  4.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward time:  6.891417503356934\n",
      "asp time: 29.700540781021118\n",
      "backward time:  37.686354637145996\n",
      "Train Acc: 0.33%, Test Acc: 0.33%\n",
      "--- train time:  --- 0h:1m:14s\n",
      "--- test time:  --- 0 days, 0 hours, 0 minutes, 14 seconds, 427 milliseconds\n",
      "--- total time from beginning:  --- 0 days, 0 hours, 1 minutes, 29 seconds, 894 milliseconds\n",
      "Epoch 2...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/300 [00:00<?, ?it/s]<string>:6: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      "100%|██████████| 300/300 [01:15<00:00,  3.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "forward time:  6.90355372428894\n",
      "asp time: 30.002299070358276\n",
      "backward time:  38.03681302070618\n",
      "Train Acc: 0.81%, Test Acc: 0.82%\n",
      "--- train time:  --- 0h:1m:15s\n",
      "--- test time:  --- 0 days, 0 hours, 0 minutes, 15 seconds, 612 milliseconds\n",
      "--- total time from beginning:  --- 0 days, 0 hours, 2 minutes, 59 seconds, 640 milliseconds\n",
      "Epoch 3...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/300 [00:00<?, ?it/s]<string>:6: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
      " 76%|███████▌  | 228/300 [00:57<00:18,  3.94it/s]"
     ]
    }
   ],
   "source": [
    "########\n",
    "# Start training and testing\n",
    "########\n",
    "\n",
    "saveModelPath = 'data/model.pt'\n",
    "\n",
    "# Iterate over lr and batchsizes\n",
    "\n",
    "\n",
    "# Setup new SLASH program given the network parameters:\n",
    "# Firstly, define the network(s).\n",
    "m = EiNet(structure = params['structure'],\n",
    "          pd_num_pieces = params['pd_num_pieces'],\n",
    "          depth = params['depth'],\n",
    "          num_repetitions = params['num_repetitions'],\n",
    "          num_var = params['num_var'],\n",
    "          class_count=params['class_count'],\n",
    "          pd_width=params['pd_width'],\n",
    "          pd_height=params['pd_height'],\n",
    "          use_em=False,\n",
    "          learn_prior=params['learn_prior']\n",
    "          )\n",
    "\n",
    "# Secondly, create a mapping from network to atom and network optimizer.\n",
    "networkMapping = {'digit': m}\n",
    "optimizers = {'digit': torch.optim.Adam(m.parameters(), lr=params['lr'], eps=1e-7)}\n",
    "\n",
    "params['num_trainable_params'] = sum(p.numel() for p in m.parameters() if p.requires_grad)\n",
    "print('The number of the trainable parameters:', params['num_trainable_params'])\n",
    "\n",
    "# Thirdly, instantiate SpnAsp using the networks, the mappings and the logic program\n",
    "SLASHobj = SLASH(program, networkMapping, optimizers)\n",
    "\n",
    "\n",
    "# Metric lists\n",
    "train_accuracy_list = []\n",
    "test_accuracy_list = []\n",
    "confusion_matrix_list = []\n",
    "startTime = time.time()\n",
    "\n",
    "\n",
    "# Train the network and evaluate the performance\n",
    "for e in range(params['epochs']):\n",
    "    print('Epoch {}...'.format(e+1))\n",
    "\n",
    "    time_train= time.time()\n",
    "    SLASHobj.learn(dataList=dataList, queryList=queryList, epoch=1, batchSize=params['bs'], p_num=params['p_num']) # learn from the data\n",
    "    # Save the training time for one epoch\n",
    "    timestamp_train = utils.time_delta_now(time_train, simple_format=True)\n",
    "    params['train_time'] = timestamp_train\n",
    "\n",
    "    time_test = time.time()\n",
    "    test_acc, _, confusion_matrix = SLASHobj.testNetwork('digit', test_loader, ret_confusion=True)\n",
    "    confusion_matrix_list.append(confusion_matrix)\n",
    "    train_acc, _ = SLASHobj.testNetwork('digit', train_loader)\n",
    "    if type(train_accuracy_list) == np.ndarray:\n",
    "        train_accuracy_list = train_accuracy_list.tolist() \n",
    "    train_accuracy_list.append([train_acc, e])\n",
    "    if type(test_accuracy_list) == np.ndarray:\n",
    "        test_accuracy_list = test_accuracy_list.tolist()\n",
    "    test_accuracy_list.append([test_acc, e])\n",
    "    timestamp_test = utils.time_delta_now(time_test)\n",
    "\n",
    "    # Save and print statistics\n",
    "    print('Train Acc: {:0.2f}%, Test Acc: {:0.2f}%'.format(train_acc, test_acc))\n",
    "    print('--- train time:  ---', timestamp_train)\n",
    "    print('--- test time:  ---' , timestamp_test)\n",
    "    print('--- total time from beginning:  ---', utils.time_delta_now(startTime))\n",
    "\n",
    "# Export results\n",
    "train_accuracy_list = np.array(train_accuracy_list)\n",
    "test_accuracy_list = np.array(test_accuracy_list)\n",
    "\n",
    "suffix = \"_epoch:{}_bs:{}_lr:{}\".format(params['epochs'], params['bs'], params['lr'])\n",
    "export_path = \"../../results/plots\" + suffix +\".svg\"\n",
    "\n",
    "\n",
    "utils.export_results(test_accuracy_list=test_accuracy_list, train_accuracy_list=train_accuracy_list,\n",
    "                     export_path = export_path, export_suffix = suffix, confusion_matrix=confusion_matrix_list[-1],\n",
    "                     exp_dict=params\n",
    "                     )\n",
    "print(\"results saved\\n\\n\")            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## One example\n",
    "### 1. Query $Q$ under consideration is ':- not addition(i1, i2, 15).'\n",
    "### 2. Output probabilities for both images: $P(C_{j}|X=i_{1})$ and $P(C_{j}|X=i_{2})$\n",
    "### 3. Stable models which fulfils $I\\models Q$\n",
    "### 4. Gradients $\\cfrac{\\partial \\log(P_{\\Pi(\\theta)}(Q))}{\\partial p}$\n",
    "### 5. Probability of the query $P_{\\Pi}(Q)$ "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The query of for the sum of 15 ':- not addition(i1, i2, 15).'\n",
      "\n",
      "Output probabilities, P(C_j|i1), for the first image are \n",
      " [0.   0.   0.   0.   0.   0.04 0.   0.02 0.01 0.94]\n",
      "Output probabilities, P(C_j|i2), for the first image are \n",
      " [0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]\n",
      "The belonging stable models:\n",
      "I_1 is ['digit(0,i1,6)', 'digit(0,i2,9)', 'addition(i1,i1,12)', 'addition(i2,i1,15)', 'addition(i1,i2,15)', 'addition(i2,i2,18)', 'img(i1)', 'img(i2)']\n",
      "\n",
      "I_2 is ['digit(0,i1,8)', 'digit(0,i2,7)', 'addition(i2,i1,15)', 'addition(i1,i1,16)', 'addition(i2,i2,14)', 'addition(i1,i2,15)', 'img(i1)', 'img(i2)']\n",
      "\n",
      "I_3 is ['digit(0,i1,9)', 'digit(0,i2,6)', 'addition(i2,i1,15)', 'addition(i1,i1,18)', 'addition(i2,i2,12)', 'addition(i1,i2,15)', 'img(i1)', 'img(i2)']\n",
      "\n",
      "I_4 is ['digit(0,i1,7)', 'digit(0,i2,8)', 'addition(i1,i1,14)', 'addition(i2,i1,15)', 'addition(i1,i2,15)', 'addition(i2,i2,16)', 'img(i1)', 'img(i2)']\n",
      "\n",
      "The gradients with respect to the output probabilities\n",
      "[-1.07 -1.07 -1.07 -1.07 -1.07 -1.07 -1.07 -1.07 -1.07  1.07]\n",
      "[-1.03 -1.03 -1.03 -1.03 -1.03 -1.03  0.97 -1.01 -0.99 -1.03]\n",
      " \n",
      "The probability for the query is 0.9358996748924255\n"
     ]
    }
   ],
   "source": [
    "print(f\"The query of for the sum of 15 '{queryList[29918]}'\\n\")\n",
    "print(f\"Output probabilities, P(C_j|i1), for the first image are \\n {np.around(SLASHobj.networkOutputs['digit']['i1'][14].tolist(), decimals=2)}\")\n",
    "print(f\"Output probabilities, P(C_j|i2), for the first image are \\n {np.around(SLASHobj.networkOutputs['digit']['i2'][14].tolist(), decimals=2)}\")\n",
    "print(\"The belonging stable models:\")\n",
    "for i in range(SLASHobj.stableModels[14].__len__()):\n",
    "    print(f\"I_{i+1} is {SLASHobj.stableModels[14][i]}\\n\")\n",
    "print(f\"The gradients with respect to the output probabilities\")\n",
    "print(f\"{np.around(list(SLASHobj.networkGradients[14][0]), decimals=2)}\")\n",
    "print(f\"{np.around(list(SLASHobj.networkGradients[14][1]), decimals=2)}\\n \")\n",
    "print(f\"The probability for the query is {SLASHobj.prob_q[14]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
