{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Details\n",
    "\n",
    "The notebook contains the synthetic data experiment with low noise (sigma^2 = 0.05). Due to the non-convexity of the HNCPD task, results will vary run to run, so the results from this notebook will not exactly match those in the paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loading packages and functions\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"./src\")\n",
    "import torch\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from NNCPD import NNCPD, weights_H, Energy_Loss_Tensor, Recon_Loss, L21_Norm, outer_product, outer_product_np, PTF, random_NNCPD, Fro_Norm\n",
    "from lsqnonneg_module import LsqNonneg\n",
    "from trainNNCPD import train\n",
    "#\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from writer import Writer\n",
    "\n",
    "from sklearn.decomposition import NMF\n",
    "\n",
    "import tensorly as tl\n",
    "from tensorly import unfold as tl_unfold\n",
    "from tensorly.decomposition import parafac, non_negative_parafac"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Data Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## set the network parameters\n",
    "torch.set_default_tensor_type(torch.DoubleTensor)\n",
    "\n",
    "\n",
    "n1 = 40\n",
    "n2 = 40\n",
    "n3 = 40\n",
    "\n",
    "r = 12\n",
    "\n",
    "\n",
    "a = 1\n",
    "b = 3\n",
    "\n",
    "X = np.zeros((40,40,40))\n",
    "\n",
    "\n",
    "X[0:10,0:15,0:15] = a * np.ones((10,15,15))\n",
    "X[10:25,15:25,15:30] = a * np.ones((15,10,15))\n",
    "X[25:40,25:40,30:40] = a * np.ones((15,15,10))\n",
    "\n",
    "X[0:5,0:10,0:5] = b\n",
    "X[5:10,10:15,5:15] = b\n",
    "\n",
    "X[10:15,15:18,15:20] = b\n",
    "X[15:20,18:23,20:25] = b\n",
    "X[20:25,23:25,25:30] = b\n",
    "\n",
    "X[25:30,25:35,30:35] = b\n",
    "X[30:40,35:40,35:40] = b\n",
    "\n",
    "#add Gaussian Noise\n",
    "np.random.seed(1)\n",
    "X = X + 0.05 * np.abs(np.random.randn(40, 40, 40))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#set a consistent vmin and vmax for visualization\n",
    "vmin=0\n",
    "vmax=3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize Original Data Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAEkCAYAAABaExIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dzY9t2XnX8WedunWrbt1ODMgWIpZFBhkStz2JTCRADG1FMOZf6CGD4GRMZGIRS0ioB0SyIEi8DJOBEQMUEQk5YkLapkUkPEFMEjsgBd9br11nMWhboPb5feuep+qUV3V/P8Ns75ez91rP3qsr9/eMOWdJkiRJkn76Nj/tC5AkSZIkfcgFmiRJkiQtwgWaJEmSJC3CBZokSZIkLcIFmiRJkiQtwgWaJEmSJC3i2T7/40//paP58587PtS1vLH3/+QzcduA/aihAO4HGwccNG2ifehCaLfnf3YBx6SDPnCbhcc8V1X/gbfO1fxth7gnGzjmtnFMuo90/XSusN/F9od1vb2kM6Jl6tD3cx2qLezYHLPdGxYP+cB1rarq+f+COnSI6tyZP9352J1zT6JGtU/Y2y1dCx2uUWvIxfZVXc+nX4dW0v0265jw54Xn329+D2EBbsz/7o9+4Hfuh5rFJn7INk91iO+hxvz/0Qn3P167Zu/e8WL7Kn4P7bVA+/nPHdd//vefa1zYw3r76+/EbeM27zeP8jba7/Ykbzu6hvOFZ7y5gX3gidBC8ee++d24bRzlSjZvPoAT7l8IxnH+AfOWvlrBNu83jvJDxR5/cMza7L5f4xn8tg/yfcT9rq7ydcBvGyfP8zEvLvMx0/HgXPU8f4TQudJY+Pb/+d03vq5dVqlDf+2f5Dp0BI+gW4fwP+zQgipMf7qODZUF+DD67Df/Kxy0+dEB+8WaQjWjWaMGzYPrXNQHvOxbNSrUp6o7as11flnh/Cf0TAe8d8K1YD2H66dame7XH776vbzPG1ilDq2k+23W8cFZ3va5f/pevg6Y/1iHaP5fhfHcXThQjaJ3Llxj6z/cVq5RWNdu88Mez+Hbhb6H4Hx0zJrwrZfuSXeMNM717Vf5e8j/F0dJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWsReISGrwH9seoAQGwz1oPM1/nH+gH9jiP/cFP5RZkFISPcfnKfz0T+WJ+MkJ7HQPw6nf2SP4Sh0v9I+EATSRv+ImK6RQknO8r+enufnb3JVHzlgb/zkf0x7gDTPn4JncCspSIMCOAjVhi1M1U2YPhg0SWFLFHBF/zh8A//wncIybuCYIRwCRxiM5/GcQiqgtsHvnhRgcQPHTGFFFFZwCf/IHnQDlcbz07wfBa6kGgUhARguQOMu1diHjhUU1qh4vylzAabO8Ss4FYVGYOgQzMcQBFJVNU53z4P5+nXepxk61E147HzzVEFwDwUE4QuQgoWaCY9Uf+l3pxqb9+AgEApG+iDUNsrHouuQJEmSJD0eF2iSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0iCcZs08w5hWWo5QK2oqOrZyeSeeiyM1B8dwU5QzRqxRFT2J0KVwHxt5TvPKLHOU8X0PW+fMcNz+eQRxtiHPFe0VR9LDfOIbocYjMnhBhXdeNaF9qqXAD8cJnL/J+qRUARek+IYPioWE/io6m7Ptue45Uo+j6t5BSjfn8hGoURbLTfimymYosxCTPa7gOmiNU226gcB/nqO1UbyiCu3Wv6o56Qq00aD+Kt07PgCK4SS9BXA+s03ao2zaJv6NoXMLpGq2Fqiq/I+FdR3UBW0pAix1sOwTfGh1YMw6A2hLgNy59Y4VngDWWnult/lbqfPf4FzRJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWoQLNEmSJElahAs0SZIkSVrE04zZb0TbV3Gc64Q7sYV0TzrmJiR1UhQtRWlPug6KgIcIUtxGebQhepkiSDFm+wgiW+kaIaYaY3Epzvkm7NeMV8Xoa0K/jdAzSPcSYmUxuhfaI8QYXhxzT8eWEs0hWR3rCebzw7ZGTcHWInD9GG9NYD5SvDW24AjRy3i8a5irEG+NqM0GxNRzZPPu+UNx09SaoxtFT/WLIr/5mLufD10/txCANgepfn08ytBSqN1H/O5pthbi1iIH+NvDAeZBPuABIuzhvUvXn9oc4Xykby96/8O3BrYrwVYg+0fwU8w+tSRqfevBu9u/oEmSJEnSIlygSZIkSdIiXKBJkiRJ0iJcoEmSJEnSIlygSZIkSdIinmaKYzN9aQu/ltITKZENpfAoCueD35ZSIasgMa+qagvbKP2GhGM2g+YQpq6dnuQdKfWnkUCEv40SnijRqIuS4UIyUVVOJ6LUOxw/Jzn1bqSYLUqZfEI2lApLoX6U8EghUHDbKP0x1TZKoCX0uzFND+YjjdlxAnM8za1NHrPdVNVOmmRVYaLZbMwFTDqkuU/vCELPhq6f0l9TQi2lWl5cwvEanzIfkzq0FEqka3xHUR2l7zlKQaSEvnH2Ih+TkhrDHMf3Ko1ZqpXtpGy4mY3fRvUEr/GYXmQ0gPL1b16exW3bH77a+3yYAgx1rZtUnPgXNEmSJElahAs0SZIkSVqECzRJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWsSTjNmn6NWU7F3FUdSbnOSO0dcYfZ/uLuW1N+Ky73KI6OV0zHaUczeKnvaj6HuKlU37QbwqRulTJD5F31IsK91nOOYY+/83GWwhcAUtEGJkbrNPxo+8/yefqbd/852d2zDCHpLQU+Q8zbn3fv3dvPET6iu/8zfzxm68PbXZaMSrY0wy1QVCv23AMak2hPlDtQajr09P836vX+dt8G7EOGpoS5Ai88fLl/lcL+D6IYI/3mMaB2/g/T/9TP3iN3bXoWcXeT+qQ7E1Dw0veuXSNw8NdUg7p++h9766Rk38yr/4W3EbtsSACP7Cdh+N7x6ax9SiAlrbUP2aVxDdT/M4Rf5TSxW6/g19UFNfBWidcAktOB66BRL9Nqhf8Rppjr7pNUmSJEmSDssFmiRJkiQtwgWaJEmSJC3CBZokSZIkLcIFmiRJkiQtwgWaJEmSJC3iScbsbyAZ88Ej8e84X2eJS20CKPoWWwjESPPCCGjYq2oL+Z9hG8a1dqLt7wLXOG8hHxqM0xCnm+Jm6464VoiHpUh8jLBvRgXH6+xGTlOEbTrm/dKta8w7IqLTfvCIaE7qzWGEPc0DqrGdsQl1AecO1ROKeaZ47m50dPrdVM+pnkAUddH7g0x8KeVN6TdQNH+3lUl63pvmb/7xcWfVs3BLt5CETl0XEvoGodo14SdSdD/t9xTEaPiqVmuLqqq6gfdxir6nugAtb8YzeKid9/tdzs7ytYQ5iVH6VIfOz2G/xvdE3dFeBOr9SDWAvkfpHsMY6fCzRJIkSZIW4QJNkiRJkhbhAk2SJEmSFuECTZIkSZIW4QJNkiRJkhbhAk2SJEmSFvEkY/bHbY7b3EBwPMXKEoruJyNcJsX9U5Q+xuJC9Gc3VhqleFKIXe5G8GNkK8XpAopzTXH6GCt7DL8N4/nhgRN6bhTDnZ4btitojq30u+8Z4Tw3EGMNtxOj+dM1NR/PJxbExsf2FVU8/qA2xPlDUcjNVgDjxWk+JLTEaEvXQrULW6PAb2u2OaH2AuMEnnd6PlBHu89tXqQs/Hv2+6gccY9R+hR936lD9DPgsQ5qO9T85llFdzxj6wWqUWn+U2uRZmsb+h4a9MBpblGNDdvwew5j73MPCmzFRC0QMPoe1gTpGcA3G7YCaLYCibvsvYckSZIk6SBcoEmSJEnSIlygSZIkSdIiXKBJkiRJ0iJcoEmSJEnSIlygSZIkSdIinmTM/u0J5dQ+/PkonjvF7OI2WhY3I8gx+jPFDFdh9OegWNNGRPG8hShquA6MbKU4XdoG7QBqhGu5huuHe9WN4MdYcohzxej+EPuL95GezeVV3i9F2N5zjo6ZY6y3kPp/S0MlPNoYe62dMEKZ6hDVGoonTuMZIt6xpQfMOYzS77YXoDYh6V5S7SKQAE2wvQhEbXON2r2NaiVpxarft91HwTseatwGnkNqH5Ja9lRxjaK4f2ofRDH72K5kETRmca5u4F1N8yBE39OrDr+v6HuOouhvoO0FtBDAb4Z0LXC8Qe2uqP5C2x5s6QH3ktsEhQlH7Ra632yNtkP+BU2SJEmSFuECTZIkSZIW4QJNkiRJkhbhAk2SJEmSFuECTZIkSZIW4QJNkiRJkhbxJGP2Nzlts7bwizA6tnlMir+N+3Tjjmk5TbGgtCNEzmJMfTjfISLxJ8Tb4+2nOFeIh47tACBmFyN4R/5teE8gHrbT5qCqYmQ5Xv8JRKe/ep33SxG2DxFdH45B83jAUE9TpDO/FXRbUXRqSjeKHuYVRUAX1ZPzCzgfXGeKoqda2YmUrjuir7EFQnMy48ssoHpIz6ZbK+8wKr/LsdZAyY/7wU/AJ0CPjl4vdNAnUBPx/Q5jHWtNo+0QvTsLYuPb8+oY5gF9a0C7HIyOjzvR/M73mNr2tFppVHGbllhTmueiGtXgX9AkSZIkaREu0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkRg+K1P+rsL39u/sLf+/s7t01KBYVtMXIe9vmjr74LJ/tk+vIv/HLcNl6+jNvm+fmDXgdGOUNc/jg7y/s1rxHjqOFaYiwrxYTT776C6NhuPDTEkuN+IX6c4oU7UeBVFSN///DV79Wf3/5ZO2z/U0efnl966+/s3IbxyjT+0v3sRnQ3x157/tB4SMfD+HpqzZHvybf++A/2vo5D+OI/fCduuz3N+3VboGBkOSU2038iDbd5k1PC6xZSvT/729+FHeEHULz1TY4Kp7Gc6hfGbNPx4PrTN86D1KGzX9m5Da8VWkDMFL1Oz4DqM0SMU2w8tmug2kbtWlJNoRpLcfOPXIfe/s1cU1J7BOomge0WqA7BMWm/n/ud9/NGnONQcAJqZYDnAvRub3vo76HGfPv2q9+tP/9gdx3yL2iSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrSIRjbz7v/zoGRvStUMKaoUT6odIDK3KOa9GXmaYLwqof0oupdi1Wm/05N8vuPdEcnz8jLvA9HJ4wTORfHQFH1N94vGQoqBpehYEqL0P7yMcB3tYOsf2YzcnoDuGbUU6cbpJ3QdEMVL44j2w8j89BzgmT/w3Xh8NAWo1DQjrNvvK7jRncj/o1zqUbeO0vzH/V7vbp0yXkK7lRRBX4VjObagoDp5X43Y/6qK10RtLw4xV/E9TlHi0DYmPj96dtWL9D+EFKVfVTXT9IFLPIJOP9S2CltaEXjHzVtolxHGZDdKH9s7wDOldkXYAqnRFgLbzFDdoLZJVL8Cl0GSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0iL1THFNa4xbCADcQXmJa48PApJpmohym7YRjxsSsqipICqPUq3H2Il/Hq9f5mHRPKLXwZv8kSsqTGhRxSqk/F3BPXpzma4G0oPhM4bm1EgILxt19w7dm5UQqSmOkxNJNSKqCsYBjndxSmlYvTRNTp9LvpjnXea5PBKWgUVLbLYSxUiIbHbOTaIrXT6fq1t+T3am2VcXJZFCj6iJcKc03GpP5TLkePnIK4I9hYuZ1GEhQ1+j5zPOLfCHw7qF7jbWG6m8nKZjej1cw6R5ZSly9hamDibFwq7CeEJiP4wISqtM+lE5NKKn4Bt4tNCbp+4VqykOnmNOYvEkppnnOuDySJEmSpEW4QJMkSZKkRbhAkyRJkqRFuECTJEmSpEW4QJMkSZKkRbhAkyRJkqRF7J0TnSJ+KUq/EyV87yjuTxqMo4WIXopeBumRUqQpRvpDdO+ECNjxqZ/N54P9JkQsd6LoKR4az0Ux1U2DxkK4Tox3hwhubFdwDNdxLxOvKcExFu4LtjOgcQnxvXgd1NKgGW89b652H4/GM8VlPwGTOio0uwR028VsoCRiPHw4JnXt2Db/kyvVKKqj2N6FWqDAuyCei2K2TyHyO9XmznfJ/28zYn3A2hBaenx4zHBfaJ8PYIDBPaO2I+M55MMPeHbUrqERaU6x5TQeDoFqSprI1H5jC+WXagYVjS08NprHOL7idcC7n95HVGu634j0HUU1KlwLtjKg92bnOmAc+xc0SZIkSVqECzRJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWoQLNEmSJElaxH4Z6xPiJx94qYcxl/oJ8/wib6SYUYropcjcFOVOedOEYr0hepWinClGfFDGcroWaiEAKA54NuLiq6rm5e7o9A83Qqx/eG54PyjeHSKXcWzdy8ixwDSOKOY5xeNCbO44O4vb5msYl9TaAtoWYPA97BfjrXHOwXhoxtQ/KoqihmFA7x2KvqZ4azofRr2nbXD/8b3ZrDWIxgm8C1JNxAjukxylP8/P87nie+CeMe1zxvpAtaGuoHaHeUxzH9tvUDuD5vsMxxGNh1Bv6B2IbRwOMZ7BeODTYfsN+oyiwQDzH+fWTaNtDXxf4TOFlhgUl9+N0sfvr/SNAq2K6DrwXI3WIv4FTZIkSZIW4QJNkiRJkhbhAk2SJEmSFuECTZIkSZIW4QJNkiRJkhbhAk2SJEmSFrFXzP7x/76sv/Kv/tveJ5kQK5tivydEwH75t78Ut23+4l/I1/E6R/FSJPh4+TJuq5vrvC2d6xayUCEum/y77/2n1n4P7Qtfeydu2+bkUoyHpnjbSSOYbmUj3npQLC5EaX/2n7+fN1JkNsXCU5wxxbhTVG2AcfkUb5vinyGK9o3FeH+4oVBTUvQ9RjlDtPchUJz2pLrRaUkA28bznCn/9tfz/CfdOT7DT/vOr73buo5D+MI/yveEfvc2/O5uCxp8x529yDvS/KcxBHNnNmoARum/OM07xpqHjSve4IIguvzyMu5GUdypRQq15sDnSrHfFLveadtRheMhXicdj0Bs+ef/cZ5z3Xj79xapKd0a236mHdQaCVoEYVuF7ncDROanGoVR+jQXOy0j4Gf5FzRJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWoQLNEmSJElahAs0SZIkSVrEXjH7VTPHUlJkK0VnpuhriOmkmOftq9f5Oigym66xGaedYjUpprNmjkLFeP4nDmP2KQ25GZm7gbT5GJlPKa9wjdQyohtvS8csGCfjNMRRQ7sIHK8gjleqFW925DyXjyECl6JzU4Qv1CGyeSu35pgXF/k6jk/yfnT9p3m/FH09J4whMKBG4Tym7iIwDXAe79/l5NGlVgBV3J4j3S+6H3i8ExhbV3Aj6b0J0dephU4V1AaKqYYo/XkBsfbxe+Ge7T7GyO05rvP9nPhtA5HgjX0w0pxqJbXtoHcW1YZOKxOqa/Db6NHS/Llv54XHQHWU5n/3XZZi6rFmkObzxvEKY4GuM31/dVsL4bdN3C8PVv+CJkmSJEmLcIEmSZIkSYtwgSZJkiRJi3CBJkmSJEmLcIEmSZIkSYtwgSZJkiRJi+jlZ+8QY6qLY1RjzCXFVUKUK8b0UoTtOMBaNV0nXCNGzj4FEG87mj8NY3EJXMuWhsJDdzOAeFsc5xB1TJHTBXHs83J3HPU4gmukMUmR0akm3DPdurYzx1h3509qidGsaxSznWLvP9wvx4WTeUnZy7vPR7HrWGOhVtIcpwhobHtBr4Jel4rH1Y3Zb6RYD+qcQO9GGOdFbTY2VGQbmeU03zBKHwo6jeX7mDPGcVMroFaNwn3yIMI2LrQNWrVgbSPpXUHvR3jmNFaonmBLnKdQT0D3Wwm/0cPjxpoBdWFSTxVoH9RZR1QVj/Ok+W6nVgD5mHlA+hc0SZIkSVqECzRJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWoQLNEmSJElaxIPF7GNceCNmHOOtIYpzvHyZr+PqCq6D8okBRX+meGKKBD6H3wb7feFr78Rt7XjrRjrpH3313f13OpDP/1a+J53If9pnC4nlGEdL5/uZt/L5Xp/n/U4hgv887Af70NVPiGPOUbv3zNk/2tR4K8xzajFANSpFzocY7ao7on2bcf/j7EVrv04EOcZl0/UfU4+KvAnrCSWyN9LaV0J1A+tQ6lIBr6pbmPopEv7DgzZrFES8xzlVVeMmjD2oQ/P163wual2RxjLVgzcxcm2nex2/C6pySwB49+PcpxoF94zarnTHShoPM72T6o56CPeY4uZpzm2an4GPieroUf7E7X2rku6YpCh6GpNnZ3k/GgsUi5/mB6xZqP0RyXUo7+Nf0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkRLtAkSZIkaRH7ZWvOHBVJca4Yb50iKyGKcxw9z8eDKP0c+11Vt83I2ZN8LfNViAWmCGuKsO1EoVbVoFRQOB3GysMjWEU33jpG9MIwoHjeGK9axVHPlzmqFsEcSFG1GHXcjJM/WLz1dta8CtG5EIE7nsNcDfca233Qc6WY3hvobUHx3BRbDvUytSWh2GKKNMd7gj0Z8iaMm6f/jNhLPH5U3fYCI21r1iGM0oZ3I7XSwDkANSUZFLMN4x8jxNPc6MbFH1C616PgwVKtoXNRrTlE3Us1lmpXqvNVXOuhLuB8fAL1BL/LYIrT/MexkNq0HKC9Q9F+9G1PGu11RjO2f0Kbn1i/LvJc8y9okiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStIj9YgHHiElo8+Ii70bpaSlh5bgXE0hpNJTUiGl0cL5BqTkvXuw+XjM9qSAhhpKJMLGQgqxo2xNIO6IHF5MaK6en0T5tlH4K6WmYskWpRSlZDdK3MLWLUpcOZQxOpEso6S0kVVVIQKzi9DH0bHddqKqakFRF2/A5hGeLyZ10f+Fc23Abq6qOIJCNYDJZeKRf+No7sBNsovQ3CgGGgMHv/uq7eeMj+sq//Nt5I80NSn+DRLNxun/qYkoc/XBjIw26qsZxuI57hziOOBcGJh02TjxyfR5UgmE/SozFa4QX4Tg9zfvdhLFCv+05/DgYK1QzNhCiG5NTq+rzv5VrylEINKS060FDHX72e1/t1ZMv/7NfztfSSUilxM5mKiymGNO3MbyTBl3LD3+4e5/w7V7FdW1ewvs7/Taoa/4FTZIkSZIW4QJNkiRJkhbhAk2SJEmSFuECTZIkSZIW4QJNkiRJkhbhAk2SJEmSFrFfTvScOSrymPKV94/ApAhojPeEuF3cj2Kq4ZjzIuSrFsRwUwRvivsujiDeQGQrRV9TXD5FvT4JlGbciOCne3xLcboUpQ8RqzH6vYrH0A1EZjdi8SfMm3ECLTRgbhwKRtFDTG/agi0L6D6/gLhpaJeBcccA24ukc0GbAGzxMPO2I7j9WE9grlIsdpqr3bYj3XjuZ48/1PdGcwPbKsDz7kTpV1WOp9/m+Ya1kmpsir6Gc72xdG8o2huiuNPvwGd3liPBKcIbW7XQ/KexQm1C0rVsIVq92caF2mVQraERgbUt1CGM0qeU+v3L+Z1arWkqP7fu8bqtExDVGrqWUL+4XReci1oBNPgXNEmSJElahAs0SZIkSVqECzRJkiRJWoQLNEmSJElahAs0SZIkSVqECzRJkiRJWsT+OZkUYx+0Iispmv8oryspJpX2IxN+cyeqGiO4KfoX4s63lEa7/yOrKo6qxRjbRXSjamdISqVYXIruxnhoiKkviDquQRm9sA3GawSx8BgZn871EFG0Kd66WzfS76Brhfvcar9RhZHmGH1NtS3EIXevMd6ruiNKH4sz7NZI2qZI/FtIhm/Hcz9suvJBUNsOGlt1C/WrG1Ufnim2tSAwXuOjufczm/F9vYW4+Q28/2MrnW7NpNhyauNydZ23QY2i9gJxHNH7ilArA3pXU2sh2A/LV+d7CG4Vfs814XcIPdPGNz+OkRu4kzQmaR1B7RgoFr9Tv6hdT7N1TeJf0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkRLtAkSZIkaRH7Z253I1GTRnxsjKKte8RcUpQoxWlTBHn6bRR9C9GxFKW9oSRRSC4dzSj65Iu/8U7chhG2kJJKEfYDUlL/y6+/mzc+oi9/82/kjRAdS+OcDIqcTbGyNO6arQBie41ulPb/O0D8HRg3T60J0j3rRAzfAePym3HaHJMeIrOprtHxaHyBCW8brEN0S8JQwprXrCdbun445iqwntB7neYAvW/pPR3qDY5/inCHc8V48QcoQ+nYWIOxFUjYRr8PIvGxtcL5RdxG8LZRS6Lnu98j1IZpUkQ6bKI2GxjBDz8Ov186n8WP3ZoD7iU+01QbMJq/N7nmq9d54zEUdWr3E9rMVMH8pToE87e9xgj8C5okSZIkLcIFmiRJkiQtwgWaJEmSJC3CBZokSZIkLcIFmiRJkiQtwgWaJEmSJC1i/5j9gOKtMXoyHQ9iLmNsbhXGe/ajM3M8KcXYxmNS3DFFgUOsKcVKY+QsRb3CthEuha6Dslwxphqu44MT2G8RGPN6nSOSEUWkw/wYaQzd5H0w6pg8dEuOH5twTZh3TFH0u483TvIAm1dX+XgUs92MO6ZxRBHkKe48xV5/uFMvQhnncTPWnI6Z2nOk+nTXNrwOKNvUJmQV7edN7014X1HLmwdH8w1i3A8Fvwto/ne+lagudL9d4LlSLH6rhQpF89N3GZ3rABH2nbqBteuhY/vvuhZqudBp6UPjjuoh1JpxCu9bak9D31gwXuN+8N7HevLAtca/oEmSJEnSIlygSZIkSdIiXKBJkiRJ0iJcoEmSJEnSIlygSZIkSdIiXKBJkiRJ0iIGRpV+xKeOPj2/dPYruw9EEb4U053ioSnmleKtIRKYIluLImcpOhMip1P0L8bU0r2i6z+ERvwtRrxDKwZEUbuXMBaaMe8pehXjWmHcfeuP/6B1HYfw9m++s/P/vqHk5O5/xgm35Hv/+ht1/qf/sx2EjHWI5kgjih51IqXrjnhrArUGWyGEeYCxy9e5Nwe2QOm2ZGiKUUkdoOQAABFaSURBVO6NulzFbS82P/NW3u/8Im/rxDxX5fcHvP/oeCvVoS98bXcdGs1WDFuY9ike/Xv/5p516Nmn519/6+/u3gi1hsSWGC9O8z4w9misd+P5CX5PpmM2v2uwtRPUc3xHAKxt4Xe3W6NA/cVvXPgO/9b7v5+P+Qn19td31yFsCUUdSRrtov77v/1GnX9/dx3yL2iSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrQIyADdYfQiSim6uBOxitGlJyf5Oi4v8zFvc3bmhG0YHR9iWXEf2tZsPUD3i6Jj6Vmn/XB8UPQ1mLcQOUtR+o02AVUQiw3He/QWCE0pxnrC5WPMPj3SECvbTNL+yEF2HwXH84uXeVtqrXCVaxfOY4rEpwhoiEmesxkdHeKc50WuhxhF320vkCLxi9tljE/9bN4vRIxTK4CaUBfo/sP7gyKzB7yTMM481Vhq03AM7W5WEm4XxeVT4dhQAvrB/jP0yPMExhi2hjne/WxpruL4guvotvTB7zmqDWlONt/TVGOx1lD9hf3GpnFMitKn70r6Zuu2adJPSJH5E14fGKUPtSYeE87lX9AkSZIkaREu0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkR+8XsA4p5pyj0FKNKcdkYpU1x8xS9TMeE/SiyNR4P4loRRajSNTaj9OvFad726vXuc0HMbjeWfPPyLG6j6OsJ8b14nSe7o6oxlpyijhfSinqFLgdbSPxOEbZwqjczRo4ap1YOUBtS5DHWDIrSpyhkmnM0V8O4rKqq6/yQZtoGv63dQoBitunZULsM+G0xoptitiHeuqoXfU3Pbd5CLDn97hSzT9dxA+daSGrrgRHWMHzarUDuY8743TNe5pYeRTHvYU7Oxji5C9aobisNaAGR6hCOZ9Kto1AbsLZ17jO0RhrUUoJqZaOlivaT2hFVcUuih+Zf0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkR+8XnjBHTsQalx1AiTULJj5RYSOl8lMhG6UPnF/mYlECUEiohPYnS0+Z1Lw2MkpXIDEmNVXckIcWd4LdNSJSipMZLSO18nlPvaJzMq933mRIv5/l53LaSlKxIKWidpMaqnBh571C1UXlu3ebEv0H5kSmhi5JMKY2tm0wG0risuqO2pW1QMzDpENJRMU22Wb+4XqZxQEmNGb5bmtePtRmSOdPzxjS5XqDfoxv7hx9zUiPtF+pXTLR9UwPm+RYeBI2VhFJC4b1E4xlrFL3foe5hqnUjsRvnTjfpu+A9TtfSSLbEpEyA9ZzO10gV/0QLjxS/hyCw+wiCvreh1FMd8i9okiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0iP2yoGdh3GvcrRFdOk5yluW8gNh7ilanyFOK56drucox7/G3UZQ+xaRiXDNcYzPWn2LlO60TMEqf0LmOoT0CxeJCVHirhcATibdOSfMbmqLwn3HuHVXdsZ01r3fH6VNLibRPVdU4e7F7A8S147xqxux3I6cpFhtboCQ05w5RY6kVCFxLPF8zerygjg6q9YTuP43JdJ2nUOvheCuZ4fFQ244BQ5Jq1EhT6r79PqDtELbgoBoV2sa050cjGr6KI+xpPI+Ts7zfTD1e4PppPNN1QJQ+om8sqOmxDtF9pJYeVOsf+ttFP4m61tAjhel2FF4ftI9/QZMkSZKkRbhAkyRJkqRFuECTJEmSpEW4QJMkSZKkRbhAkyRJkqRFuECTJEmSpEXsn8k59l/TDYgFpjjazvEmxWJD5ClFymOcdiP+lq6RImcxlvX6Ou/WjGWmCP4YY03xwhQ9DpG58+IyH7MZP0wx3PlkEMEL0eNPAcXlU/Q1/Seen0YEP40VlOY4jSGKa2/CmORjGGMUAZ/mVnd+pLjsuiPCnmKqoSbiMcM23IdqzWVz/JDb++a5f0SIYr/LL37jnbiN5vgGks6xBQcMoe/86rt54yP5pf/wg/sdgNoO4RjLzy99h+D8OMl1YVDbC4rnp28UMCp/h8RvDZirVA/xnpyexm3o9et8zJe5hUD6/uq2WyEUwU9tmr74G3n+b+nTLAzxLcxvio6nmHqqJ7ewjMBvFPCdX1ugDv3HXIf8C5okSZIkLcIFmiRJkiQtwgWaJEmSJC3CBZokSZIkLcIFmiRJkiQtwgWaJEmSJC1ivwzQOXOc6Isca4rR1ykyFONhIV4V4qHHCeR0gknR9xTPfxPidOl4TXT/i+JoKfqaovtjCwHIUKV2BXT91F6gGc9NzzRFts4PIB757EXctpL02ygyt2C4YpR++s8/h4zfp3h4qBuxJUazxQbFHdP8x3nQidIvmCPUdqRzvKqqG2jb0W1FAdcyjkNNpxYCUPPGthdh30XjK41ljB6HtikbOBVFX5MUwV1VVc1jPhlz5nkOUeg071rtX6CeYB2Ca8Q2FXQp9D5O9bIZ94/vd4jLx1h/qOnbH77K+6VWRvTb6PuK2q3cwPihCH4YkhRTn74NKEr/EI7y7Wpd/1PgX9AkSZIkaREu0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkR+8XsV+UYT4pDpcjQ4xC9DJHSk6JEKcL6CNajFG8L0cXzEqLXwzHxGiECmswryCBtwjjnEN9LbQ4ogp/ioRFEp1MKbOu3UUsCaiWxkBi1C6nKFKVP0d3btN8B43nxGUGkeZqTFLtcA+oJxWxDFH2nntwl1j2K0qZ2JRQdfQz3C+Ly6xYGBUXmb8MzpToEEdxYf5utB/C9Qy1jwrVgdDo+t7xbd/4XHJNivT8WBsxJGusnL/O2VKMoWp1i+5tR+thaCLTmT5rDdY9WOc2WSjVhQFNN77RpwW+lXispbElC7XIaEfy0D7bfOECbHYrSf+x2AA/p415CJUmSJOnJcIEmSZIkSYtwgSZJkiRJi3CBJkmSJEmLcIEmSZIkSYtwgSZJkiRJi9gv033k2NN5nWPeMU4UIo/j8Sj6miKZCew3TkIrgCqO4A8Rq3T9GMtKsdiH0L2XwTg7i9vm+Xne73m+//PiIp+QYnFBjOCnMQLXuJTGI6V4boznDVG79x7FmxHnELVrwAjoFK+c2oBU1bzMrRW4pQdEQFOEPcVKU4R6jAKHnQ5Rax64nlRVjJWe0KYFWzFQvPgltDKh1i8UpX+6fxz4OIJaQ61MKJ0bupxs4XQUYY1R2x8Lg9vzJLRP+magbyhqu0DvwG7bC4DzLtVsGLPjxWk+GbW2wdYDvdpGNT19/+I3T+Pbt6qq6PsR4vkxir7zjqfhAzWDas2kFUn39WHMviRJkiTpvlygSZIkSdIiXKBJkiRJ0iJcoEmSJEnSIlygSZIkSdIiXKBJkiRJ0iL2i9mvkSNKIW+bYlRTnHM3Vhb3gwjSQXnBFF1MEbExVhYyTSGSGeN5KcqZ9qPI1itonZCi6OlcEMGLMdUU3Usx1QP++0OKVaf9jnOu7LyE37aQFGOL0bfd/4xzqK4Q25nj9GEedFpYYJR+J7a/iuPm07yqqkFx2jTvQo3C2kVzh9qcvILo6Lde5m1Qa1qoTUA3XrwTqV5QK6uqaAyFMT5voS5Tu49m3PRR89HMR+4K8+jmzDXlBtp90CHDM6dag21o6J2Lx6RM9t6DnSEWn+bHPM9tdLBdEX3rdVqS1B3vj9QOgO4/nAvfLWCc5P2o7cUWTpfa5dBA3sCrqvs9Qcck7e+XBTzhS5ckSZKkjxcXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrSIPbM8J0dEBxx9HyKgIdIU4467Ecq0H0ROdwJnx2mIZK2quoFMY4q+plYAFA+bj8jiOIDrmPn+jyOI2W9dB8dRU6x/3YaIZIrFhW1vf/2dfBkQb7+FNHOKCqZ46/f+wbt54yP5pd//wf0OMEYeLycUOZ3HQxybNK/ilhzbj+eqqkHn686fNP+7rS0oZp9qM9X0ZnuR2G6h2W6lC99x6Rqrqqh1Qnre3RYC1EmGupVAPDe254Bh8rF3DBHw9D4O8eo09+sa2i5AXPuk/VJsfFVs/1B1j++2dB1Ua2g/qIf4zUDXT+2iwvPBFlMAnw21JbiAdxz2d8ibUgcqOt6WujRQ+YVtVKOoDuH5Fudf0CRJkiRpES7QJEmSJGkRLtAkSZIkaREu0CRJkiRpES7QJEmSJGkR+6U4zpwmhClilFSVksQg1YsS0ggm+3RTuCilJxxzHEFSG9wrSgobz3N6HV5jV0gLo4Q0TGSCZKXxvPds8J5QutXr17uPR+On6RbCsjYQQjnhUlLq0scHpMlSChcJc4RS0DCV7BBzDsyLy7gtpn5RIi+luHXqeVXNS0iGPIWIrk5yMNUaSsqkhEf6bXBMHEOUyBbuVzsxksoXpbjBUMak2eZUfDI2I48zmv+N8Ywo4RHQ+3GeX8B+DxzPSamKzd+G349Ut2k/msfp/UHzm2pNs0bR+TDNkMJfwzb6BqFvF6oZkwLaobQRSpRcnX9BkyRJkqRFuECTJEmSpEW4QJMkSZKkRbhAkyRJkqRFuECTJEmSpEW4QJMkSZKkRewXsz9Gjhqn6FiIJ46R+RAXirGsFDNMMfUU806R7BDLGiNbmzHVCGNZIQIa4rQHRJbHdgAURU2R+Ng6AZ4btXegyPXrnAMb7wk9NxhbFCv7LKcZc4zt46a4Pxnd6OI0x2nuYKQ/1Rpq1zBo/jcj4MO4xTYUzbGO8/EA0eOxNndjqiFCHNsEUPQ9RYXT++qBUew9lV+qQxTPT9HdHxvpxtE3Cs2DTtuhE2hRQefqtoaguUq1Id0rmHOjOz8oSp+usXe2XGOpfQi9q5qtAHAig84cp/YbE4YPfbtM+GkUwU8tBJ5y2yH/giZJkiRJi3CBJkmSJEmLcIEmSZIkSYtwgSZJkiRJi3CBJkmSJEmLcIEmSZIkSYvYM2a/YvznvIUI9bMX+ZhXObq4g2LjKeYdo+g3Obp0dNoBwPHqtpcJilHOtN8VxM1DZCvFSsdzwTVi9HCzTUDd5t+GseQpTpvuMVzjBpKCZzPemvb7JP9nF2zX0Iij7sanc9sIQDnDFLVN8zHE6U9oNYGtUSjC/iL3jRgv4D1ArQcgjjq20rjOmfIYIQ7tBdpjgX4bxbF3ajqMEYy+pjRzGsq07QnHW7+ROXObhAeOcp/Yvgbe4dRKg965dC2dNgGw3/gA7gh9X8FcPUh7FPoeStdJ9YTaN9E1NlFMPenE1FPs/RauA9sHdWP94TW3uk/wp5wkSZIkrcUFmiRJkiQtwgWaJEmSJC3CBZokSZIkLcIFmiRJkiQtwgWaJEmSJC1iv5h9gLHSEG9fIWp0nue4ZopCxXPdwDaKUIaoVIyjTvcEopXpPmLscjeCG2BUbToXxPpiPDf9Nno21CaA4q3pmDSGErj/B4mphv1GLz35Yw+ji9Pzo/lN46uJag229CCdOG2Ka4d7UsfHreugZ4P3OR2T6ijFkp/mVgyx/UZx3ZvQJoDEGG5q+7Khti+ty+jHczfP93SM+CxwjHVa89BV4Lvs4a8Dz0ctccL5sObR/aB58Bza7zTbLWHdS2096LuAIv2pDtH1UysjuPwNvBpTS5/u/KbYfmwtRK9buJZu/VqBf0GTJEmSpEW4QJMkSZKkRbhAkyRJkqRFuECTJEmSpEW4QJMkSZKkRbhAkyRJkqRFDIw3/+j/eIwfVNX/ONzlSPoE+Ktzzs90d7YOSXoA1iFJP22xDu21QJMkSZIkHY7/L46SJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrQIF2iSJEmStAgXaJIkSZK0CBdokiRJkrSI/wtCTEXyem6U+gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axs = plt.subplots(1, 3, constrained_layout=True, figsize=(12,5))\n",
    "axs[0].axes.get_xaxis().set_ticks([])\n",
    "axs[0].axes.get_yaxis().set_ticks([])\n",
    "axs[1].axes.get_xaxis().set_ticks([])\n",
    "axs[1].axes.get_yaxis().set_ticks([])\n",
    "axs[2].axes.get_xaxis().set_ticks([])\n",
    "axs[2].axes.get_yaxis().set_ticks([])\n",
    "color = \"viridis\"\n",
    "X_max = np.max(X,axis=0)\n",
    "axs[0].imshow(X_max, cmap=plt.get_cmap(color), vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(X,axis=1)\n",
    "axs[1].imshow(X_max, cmap=plt.get_cmap(color), vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(X,axis=2)\n",
    "axs[2].imshow(X_max, cmap=plt.get_cmap(color), vmin=vmin, vmax=vmax)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.from_numpy(X)\n",
    "X_np = np.asarray(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Neural NCPD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jvendrow/.local/lib/python3.6/site-packages/torch/nn/_reduction.py:43: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4befea00b6574ebba0d51b248245ab7b"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# unsupervised case,one layer\n",
    "r=7\n",
    "n1,n2,n3 = X.shape\n",
    "net = NNCPD([n1, 5, 3],[n2, 5, 3],[n3, 5, 3])\n",
    "\n",
    "loss_func = Energy_Loss_Tensor()\n",
    "\n",
    "history_unsupervised = train(net, X, loss_func, r, epoch = 10000, lr1 = 0, lr2 = 50)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAWdUlEQVR4nO3dbYxc133f8e9/ZrhLcSlTlMQaFEmJlKAkpRw/yFtFbtrUSGVHcl0pQFxASl3bTQKhD0TVuEBLxYWAqn0TJzDaIEJjNXERFFVkxXFb1qUhpK7zwg2iiooUPZoRJUsm6QeRsp5IirucnX9fzJ3ZeVhqR+SsZu/d7wdYcO65Z++cO3f527Pn3nNvZCaSpPKrTboBkqTxMNAlqSIMdEmqCANdkirCQJekimhM6o0vvfTS3Llz56TeXpJK6ZFHHjmemVuWWjexQN+5cycHDhyY1NtLUilFxItnW+eQiyRVhIEuSRVhoEtSRRjoklQRBrokVYSBLkkVYaBLUkWULtBfPTXP1x7/3qSbIUmrzsQmFp2rPfc9yrcOHef9Oy5i++YNk26OJK0apeuhf+/VNwE4faY14ZZI0upSukCPaP/rk5YkqV/pAr1WJHrLPJekPiMFekTcGBEHI+JQROxdYv1nIuJYRDxWfP3K+JvaVq+1A33BRJekPsueFI2IOnAP8BHgCPBwROzLzKcHqn45M/esQBv7LPbQDXRJ6jVKD/064FBmPp+Z88D9wC0r26yz+/j7tgKLPXVJUtsogb4NONyzfKQoG/QLEfF4RHwlInYstaGIuD0iDkTEgWPHjp1Dc+HKSzcC9tAladC4Tor+T2BnZr4X+GPg95eqlJn3ZuZsZs5u2bLkAzeWta7uGLokLWWUQD8K9Pa4txdlXZn5cmbOFYu/C3xwPM0b1hlqObNgoEtSr1EC/WHg6ojYFRFTwK3Avt4KEbG1Z/Fm4JnxNbHfunq7yfbQJanfsle5ZGYzIvYADwJ14EuZ+VRE3A0cyMx9wD+LiJuBJvAj4DMr1eBOD7254ExRSeo10r1cMnM/sH+g7K6e13cCd463aUvrjKE37aFLUp/SzRSt19pNbrbsoUtSr9IFeqM75GIPXZJ6lS/QHXKRpCWVL9C7Qy4GuiT1Kl2gd0+KepWLJPUpXaB3L1u0hy5JfUoX6J2JRZ4UlaR+pQv0xfuhO+QiSb1KF+jripOi3stFkvqVLtDr3m1RkpZUukDvTCw645CLJPUpbaAvOOQiSX1KF+jd+6E75CJJfUoX6BFBoxZe5SJJA0oX6NC+n4vXoUtSv1IG+rpazZmikjSglIFer4f3cpGkAaUM9IY9dEkaUtJAdwxdkgaVM9DrYQ9dkgaUM9Br4TNFJWlAOQO97hi6JA0qZ6DXvMpFkgaVM9CdWCRJQ0oZ6HUvW5SkIaUM9FpAKw10SepV0kAPzHNJ6lfKQK9H2EOXpAGlDPRwyEWShpQy0GsROK9IkvqVMtDrNYdcJGlQKQPdIRdJGlbKQK9F4LwiSepX0kCHtIcuSX1KGeiOoUvSsFIGeniViyQNGSnQI+LGiDgYEYciYu9b1PuFiMiImB1fE4c59V+Shi0b6BFRB+4BbgJ2A7dFxO4l6l0I3AE8NO5GDqo5U1SShozSQ78OOJSZz2fmPHA/cMsS9f4t8OvA6TG2b0m1WuDNFiWp3yiBvg043LN8pCjriohrgR2Z+b/eakMRcXtEHIiIA8eOHXvbje1ozxQ10SWp13mfFI2IGvAF4F8sVzcz783M2cyc3bJlyzm/p2PokjRslEA/CuzoWd5elHVcCLwH+JOIeAG4Hti3kidG22PoK7V1SSqnUQL9YeDqiNgVEVPArcC+zsrMfC0zL83MnZm5E/gz4ObMPLAiLcaTopK0lGUDPTObwB7gQeAZ4IHMfCoi7o6Im1e6gUupBY6hS9KAxiiVMnM/sH+g7K6z1P3w+TfrrTnkIknDSjlTtFbzpKgkDSpnoDuGLklDShzok26FJK0uJQ10h1wkaVApAz2cKSpJQ0oZ6HXv5SJJQ0oZ6A65SNKwkga6V7lI0qByBnrNJxZJ0qByBrpDLpI0pKSB7pCLJA0qZaCHE4skaUgpA70eAXjHRUnqVcpAr7Xz3GEXSepRzkAvEt0OuiQtKmWghz10SRpSykDvjqEb6JLUVcpAr4VDLpI0qJSB7pCLJA0rZaDXvGxRkoaUMtDrxVUuCwa6JHWVMtA716Eb55K0qJSBHg65SNKQUgZ63YlFkjSklIHu1H9JGlbKQO8MuXhSVJIWlTLQOzNF7aBL0qJSBnqtaLVDLpK0qJyB7r1cJGlIKQM9DHRJGlLKQK97cy5JGlLKQPeyRUkaVspA97JFSRpWykDvzBS1gy5Ji0oZ6A65SNKwkga6Qy6SNGikQI+IGyPiYEQcioi9S6z/RxHxREQ8FhHfiojd42/qopo355KkIcsGekTUgXuAm4DdwG1LBPZ9mfmTmfl+4PPAF8be0h7d+6E75CJJXaP00K8DDmXm85k5D9wP3NJbITNf71mcYYWfPeGQiyQNa4xQZxtwuGf5CPBTg5Ui4p8CnwWmgJ9dakMRcTtwO8Dll1/+dtvaVXNikSQNGdtJ0cy8JzOvAv4V8K/PUufezJzNzNktW7ac83s55CJJw0YJ9KPAjp7l7UXZ2dwP/Pz5NGo5nZOiCwa6JHWNEugPA1dHxK6ImAJuBfb1VoiIq3sW/w7w7PiaOGzxOvSVfBdJKpdlx9AzsxkRe4AHgTrwpcx8KiLuBg5k5j5gT0TcAJwBXgE+vZKN9va5kjRslJOiZOZ+YP9A2V09r+8Yc7veUq37xCIDXZI6Sj5TdMINkaRVpJyB7iPoJGlIOQPdIRdJGlLqQHfIRZIWlTLQ6w65SNKQUga6D4mWpGGlDHSvQ5ekYaUM9Hon0B1Dl6SuUgZ6+Ag6SRpSykBffGKRgS5JHaUM9Lr3Q5ekIaUM9JpDLpI0pJSB3r1s0S66JHWVMtC9H7okDStloNc9KSpJQ0oZ6OFJUUkaUspA7w65mOiS1FXKQHfIRZKGlTLQaw65SNKQUga6U/8laVgpA73udeiSNKSUge6QiyQNK2Wgd4ZcFrx/riR1lTTQg6l6jfkFu+iS1FHKQAeYatSYb9pDl6SOcgf6wsKkmyFJq0Z5A71uD12SepU30B1ykaQ+pQ70OQNdkrpKG+jT9tAlqU9pA719UtRAl6SO8gZ63SEXSepV3kB3yEWS+pQ20B1Dl6R+JQ70OnNNJxZJUkdpA/2CqTpvzhvoktQxUqBHxI0RcTAiDkXE3iXWfzYino6IxyPiGxFxxfib2m/jdIMTc82VfhtJKo1lAz0i6sA9wE3AbuC2iNg9UO1RYDYz3wt8Bfj8uBs6aMNUnVPzC6RPLZIkYLQe+nXAocx8PjPngfuBW3orZOY3M/NUsfhnwPbxNnPYzHSDZiu9dFGSCqME+jbgcM/ykaLsbH4Z+PpSKyLi9og4EBEHjh07NnorlzAzVQfgpMMukgSM+aRoRHwSmAV+Y6n1mXlvZs5m5uyWLVvO671mphsAnPLEqCQB0BihzlFgR8/y9qKsT0TcAHwO+FuZOTee5p3dxiLQPTEqSW2j9NAfBq6OiF0RMQXcCuzrrRARHwC+CNycmS+Nv5nDNnR76Aa6JMEIgZ6ZTWAP8CDwDPBAZj4VEXdHxM1Ftd8ANgJ/GBGPRcS+s2xubDZOt8fQT8w55CJJMNqQC5m5H9g/UHZXz+sbxtyuZW2cXgfA62+eeaffWpJWpdLOFN160XoAvvfqmxNuiSStDqUN9HetX8emC9Zx+JVTy1eWpDWgtIEO8OPvvpA/fe5lDv/oFCfmmrRazhqVtHaNNIa+Wn3qr1/Bnvse5W9+/psARMCGdXVmphtsnG5w8cwU/+BDV3DL+99qHpQkVUOpA/3j772MXZfO8MSR13j99BlOzC1wcq7JybkmJ+aaPPrdV7nj/sd4z7ZNXLVl46SbK0krqtSBDnDNZZu45rJNS6771rPH+eTvPcTLJ+a56vwmpkrSqlfqMfTlrF/X3j0fhCFpLah0oNdrAUDTk6WS1oBKB3qj1t49r36RtBZUOtCLPLeHLmlNqHSgd3roCwa6pDWg0oFeL/bOQJe0FlQ80O2hS1o7qh3o0b7KxUCXtBZUO9DrBrqktaPagd7poaeBLqn6qh3oTiyStIasiUB3YpGktWBNBLo9dElrwZoIdHvoktaCSgd6wx66pDWk0oFeK65yaXmVi6Q1oNKB3u2hLxjokqqv0oFeq3kduqS1o9KBDu0Towut1qSbIUkrbo0E+qRbIUkrr/qBHvbQJa0N1Q90e+iS1ojGpBuw0uYXWnzp/36HViaNWtCo16jXIBNaCUm2X7eSpH2JYyZk9i+/9uYZLp6ZYuN0g5npBjNTdTbPTHHTe7Yy1aj870VJJVD5QP+ln97FAwcO89U/P0KzlTRbyUIrqQUEQUT7evXuv0AERES7TvHvibkmp8+0ih7/4lUztduCv/u+yya3g5JUqHyg773pJ9h700+MbXuZyVyzxQ9eO82Hf/NP+OHrp8e2bUk6H44VvE0Rwfp1dbZvvgCAk3MLE26RJLUZ6OeoUa8x3ahxcr456aZIEmCgn5eN0w1OzhnoklYHA/08bJiuc2reIRdJq8NIgR4RN0bEwYg4FBF7l1j/MxHx5xHRjIhPjL+Zq9PMlD10SavHsoEeEXXgHuAmYDdwW0TsHqj2XeAzwH3jbuBqNjPdcAxd0qoxymWL1wGHMvN5gIi4H7gFeLpTITNfKNatqTmZG6bqvHHaQJe0OowS6NuAwz3LR4CfWpnmlMvG6QaPfvdV/vS540w3asUkpPbkpM5kpaUnLvVPWtp20QU06p7OkHR+3tGJRRFxO3A7wOWXX/5OvvWKuGrLRr7+5A/4xf/00Hlt55PXX86/+/mfHFOrJK1VowT6UWBHz/L2ouxty8x7gXsBZmdnS//Uic9+5Mf46DXv5uTcAnPNBRIgF+//0sqkVRS2km5Z0p5x2srkV7/8F7z48qmJ7oekahgl0B8Gro6IXbSD/FbgF1e0VSVRqwXv3X7ReW3jjx456pUyksZi2YHbzGwCe4AHgWeABzLzqYi4OyJuBoiIvxYRR4C/B3wxIp5ayUZXyQVTXssuaTxGGkPPzP3A/oGyu3peP0x7KEZv0/p1deaaa+riIEkrxEsrJmy6UWPeQJc0Bgb6hE01asw1HXKRdP4M9AmbbtSYO2MPXdL5M9AnbLpRZ86HnkoaAwN9wpoLLeabrb7H2umd9crJeZ7+3uuTbsbYtfyZWnMq/wi61e7gD98A4Kpf28/mDevOa1sRMVw2VGfJ71y2zijbiXPazpINWn47A2XLvXcmLLSSCGjU+le+UEzs2nHxBdQjODG3QKPWvi1DK2H9uhpnFtrhuNDK7nK9Frx6ap4Tc012XjJDK9vPrK0X219oJT86Oc9fuXCaBF58+RSXzEyxacM6Wq1kIZNXTp7hwvUNGvXgzfkFLly/rt32nizuPKz8TLNFo16jUY/F9dFTqdDK7NuvKy7ZwPdfO81UvcamCxZ/xqYbNYj2JrL4jLKYDNe5NcVbHcPeiXS1iG47axHdf4e+fZnD3VxoT7jrffB6JpxZaLXL8izb6Ju0tzi5D9rnqbJnXefB8J1t12tBo97f1na9xc9m4EeG1083OfbGHFdeOtP9DIFux6w++A0D7rjhx7h5BZ5FbKBP2G/fdi2/9t+eYMuF093/iOdiqW9NcoQ6y29nsNaS2xkoG3zvpeuc23YGi5beznBpvdb+j73Qsy4T1tVrPPvSCa69fDOthBeOn+TySzYQtP9jZrZ/CSRw/MQcmzdMsa5eo9lq8cLxk/zo1DzXbNvU958+s/1L5fljJ9l16QwR7VtFbJxusJBJPYJ6LfjWoeP81a3vohbBX/7wDa657F3dIOn9ZVeP9lOyziy0aLayu76V2Q2TiCCLIIV2oH/wis1s33wB79sOj7z4CtdfeQlJ+0HpC63+T7f3IelAMbt54NgPfKad+xYttnlx3wf/QFjqmAzKzvaKugnUi3seNVuLZbHE93Qf8E7/L6MzCzm8jsWNLBQPjx/U146BNz36ypsce2OO3cXx6qgXx2Sw/qDz7bydTYzyIa+E2dnZPHDgwETeW5LKKiIeyczZpdY5hi5JFWGgS1JFGOiSVBEGuiRVhIEuSRVhoEtSRRjoklQRBrokVcTEJhZFxDHgxXP89kuB42NsThm4z2uD+7w2nM8+X5GZW5ZaMbFAPx8RceBsM6Wqyn1eG9zntWGl9tkhF0mqCANdkiqirIF+76QbMAHu89rgPq8NK7LPpRxDlyQNK2sPXZI0wECXpIooXaBHxI0RcTAiDkXE3km351xFxI6I+GZEPB0RT0XEHUX5xRHxxxHxbPHv5qI8IuK3iv1+PCKu7dnWp4v6z0bEpye1T6OKiHpEPBoRXyuWd0XEQ8W+fTkipory6WL5ULF+Z8827izKD0bEz01mT0YTERdFxFci4tsR8UxEfKjqxzkifrX4uX4yIv4gItZX7ThHxJci4qWIeLKnbGzHNSI+GBFPFN/zWzHK8xozszRfQB14DrgSmAL+Atg96Xad475sBa4tXl8I/CWwG/g8sLco3wv8evH6Y8DXaT/Y6nrgoaL8YuD54t/NxevNk96/Zfb9s8B9wNeK5QeAW4vXvwP84+L1PwF+p3h9K/Dl4vXu4thPA7uKn4n6pPfrLfb394FfKV5PARdV+TgD24DvABf0HN/PVO04Az8DXAs82VM2tuMK/L+ibhTfe9OybZr0h/I2P8APAQ/2LN8J3Dnpdo1p3/4H8BHgILC1KNsKHCxefxG4raf+wWL9bcAXe8r76q22L2A78A3gZ4GvFT+sx4HG4DEGHgQ+VLxuFPVi8Lj31lttX8CmItxioLyyx7kI9MNFSDWK4/xzVTzOwM6BQB/LcS3WfbunvK/e2b7KNuTS+UHpOFKUlVrxJ+YHgIeAd2fm94tVPwDeXbw+276X7TP598C/BFrF8iXAq5nZLJZ729/dt2L9a0X9Mu3zLuAY8J+LYabfjYgZKnycM/Mo8JvAd4Hv0z5uj1Dt49wxruO6rXg9WP6WyhbolRMRG4E/Av55Zr7euy7bv5orc11pRHwceCkzH5l0W95BDdp/lv/HzPwAcJL2n+JdFTzOm4FbaP8yuwyYAW6caKMmYBLHtWyBfhTY0bO8vSgrpYhYRzvM/2tmfrUo/mFEbC3WbwVeKsrPtu9l+kx+Grg5Il4A7qc97PIfgIsiolHU6W1/d9+K9ZuAlynXPh8BjmTmQ8XyV2gHfJWP8w3AdzLzWGaeAb5K+9hX+Th3jOu4Hi1eD5a/pbIF+sPA1cXZ8inaJ1D2TbhN56Q4Y/17wDOZ+YWeVfuAzpnuT9MeW++Uf6o4W3498Frxp92DwEcjYnPRM/poUbbqZOadmbk9M3fSPnb/JzP/PvBN4BNFtcF97nwWnyjqZ1F+a3F1xC7gatonkFadzPwBcDgifrwo+tvA01T4ONMeark+IjYUP+edfa7sce4xluNarHs9Iq4vPsNP9Wzr7CZ9UuEcTkJ8jPYVIc8Bn5t0e85jP/4G7T/HHgceK74+Rnvs8BvAs8D/Bi4u6gdwT7HfTwCzPdv6JeBQ8fUPJ71vI+7/h1m8yuVK2v9RDwF/CEwX5euL5UPF+it7vv9zxWdxkBHO/k94X98PHCiO9X+nfTVDpY8z8G+AbwNPAv+F9pUqlTrOwB/QPkdwhvZfYr88zuMKzBaf33PAbzNwYn2pL6f+S1JFlG3IRZJ0Fga6JFWEgS5JFWGgS1JFGOiSVBEGuiRVhIEuSRXx/wGwrlhDR787MAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot the loss curve\n",
    "history_unsupervised.plot_scalar('loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "X1 = history_unsupervised.get('A_X1')[-1]\n",
    "X2 = history_unsupervised.get('B_X1')[-1]\n",
    "X3 = history_unsupervised.get('C_X1')[-1]\n",
    "\n",
    "\n",
    "A_A1 = history_unsupervised.get('A_A1')[-1]\n",
    "A_S1 = history_unsupervised.get('A_S1')[-1]\n",
    "B_A1 = history_unsupervised.get('B_A1')[-1]\n",
    "B_S1 = history_unsupervised.get('B_S1')[-1]\n",
    "C_A1 = history_unsupervised.get('C_A1')[-1]\n",
    "C_S1 = history_unsupervised.get('C_S1')[-1]\n",
    "\n",
    "A_A2 = history_unsupervised.get('A_A2')[-1]\n",
    "A_S2 = history_unsupervised.get('A_S2')[-1]\n",
    "B_A2 = history_unsupervised.get('B_A2')[-1]\n",
    "B_S2 = history_unsupervised.get('B_S2')[-1]\n",
    "C_A2 = history_unsupervised.get('C_A2')[-1]\n",
    "C_S2 = history_unsupervised.get('C_S2')[-1]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize S Matrices\n",
    "\n",
    "Note that these S matrices will not look as nice as the S matrices in the paper. This is because the order of topics in NCPD is arbitrary, so to get a \"nice\" S matrix one needs to reorder the topics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fc9944084a8>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAK0AAADrCAYAAAAbrhYyAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAEAElEQVR4nO3csYqcVRiA4W92jQYjKkEbISRiIUhSWVhZq7WFhZdgK3gNirX3YOEFpLA32mmqBBTBRiyDRLPJ/haCVXZkwOX4kudpz8B3mHk5TPXttm0bKDlafQE4lGjJES05oiVHtOSIlpynDvnwS5ePt2tXLpzXXfa6e/u5JXNnZrbT02Wzn1R/zO/zYPtz97izg6K9duXCfHvzyn9zqwO99/rbS+bOzJzeu7ds9pPq1vb1mWf+HpAjWnJES45oyREtOaIlR7TkiJYc0ZIjWnJES45oyREtOaIlR7TkiJYc0ZIjWnJES45oyREtOaIlR7TkiJYc0ZIjWnJES45oyTloAd2dHy7Nu6++dV532ev4lctL5s6sX0D36U+3ls3+ZNHvvY+XlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05B636nG2b7eThOV1lv9Pnn10yd2bmo7t3ls2eWbxu8+h4zdxHZx95ackRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSc9Cqz90zT8/x1avndZe9tm1bMndm5os3ri+b/bcHyyZf/27N9/79h2efeWnJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNacg7aTztHR7NdunhOV9nv46++XDJ3Zuaz124sm73a7TdPl8y9v2ctrpeWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmHrfo8OZndL7+e01X2+/z9D5bMnZl55/Y3y2bPzNy88cK64btF79qjs4+8tOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJES05oiVHtOSIlhzRkiNackRLjmjJOWzV58zMtp3DNf7dyYsXl8ydmfnx/svLZs/MzO7hutFHuzWDT88+8tKSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJUe05IiWHNGSI1pyREuOaMkRLTmiJWe3HbC6c7fb/TYzP5/fdeAfV7dte+yO1YOihf8Dfw/IES05oiVHtOSIlhzRkiNackRLjmjJ+Qvv3FUPYJqwkAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.xticks([], [])\n",
    "plt.yticks([], [])\n",
    "plt.imshow((C_S1 / torch.mean(C_S1, axis=0)).T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7fc9943d9240>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAG8AAADrCAYAAABn5MiuAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAADfElEQVR4nO3dz6pVZRjA4bWORxMPDTScZBDRHxqJEBThsMAGEs68Aq+ieZfRrIZBEZEDL0AjCYeFAxFEcFRwTD3qarRp0NmbVrhb/eB5pt+G92X/+PbsY4/TNA007Sy9AP+eeGHihYkXJl6YeGG7cz58bHxpOj7sbWuXjd45+3CRuSu/3DqxyNxHw/7wZHo8HnY2K97xYW/4YPzoxWw109WrPy8yd+XCq+cWmXt9urb2zM9mmHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFzbrocnbZ/eH7374aVu7bHThtfcXmfuXZwvP/zs3L0y8MPHCxAsTL0y8MPHCxAsTL0y8MPHCxAsTL0y8MPHCxAsTL0y8MPHCxAsTL0y8MPHCxAsTL0y8MPHCxAsTL0y8sFlPvH69tTdcPPPetnbZaNw99L+Q/jPf3P1xkbnnP9lfe+bmhYkXJl6YeGHihYkXJl6YeGHihYkXJl6YeGHihYkXJl6YeGHihYkXJl6YeGHihYkXJl6YeGHihYkXJl6YeGHihYkXNuuJ1+67R4ZXvji5rV02uvf5W4vMXbn05qyv6oW5/ej7tWduXph4YeKFiRcmXph4YeKFiRcmXph4YeKFiRcmXph4YeKFiRcmXph4YeKFiRcmXph4YeKFiRcmXph4YeKFiRcmXph4YeKFzXp09uz2zvDb5RPb2mWjK9e+XmTuypfnzy0z+GD9/XLzwsQLEy9MvDDxwsQLEy9MvDDxwsQLEy9MvDDxwsQLEy9MvDDxwsQLEy9MvDDxwsQLEy9MvDDxwsQLEy9MvDDxwsQLm/XEazo4GJ7eu7+tXTb66uMPF5m78tmNbxeZe+XT39eeuXlh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oXNeuI1juOwc+zotnbZ6PmplxeZu3LzjzcWmfvw+YO1Z25emHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4oWJFyZemHhh4zRN//zD4/hgGIY721uHQ7w+TdPpww5mxeP/xc9mmHhh4oWJFyZemHhh4oWJFyZe2J9lLkcFfyW8gAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.xticks([], [])\n",
    "plt.yticks([], [])\n",
    "plt.imshow((C_S2 / torch.mean(C_S2, axis=0)).T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rank 7 Approximation (for both Neural NCPD and Standard HNCPD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reconstruction loss: 12.015365937320105\n",
      "Relative reconstruction loss: 0.09044921085441056\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAEkCAYAAABaExIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAS60lEQVR4nO3dsc8k510H8Jl370wUpBQEkwI5SUPpwkgYh45Uxz9AhdJzRbBQJBpT4QoHDlJcE8EfABISVUyRKkEyUaITKVJAEVkUkXBokLAS7N2hsHTCuXd+vn3emWe/M/v5lN7MzLO7M7/d77u574zTNA0AAABc3s2lFwAAAMBHBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAhx75z/8a/+ymH64kv311rLc/vRT1689BLi3Pvp+5deAjGqW2eM3VYx52fT/wz/O/2seSHmUK59z6HWW9Jc/prjWXuZQ0lSZuK+5xB7Us2hswLaF1+6P3zvn15aZlV38MqbDy+9hDif++b3L70EUkyn+cfGjj+a39z+3eedn3/rTrs1h3Lteg5V11Wl5zXHs3Y+h5KkzMTmOTRzrgzDMAynkHsGp6wxZR29Vc+7MvOavPPB2/OHajsSAAAASxPQAAAAQghoAAAAIQQ0AACAEGeVhMBF+Mf5ucr35tBtGddunM7/R9nTOP+PnVv2NwxD+z+gXtoa/0i9dZ7s+TWB/6eaG3PzpnnWVFKuud6utbhjp3yDBQAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACDX7e5FSK5tU5Xqtr0nP2wu4lUGEqjK/67FSrv+Uax8orXK7j9Y5VM0NM4XOfLsCAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAINft7kVJvvYbWKvc9vyZEq+qh16jEb6mjbl1HeayUKmrX/uVNp/nHDod+6yBac5V+ZY05lDJTqueWssat6Pl51XAsv6ABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAILY5Au6qpbdDUtqRV2s6WpkXsWdfabFm17/ZcS8rrTz+t59fWzxUNj+dpeU1az5GGY/kFDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIXZXs19VUU/j8hWqSx9vE1XaW5FSmbvnetuqSnslP/rJi8Mrbz689bGe1/+TNx4vur89ePC3v33pJXxkjZrtrV/HvZ9by/HWWONK71s1h1K0fp9onZVP/jRjJsbMoT3b+jxs1fF5+wUNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhNlmz31od21rBvcbx6CClBjal7n8Yln9NplPx4GHZY91R71twXKU9X3Ot++z5miTNmkrLOrfy3HZg87Oy920jekqZsdVnf3X7nTW2a3XI+o7yi/yCBgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEJus2d+6NSr9CZZSiwtra61QXlrrNVdVaTfuczxk/B10SvpsmXstV3j9Od8mqvQra8yhkPOvdZ5MxxVq6rcu5D2dk/HJAQAAgIAGAACQQkADAAAIIaABAACEENAAAABCCGgAAAAhNlmz31pT37s6dvNVtVtXVe22qOp5q2MdDsuu4y6WrpXtWZ1Ovj2fD40V8FH19nANdjyHYuZJ62vce7sNu75nDAAAEEpAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEJus2a+sUW3fWuvfcx1lzXtrPeka++yp5xqrYy1dbZ+kvJVB0O0F4K72fB2zWWt8B0m6XRFcqw18ywYAALgOAhoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQYpzMqWj/9ay9Nv/H7fzy3o6YFtFS2PnnjcdOx9uzBF1699BLu5qY4D1rrrat9plijurvltVzj9Z/xzgdvD/99+q/mN+czN5+dXvul37v9wWqtPV+Xjq9nsxXW+A8//u78Lof5WzLcFH8rbNnuS2+9PrvN0HksfPDL84+NC58KU/HcPv/1H8w/2HsOBXjn59+6+xy6/+D8DbcwG9bQcj50nkOV33l0+3ffSnU9Vqq58OGn2vb5+T//ftuGS0s6/wNmVDWH/IIGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQ95baUUtdPgsKqAsdhiGrJrjnWlJe/2vnfbioQ/U5MM3/PbDarrqMy+OFGNtavduOFTR+S8fj7f99LP5mXF3bc/v7pH2yS61zqLWCv6ee82T3lv6OuPAtBEwuAACAEAIaAABACAENAAAghIAGAAAQQkADAAAIsViL4zi1taFofwTYv9bGxZut/x2xetpLty5u5eN06WZFTY0soOusqYomG7fjFgs3K/ZkqgEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIMRiNftcWHhd6J201qRW23U0VvXih75rnBbv9d6JkHNl646Nt1upKvhPw2l+u+HQdLyuel5yLu++lp4bW59DG/geUs2TFGP1Mqa8xFPxOva+7UW1ltN2f4fa7soBAAB2RkADAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQy9XsV9WfG2+O3YSkytOltVb39qz8LeqRp8bqcRq03nZhbrut114D62n5jGm9bQwkSfpembSWBe3zWQEAAGyQgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACHOqtm/99P3h8/9zQ+WXUFDvfWDb746v781KmyXrrBfYY3/+O4/N21XORb18Ifx9ufwW3/xR4uvo9Xp/vxjY/GWTg1v6Vi8bb/+18U1s0atcnV+zZ3Lp+JJV/s7HucfW7P6tuV167XNXbarLD03Vljjl77++uL7bPGvf/L40kt46uVHD5fdYestbXpXuC99vKVvoXEpaevppePzXuN7yA+/ljFTmudJz1sxVceqvmu07rNa/wa+v88eatG9AQAA0ExAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEGfV7F+tNevCF1JV4p+G+ZrRmyKjV9s1ddFvxFxl/lRVWG+lOXnpc3kD18ZubL2ee43rp9png2NVyZxi4ecMrCNqnvT8rF7jWK373PB3lO2uHAAAYGcENAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACHH5mv0bncGpjkX3tWQPK6hqmQ+H+ceWruCv1lHVFrcuI+QOAodV6qGX3+Wc8lYgsLaWWvkVrrm5W+UMQ99rZI15suM7HPELvNUAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAhxfs3+0nXOS1tjfdWtAFqOt8IaX3v0+uL7bPHDrz2+9BKeevnRw7YNZ5qCy3be6sHe18zSx6v2t/S18bwSbs+RPgsvofVPfgFv52oamsdbjXt+HRMlzKEkx0sv4CN7vt3E2HGecFl+QQMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIjza/ZhKxpb0MeG7TZTuD7N3UNghb/VzB1riVdrrxX3a7wPLVrXUVVAV9XXrW/nFuq0t7BG2ux1DrVKmV+wA64mAACAEAIaAABACAENAAAghIAGAAAQQkADAAAIocVxJyZNYc9oaWNkI25uP+HHcf5CmKbihGhpY5tZQ/d1fMJamhTrGA/zf9f78NPFPluX2PCSvPzoYds6qmNVjxV/6vz2V98qNpx3rM6TGYfivPvKN77ctI7ynFz6vGtdR2G8f/vXnPHDC31oVq/ZwnOoe8tkyHM73Z/frPV7QTlTWvbZePq1zpOv/NXvth1wzhrnXc95Mgzz6+y9jhl+QQMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAg1+3vRscL6TsfraPLnBy6krNK/VtVL0nOetFbpN+6zpS5/GIbhg6Z1OO+grNJf4Rpvqe4vb420wjzhwqbT3AOzm/gKCwAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEGr2d2Kca/Ac6jrXlnrYYWhvqo2x9BOoKnNvgu5JcLz0AuAWS1+PK1T6V7OymrFLX3LHYh1Jo2Zxu35ydNP7lkQL789H+IKWninl/g63/+cP57fxCxoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEKcX7M/UyM5jvNVkdPUr5S9eR2nxjUuXdNZrGM8zOfp4wvFPjvWyr786GHbOqpjVY8Vf2L49lffKjacd2w4X+8X590ffOPLTesoz8nW8+4wU/Xauo7C+MLtJ+V43P/fhVLmYZSe9dbVsRrXUS6j2OcLxblQmZtD94vdVXMIuIPi0mqa6I2X6qeKa7y42xIbtP9vSgAAABshoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIjza/aX1lLhXVSMX22FNas7Fo81lNezY+bQLaqXJKUdvvVtK7ZruW3HMNTzZs6u59Aatx2B51WcfmPDJT413naoqtJvnTVXa26mtM6T1tt1zfALGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQozn1EF/5uaz02v3H6y4HOikqlFduCr17Xe/t+j+7uI3/+wPux1rGm9/jf/97/5yeP8//6O5F9scYlNCKuDf/vG/XHoJT5lDnbR+znX8fKyMh/nfEMrvrh3XuIoVZkbS9Z9ibg7NzYxhGIZx4VsZ/NvfP5qdQ35BAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiHuXXgDQT1UfCyxvdM09Y24OtVZYm2tcC/NkOS1zo3XWtMw2v6ABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEmn2grIBVYc1zuwk5V05tde1rmBqr44FGa8yhkJlinnRQnD5jx/PAL2gAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABDi/BbHhJau43H+sbHInNOpbbst2PNza9V6riac451pamQRIU1nSUbX1jOq1li4sx3PIfOkg5DT50q/uQMAAOQR0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIMT5NfsJ9aWttfF7rpvf83NrVZ2rVZV+wjm+kp711qr7gdvMzYZqPpknQBfVqCm+QlUzquW7l2/1AAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIcX7NfoLpNP9YVTffut0WVLXxPe24op4NWPoa38rM6HnbiCu9RUWrqeOtLSqvvPnw0kt46skbjy+9hOHV77y33s5T5tDxeP6xhmEYThnfo6Zjcaw9K+bo1PgSp1z/SbfSiJhD352fQ0HfMAAAAK6bgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACG2WbPfWuWaVIu9NPXWz2q99UDKLQtWMFdjW1XftuzvYpa+xvc8M/Zsx9fwGnZz/adImUNrzK9rnYkbnymt9fZz27Vs80mSKvgTXOmVBgAAkEdAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAENus2Z9O849VFbCt28HOXWOFLTvWetuRNT5bDoe2tXS0RmU27Eo1U6rrf2md50nP7wa+h3ycVAIAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBDbrNnnWW4hABBn3Hh1tOprFjH3HcX3k7O0zpOUuny37Xh+rgwAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIRQs78XrVW1N0X16kkdKnCGap70ZHZBlp51+mvMITOFzvyCBgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEMvV7E+n+cda6lWX3t9dtqv0rKlf41h7ro4NeW6vvPnw0kt46skbjy+9hOHV77x36SXA7Vb4bJmmjDkErKTjLQTMk+vhFzQAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAixXIvj0i02HVtx2Kmq9bKnahmthUwhTw0+JqQ5tVlre3C13XBoXk4vY2Mz3DQaRATqPYfK639pbfOkusar63huu5ZtPknrGvdKCgIAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIjza/YTqsu3XuW8ZwnnR5ridG2ut9azT6KU67/1M6L19i7FduMV1kPDRa0xh6qZ0vG2UObJ9fALGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQpxfs0+mrddbwzkSznfn+rO8Jps0qe5ukzCHklTX/3S6/b+vUVFvDm2SOfRxfkEDAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAECI82v21Zdm8r5clXG68vd77nyfq3IehrY652p/h8P5+7uEnlXg1etVqd6bpd/TSuuxiu2mKb86unWeXH0tdsIcar12Kq37LGdix3nZ+/tQ6+vccqgNzBOW4Rc0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEGKcz6nXHcXxvGIZ311sOcAW+ME3Ti60bm0PAAswh4NJm59BZAQ0AAID1+L84AgAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACE+D/lC0PPuyPJjwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "approx = outer_product(X1, X2, X3).numpy()\n",
    "print(\"Reconstruction loss:\", np.linalg.norm(np.ndarray.flatten(X_np-approx), 2))\n",
    "print(\"Relative reconstruction loss:\", np.linalg.norm(np.ndarray.flatten(X_np-approx), 2)  / np.linalg.norm(np.ndarray.flatten(X_np), 2))\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, constrained_layout=True, figsize=(12,5))\n",
    "axs[0].axes.get_xaxis().set_ticks([])\n",
    "axs[0].axes.get_yaxis().set_ticks([])\n",
    "axs[1].axes.get_xaxis().set_ticks([])\n",
    "axs[1].axes.get_yaxis().set_ticks([])\n",
    "axs[2].axes.get_xaxis().set_ticks([])\n",
    "axs[2].axes.get_yaxis().set_ticks([])\n",
    "X_max = np.max(approx,axis=0)\n",
    "axs[0].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(approx,axis=1)\n",
    "axs[1].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(approx,axis=2)\n",
    "axs[2].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rank 5 Approximation for Neural NCPD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reconstruction loss: 30.474334998624315\n",
      "Relative reconstruction loss: 0.2294045446736759\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAEkCAYAAABaExIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAPiklEQVR4nO3dsa4c1R0G8Jm9BpsUQUEhSEiWafIAjpCVN4ifIO8Ql1FKU8VKKiokFyRPkDxASBtSIAvkjhqkSCCZNJEQYNidFFFAFp5j5twzs9/M/H4lm5k9uzvnf/e7N/6mH4ahAwAA4PwO514AAAAA/yOgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiCtT/sc/felieO36c3Ot5Qf78JOXz72EOFc++7zwaF94rHSbhdJxa9D6FhJrfz8yfDl83j0evqx+M82hXDlzaI65lnJLGnOoha3MoSQpM7E8h2rZd7RXmkOTAtpr15/rHvz9eptVXcLNe3fOvYQ4r7z9YPzBi4vxx47H8cf6wh9YDyPD6pTyJabruuHU9nyl92Ptxj7Prmv+mb739TuXOt4cyhUzh2rPV9J6ntSqXX9pjy9tbKaYQ6uWMhOLc6hWaX4tqXYfJO3/FAHfV0tzaMPfOAEAANZFQAMAAAghoAEAAIQQ0AAAAEJMKgmBZyr9Q/q1/+P8WnO8J7VG11L5D6CTXtuO9cP0f+w89OP/aLzmfFHmuPb2Or/gByrNjbF5s/pZAzPxDQoAACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACHU7NOWeuvvS6qbb72WpNfGU5Xq9BPOBwA8ybcrAACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEULMPsGH9MDz1v9fW5Y+dj3Annxsw0djcOFTebqU0h0rnnOO4cP6CBgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIbQ4wp4Mp5EHLhqfr+u6fr+//yk1HWpPXMAc12XpnADQ0H6/QQEAAIQR0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIISa/UtqXac9Rz03fKt19f0ZqvQ//OTl7ua9O099bMn98/Du/abn24Lbf741/uCpcJuA0nV0qPjc5rgu56jnrzln4f3oC9f4cKy8TUDp/S98pv3F+GsbW0tx/V3WbSZKcyhF7a05amflwzcyZuLtP73e/qSl+bUGrdc/x/tRe861fzYj/AUNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhVlmzX1sdW1Kqla19vrHjaits53jd7Mxo5fdF4/N1Z6ngL3ELCwBgDbK+QQEAAOyYgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQlw59wJqDH0/+lg/DFXHJT3fqg2n8cf6yt8HlM65Boeg6+B47gUAwIxqv2tAEFcxAABACAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCrLJmv2Tt1fZrXz9PcRq/FQMA1Crd6meOc67iO8oct+ZR3c/CXHEAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAjRDxMqWn/0s+vDz3/927ETVS2gprL14d37Vc+1Zbdv3Gp/0kNFnW5SpXyparemMrfwfvSF63g4Vlb+lt7/wvvcX4y/trG11BxT672v3+n+c/p3dVfzjw8vDb+88quWS6KRf/1uhjlU4fGLOXPo2qOMWvJX33r/3Ev4Ts3Plsbe++pv5tBGffRG3RzqC2PjeDVjplx8WXfJ3vhD0P5PET6H/AUNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhrrQ6UU1dPgDsVanWexj5kVpzTJzj8dwr6LouozadZZX2zyqsff1JwueQv6ABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAECIZi2O/VBXLaP9EZim77re75ZYv5rWxdU0NZZE7N/LvpHm0Bqtfv+sff1JIvbv+AeasDoAAAA6AQ0AACCGgAYAABBCQAMAAAghoAEAAIQQ0AAAAEI0q9kHRhxm6MWd45xrsvfXzyb0hbvTNK8DH06NT3gJEfXWDZhDq1Pac6tQWn/pckza/ynC51D26gAAAHZEQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBDtavZrqz9h606Vvb6lCufSOS9sOFiD5lX6JeGV0rCE0p5bfQV/if2/Oj4xAACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEmFSzf+Wzz7tX3n4w11p+sNtvvz7+YKlKdDi1X0yIj94ovCcLevyTY92BtX3ThV7cFz69qDtnY9ff/KD9SQsV/MNQ6Aoe2QPDN0vujQZdxrW3LmBWj18sfC6lLd744/y6dg7N4PB4/Mds61rv6tr+OX427qHW2xyKVPweUrlJhqsZ3x+P11awr0rzpHYuLHnOOWbX6PrHZ8gKPmkAAIB9ENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACDEpJp9cpXqmkutsrXHNX+yWtW90gtKqmKOqL5ewWdGe0HbYElzjL2E54JYa/hewPkt+X1o9LnGr9WEb2sAAAB0AhoAAEAMAQ0AACCEgAYAABBCQAMAAAghoAEAAIRQs8+oRSv4t+wQ9LqTKv/Zl9I22PBlueSdR6pHbMTtN6CR5l9emGSOebLDGbW/VwwAABBKQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBATa/b7rru4mGclU9RWhW+4pvN4NaOnerh6OvcSvnV6PuTzVm0PAOengp+VCPkGCwAAgIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhJtbsA6s2NL4NQunWFaPP5bYDu+Rjn12pQRzoVOmzGv6CBgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIbQ4wjm1blXsunKz4iYM87xvU23+fZ7u2qOMhrTD47ofbbUtiKViuGuftX++mnVE7Jn/28TeCZlDSUI+1xc+vWh+ztPzbV9baa+W5sLhq9ondK1+T8j1OiZ7dQAAADsioAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAg1+3tQrF5ebBWLq62xbf1ctNZn1OOqLf6e0r6q3Y81e6v1/r7M8y25luJz1e6Z0nWesA/PJmQOJTETn1SaC8ut4tlK1/GSn2ntOja6D7f5qgAAAFZIQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBBq9vdgw1X6JTH11sUD/Y5kukGdMyxljhl1iCoZr2QO8QxJl3ntPk75jrLDORTyzgMAACCgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiOk1+6eddraHu/iiUBdaahKtrod/+n8+fpGT+Q9fnXsFgZaszB19rstW2/Y51b88Yaj8aGuPa/1ctbfLKJ1zjuereS5aM4eot+RcYJ1MFwAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhJhes89+qGxu4xD0Rh6P515BV39vhxUpfea1tyqpPefYcTu9ZcocFdalcy5Zmb2aeu6dXnubl/SzLtxq9uqWhc8hf0EDAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAECI6TX7CTWq4dWY53Djj+/XHTicxh/rN5zf56hBH/HOxw+anu8yfvH73yz2XEP/9Pf4m7/8c7E1nM0cM6r2nAvOy1ffMocmMYdmNz6H/rHYGlan9rpsPWsq13H9zQ+qjmu+H0tzraR25hXWn7T/U4zNobGZ0XVd1w9tr/Fv/jo+hzb8kw8AAGBdBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAgxvWYfWK1SfSzAEsbmUKnC2uxidbZ8i5ANGJ0ppVFTaNlvXc/v6gEAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIjpNfun6VWRQDb11sBSaiqnzSggWc1cK/EXNAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACDG9xfEQ0JakSRKa0oIGAJDBX9AAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBies0+sFr9sNwtKlT3A08zNhtK88k8AZK1nlH+ggYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBDTa/ZPy9V0A7Aiw6nuuL7wu8LSOUvHrcGCP09v3ruz2HM9y8O798+9hO7Wu4/OvYRcKd/zatdxPI4/VpoZKa+7VuF1p+z/pFtppM+hlf90AwAA2A4BDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQkyv2T8sW4P5VGuvQoUzGauxTaq+BZZV2v8lZgOsw5I/42vnie8hT/IXNAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhJhes6/iHjZnjxW20MxwGn+sz/89aGn/11Zmw26U9n9rK5gntXwPedJ2P2kAAICVEdAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACDE9Jr9Q0ANpqp/AJid6mtYvyX3sdt2tOEvaAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACDG9Zl/FPQB8p/e7Toix9H60/5mBqwoAACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACGm1+yTaTjVHVeqhy2dM6VWtnaNC94u4ua9O1XH9cP4Goe+rzrnw7v3q45r6da7j869BAC2ag3fXeAZXKkAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCiyPrtoZGprrCxW4oHVh5zs047P0NCHU89wLOZOWtcaXG2JLaNtnNGJtDtS3BNecrzcLW67jMWpa09ByqbdGuUTlPaluhx46rOYZp8n9yAAAA7ISABgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIdTsw9wKjbPV9da779kPUFthzfasoEqfBbWum6893xy193PU+re29H5M2f8zrGP3t9I4o5CrCgAAAAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCqNln3YbT+GMp1bcFpQrb2gp+gCnMIdiA0vehSmP7X/3+/PK/wQIAAOyEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACHU7MMZ1VZYl47bRf3tKaD6+7CD93mq47kXcCYrv90HlRLmUJKUmXgsDKI59uMM9fajzJPd8EkDAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACBEP0yo+e77/lHXdR/PtxxgB24Mw/By7cHmENCAOQSc2+gcmhTQAAAAmI//iyMAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIj/AvF//8Q3E85GAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "approx = outer_product(torch.mm(A_A1,A_S1), torch.mm(B_A1,B_S1), torch.mm(C_A1,C_S1)).numpy()\n",
    "\n",
    "print(\"Reconstruction loss:\", np.linalg.norm(np.ndarray.flatten(X_np-approx), 2))\n",
    "print(\"Relative reconstruction loss:\", np.linalg.norm(np.ndarray.flatten(X_np-approx), 2)  / np.linalg.norm(np.ndarray.flatten(X_np), 2))\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, constrained_layout=True, figsize=(12,5))\n",
    "axs[0].axes.get_xaxis().set_ticks([])\n",
    "axs[0].axes.get_yaxis().set_ticks([])\n",
    "axs[1].axes.get_xaxis().set_ticks([])\n",
    "axs[1].axes.get_yaxis().set_ticks([])\n",
    "axs[2].axes.get_xaxis().set_ticks([])\n",
    "axs[2].axes.get_yaxis().set_ticks([])\n",
    "X_max = np.max(approx,axis=0)\n",
    "axs[0].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(approx,axis=1)\n",
    "axs[1].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(approx,axis=2)\n",
    "axs[2].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rank 3 Approximation for Neural NCPD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reconstruction loss: 62.09267763486261\n",
      "Relative reconstruction loss: 0.4674209442482673\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAEkCAYAAABaExIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAQjElEQVR4nO3dsa7kVh0GcPve7N0VQSiASAQSCRXUCF4GeABKJCigSB0JKEDwDLwDJaJCKEi0BAmCgoSiQJno7iY7pgggol3/d+fcY/uz5/crc2P7jMfneL6Z5PM4TdMAAADA9q62HgAAAAAfE9AAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACDEC+f8y9effnF64fOfW2osz238aOsR5Ln5xwdbD2EYhmEYtx7A//EAiUy30/vDo+m2+VK5/synpnsvv9RzSE2mD6+3HsJzmpsJ/Wfr/XfeL/5aHa91tiatOHN6r0R7eM1tlrhC5tx5HVrk89B6c3UJ473HWw9hGIZhuPnLw62HAM+lWofOCmgvfP5zwxd/9L0+o7qDm3/u5IPROLPYTv0X26+88Yfu+2wxjjk3kvIZf6fib1edX0N1rBTVa+48/t99+Os7bX/v5ZeGr/zku51G0+723Re3HsJzmWbWoXGBdeir3yvWobH4DzamU9sBq32maH1tc9Z+zb3Xw2GYXVPG6/nX1ryez7jrOrTE56E15+oSHrxSfUGznle/9dbWQ4DnUq1DO7i7AQAAXAYBDQAAIISABgAAEEJAAwAACHFWSQhn2sn/2NtT+T9yr6213GIPpR481dQ458a5Qp8DWLVgoLXAYg9lH62O/NroqvtcPXI5KhycOwcAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEL0q9lX58owDOOY84ZOKV8/qO1fzZHr8ltNM+dk1fp9AOC5pXyEBQAAuHgCGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhOhXs9/Y2LxqBfQSFdzVOOeOp94a2k3jMHWeQ0eu51enz2KqR4hcue5YTu97AKTxCxoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAITo1+JYlaCVRYcrNvGs3fqjZQjgsk2nvvsbfa+6a42flY6qamM8crvuxarWQ2vbJzgbAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAI0a9m/wLrYYENjNNs/bLKZoZhGIaroBvSKeR70CXOSes+e48l6f1+lh0NdQ2t63K53Z6uh7WcQu5/SVX64ddJ0JkCAAC4bAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEOL9mv3dT51zLZXGcql21aNleRO+xJL223VuzVja8rhUV/KswD4AVVOs5HIFf0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEOL8mn3YC5Xf/B9V+itY89EWFXMfDs16ztH5BQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACHOr9mfay+uGk9bGo+LbaagBuXeY0l6bS2mKaj6VuV3H3sfP+sJuVbGMWMcwzAMU/U1aLVGhZxLIGdNifqMxaL8ggYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBDn1+wD50mp+69Uld5p45/GYWp4HkW1zTg+/TW2HOeihVwrZbX92lrPSe9zqbY/Q+9HEi1hxWncsi4/a7slqLdnbUm3MQAAgIsmoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIizavbHe4+HB6+8/9S/tValtrgdXuy6v6VMM697XKIeNqTeOsp06ru/ceXvM5aoxZ65TsZx/lhlZfkG193NX2+H1779p9WPe5Y9Pbago7//4Juzf6uWvc63iOHRSwuc48Z69AfvFX9c8VL40i/fXO9gjaY1T8gd3X/ng+Gr3//j1sPIErK2vf36/DrU6nS/eG0tL7u6vRf7u75t+1zw2hv58391IdfrHL+gAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAgxFk1+5XeVfoAHMeqt4jGSvzmlvdqO7fGJ/V+BEoTbwxncLkcT/g65Bc0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAI0a3FcZqqaqx5R25/HBvPCcDRVMth621gdp+tS+8SS/YSrZEtx9qDsfjOOKJxjYtlHh9Ptd6sZv7NThgdAAAAg4AGAAAQQ0ADAAAIIaABAACEENAAAABCCGgAAAAhutXsH7kuv9U0c07U78MdXQXMoZM17xxL3CJm91kda4m67DUruCt7vyRV6e9bwrq8FPP4eCLWm/k32y9oAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAI0a1mf2qsjj9yPb86fYANJC29KRX8AOyGX9AAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABDirJr9m788HF79zp/7juA00zN8VXQTz23zrO1axnGXfa44jrdf/+YCgznfo88+3noI//Pg3evZv1VPd5h7OkLLNsMwDK/+5M35P1aqa6HVdHr6P8552/ajmI/jOP+3qfpaLGUdavTopYze+A8/+1H/nVaTvFgcrh7ea9ms+zrUfE9tVR4v4bvhu82ncZif59M0fz6rtSFFNf7SEtdRgyU+h0z3n37v/PiPDe9p40R+fNs4d3Zw/1jd3GUyNp7jmc9XrRJWSQAAAAYBDQAAIIaABgAAEEJAAwAACCGgAQAAhBDQAAAAQpxVsw+XoGq/hefRXFO9d9XLrlqeW0/X3D4bK/Gb6rLvsh2wDylzvHWNZXf8ggYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBBq9snXWpm9opBhPNvoO5k1jOP8NXvoCv7Wmufe9dCtE7K1gn83CwDQZInHczSNY71DsS2f1gAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEGL7mv2rhs7QYpvd11u3nI9hGE43xWurdtl6Smb2OZXjaK2pbRvk6aZtKC3K4Vfv6WmBa3Lt4wGrWbPRG9iAx3Yw+AUNAAAghoAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhtq/Z72wXVfpH0HKa9UPDsVXrQuuTNFqWjWqtaX7cR4adDx/2ac3JVR2r91pJLL+gAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiMO1OAIX4JTd1jo11Zw+Q/hrHoZhePBe1Z5YbNi5mezqYc6t7cE/tx7Bf1TXz3Ra4HjF979LHO9sd5tP01C0RhfnetrD1+I7WGsqD9697r7P0033XTa5vm3ccOfv6aoi1ie/oAEAAMQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIEROF3En4zjfyTxbiXsEDVXUd9pudn8LnOOpqu7uf7zqcLCoqt533Pn3aa1TtWG7JZahJexlnOxUVa1+ddwb3RLzylxlbTu/4wMAAByHgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEOV7N/6Cr9SvWyqzbd1tM1t8/WSvzWbvsFOvFb6nRV8wO7tfYjHCIeGWHR7q56TAikCV+HEkYHAADAIKABAADEENAAAABCCGgAAAAhBDQAAIAQAhoAAECI7Wv2TzOd5ldFBe7cNnfZrlXLPhd4bde3GZXBj+8XmX+Juv/C9W3/fTZZ4rpLOh79RNT+LmTFx30c4bEXR3gNbKz6rLGEkPWrde4s8SQgaJUxmwAAABDQAAAAUghoAAAAIQQ0AACAEAIaAABACAENAAAgxPY1+xzL2lW0qm/PM52e/s9D6pE5sCWe/jCzz9a67Gq7tbWMRRU4n9D6SKKdW2Ie997nXtYhtuNTGQAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQpxfs1/VtvbUepwlxtd7nwu8ttfeeLNxMJ0lVfeuda0+w1s/+/rWQ4jy8Me/3XoIfSRd6yG+9Ivfr3vAlMdDzD2+Yhhixvj2r7629RCiPPrhb+6+k5Z7TMh9qbmCP2X8hS//tPg8tMRcrfbZYoFxvPXzbzQOJkR1u228JG9e/qBtw46qdSjjzgEAAICABgAAkEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIc6v2QfOk9TG3rshOem1wVZCqvQrYzFXp8Z1odpnpfV4cGdLzNWU+d86jpYK+wVq75s/T3T+HNK6VvZeY0OuKgAAAAQ0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIoWYftrRE3fSa1ffqsgFge9Np6xHQkV/QAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAghBZH2FLVuFg1JK7Z1NhqD2MEmk3FGjWa/ySqmg5Hv1mQw9UIAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQavZhaVVdftI+e9vDGIHdV+LvffwQ68Bzq2XdaF1rWrbzCxoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEKo2QfqSvw91exOp/WOdfL91hNG5+QJa16ThYfvfmrrIUSZPrrrtTq1vbcpc6Qae7W2hVzPpeocp5z/JRSv7eZf1+sNo/ERO9PKnzUeTtuvidU6dOArFQAAYF8ENAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACKFmH5bWWh27RPX93D6XqLfdUz0/AEAIv6ABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEmn3ynYq++asDd7mvWX2/RKU/l2c6bT2Cj42+e9xctW5U6w3cVes6tId1I2WNLUzF3B/3MPdD1q4dXI0AAACXQUADAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQavbJd+Qq/coS1fctFbGttbJbvG17qEmGS7CHOm36sv4ubwfneBdV+pWQ8ee/0wAAABdCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBBq9gHoYwcV0IeWdP6rx2yE1FhftKRrhUOZirm/iwr+kLXLDAUAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIh91uxfVR2YF+rx43WPN1fRe1q5Q3U6zf8tpEb4/ssfzP5tajxdY+MUaD1eT+MLxXsGtCvXw2q77iMBeEJVwc8nZXyCBQAAQEADAABIIaABAACEENAAAABCCGgAAAAh9tniCP8V0tTYqrWNMcXex09nVYvgmna+Lixi7abGluNV28StNeNlXmeX+JoPYCzmVu9mxepYlZiGx5BWWzMNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAh1OxDqKmoek2pt6/GyAXaQwX3VTF5Tgtc0Hs4JylC1jV4wtrrRmcxFfY8N3cOAACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEULPPvk2n+b+pt95UyqMA4BPWrsSu1qje9r7mVW9N3Hoytb23u3+PLvSeu4Mq/cpYDF8Ff6YDzyYAAIB9EdAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEmn3Y0NTY3Fttl1Jv3/ra2LE1K+UrR677rqScf/IduS5/iXmQck7M8YsRcsUBAAAgoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIhxOqMLexzH94Zh+NtywwEuwGvTNH2hdWPrENCBdQjY2uw6dFZAAwAAYDn+E0cAAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBD/BoYl6m1LxsWKAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "approx = outer_product(torch.mm(A_A1, torch.mm(A_A2,A_S2)), torch.mm(B_A1, torch.mm(B_A2,B_S2)), torch.mm(C_A1, torch.mm(C_A2,C_S2))).numpy()\n",
    "\n",
    "\n",
    "print(\"Reconstruction loss:\", np.linalg.norm(np.ndarray.flatten(X_np-approx), 2))\n",
    "print(\"Relative reconstruction loss:\", np.linalg.norm(np.ndarray.flatten(X_np-approx), 2)  / np.linalg.norm(np.ndarray.flatten(X_np), 2))\n",
    "\n",
    "fig, axs = plt.subplots(1, 3, constrained_layout=True, figsize=(12,5))\n",
    "axs[0].axes.get_xaxis().set_ticks([])\n",
    "axs[0].axes.get_yaxis().set_ticks([])\n",
    "axs[1].axes.get_xaxis().set_ticks([])\n",
    "axs[1].axes.get_yaxis().set_ticks([])\n",
    "axs[2].axes.get_xaxis().set_ticks([])\n",
    "axs[2].axes.get_yaxis().set_ticks([])\n",
    "X_max = np.max(approx,axis=0)\n",
    "axs[0].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(approx,axis=1)\n",
    "axs[1].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(approx,axis=2)\n",
    "axs[2].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Standard HNCPD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#rank 7 (same as Neural NNCPD)\n",
    "X_np = np.asarray(X)\n",
    "r=7\n",
    "factors_tl = non_negative_parafac(X_np, r)[1]\n",
    "\n",
    "X_1 =  factors_tl[0]\n",
    "X_2 =  factors_tl[1]\n",
    "X_3 =  factors_tl[2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rank 5 (for HNCPD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "69.83321370690754\n",
      "0.5256901124915707\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAEkCAYAAABaExIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAQRElEQVR4nO3dPc8lZR0G8JnzsLta2IhbELNZggkNsRO+g9/AxERtl9ZGDWppYaKJ2c5K/AY2aIkVGqgUsdAgRkmAYMGby7JnxgJYE/aZe3fuZ16umfP7lRxm5j4v83/OdQauafu+bwAAAFjfYe0FAAAA8DEBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQjw05l/+0hfP+kevXZprLQ/s5TevTr7Pvh1+rN3AnQguvfHB2kuYUe0bUHhTq/ZZ2t+pKr2O579et/r3m9v9reoXM2UO/fm9h9deQpzLr95aewkzMoem0rbrP4f/du/tYg4lSZmJOXNo/N/H+23XVp7/G/gau7j0OTQqoD167VLzx99dm2ZVF/DEzacn3+fWA9q1n744/OCh8OS6DTy5vqvbri1cIK7ZZ2l/p6r0Op6dnfuPX/jwuQsdMmUOPf78t9deQpzHvvXK8IPm0DT73MEcaj93Ze0lNC+895sLbZ8yh5KkzMTHvvPX6XdaM6NK53flXGgfGvW1/f+7LN3zuOa5bX2eN/lzaPuTHgAAYCcENAAAgBACGgAAQAgBDQAAIETd/224Q1soAgEYbWi4lZqRAMbyRQom4woaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCqNlPU2qp1YoNAAC75goaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCqNn/RF+osG9L1fdTH2vaQwGnrjRwYA1dt/YKmINZw5aEzyFX0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIIQWx09M3dSYciwCHIKarTofPgCAZK6gAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAghJp9mNueq+3b5X/jeeX1q82Tz9yYdJ99xZ0Qrv/t1qRr2IW+G37suNwyopRekyX3t8K5OqS/c2ftJTR9f7G5/JfXrzZf++H4OVS6zc7QHKrZ5n7bzeH630Nm4nEDw6byPK4+d6Y+/3fwvSZ9DuVMbAAAgBMnoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIj91eyXmj8rqrTvu88apXVsv7mUzzrUfvBmMHU1brEq+GzaY81o6Tpq4Byl+VSao7XbXUDbTD83avZndkGlmu9DtfOk4liuoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIMT+avYhzdTV9gB7VFthnXQrE2AblpwbFcdyBQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACH2V7M/R2tmSoNvyjoYJ6kCeurK/3Yfv/H0QW8RcJqG5lBbGNul2VXaDsi2j29XAAAAOyCgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiP3V7G+duu/9mbraPknfFR48W2wZF1VVR1166kv/9FWzltI2JX7Wg1nUzKHiNrV/enwPgdX5UwsAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBCjavZffvNq88TNp89/cME614+++n7lwRY21H/bz9BhW6o7P05/uE0oVsAvuL92x7+DlJ7bTLcXOHv7g+bhX780y75HOeii/qx/ffep4QdrX67Sx2hgn8fLdYequt1CUz/SS8cb2mfNNk3TNNd/8uKDLWoBbbv+uXPRFcTMoSQhM/Gf3yvModpzvHDXmJq5UTszLr9TeLCwz0du5pz/KdLn0I6/OQIAAGyLgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACFG1ez3baH6t3IBczTOAwOmvvVAlXnq9wk3x9s+sM/aKvoktZX/m3AI+G04oGKbmcxw7ix5PlYfa88zYw7hcyhgdQAAADSNgAYAABBDQAMAAAghoAEAAIQQ0AAAAEKManFs+0K7TKk9plCWtOumKmiapjkEtYV1Cb/JBL0eLKe66nf8PmubGkvbzfG3qmadu/ib2VW0ySY0rrENM8yaJdtfi8fy53M6U8+hmv31wx86Ew8AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACFG1exXq6zghz1o25wPeV88GTck6dYFXFztx3Jgu1IV/dJV+iW7qMxfSk2FdUmh3vqBmUPbU/m2T32uVs8hM2NdU8+hAlfQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQYnzN/tQVnypDgbG6gMGhYhtOW8IcSmImwmRcQQMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQIhRNfuX3vigufazl+Zay4M7HocfawuZs++mX0uIf/zoybWX0DRN0xyvVG5Y21Zc2+pbOt7QPmu2aZrmK88EnDOfijgHVFPv1UdfyHhvi3OoLayxn6EmfOnjbUDfj/+ctO1pvlaMN8cc6qe+nFE7F050Zsxh6jlUs7/SFq6gAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAgxKiafYKVqldLda5T77Oyij5KRlM4MIela6rVYt9DZT6Rlvzbby6sbuo5VLO/0hauoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIISa/VNXW/U6uN2JdtSXnvYhqE63S/hNJuj14PSUbjsyR/X10scD6ix5OpoL3EfCtzUAAAAaAQ0AACCGgAYAABBCQAMAAAghoAEAAIQQ0AAAAEKMr9nvTrRGPVx/Vvu+TPt+9meT7q5pmpw22up1OGeml3TrAuA0mUPATFxBAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiPE1+2QqNblrAr5HSnX/4vpuuWO1M/7+k3DrAhXb27T0Cb7rgXLiEuZQkq3PxCXfTnOB+3AFDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQoxsceyXbYFjXQs2Q5ZaFdmowVmxjze7PTtbewlxrvwno5nseKVuu9o5VFvIVtMmu4sG2kuX1l5B03y4lRdrO1JmYu0cKp1bXeGp1cyN2nP18jt123GO8DnkChoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEKMrNlvm6YNyHSq/u81R2Pw1PssdNH2czyBiY+3mQrrFIOzwgu5Vym3y5hlHYV9zjJ+h4638Dpga+Y4/6feZ/X+Frz9EesKSFsAAAA0jYAGAAAQQ0ADAAAIIaABAACEENAAAABCCGgAAAAhRtbs9yruqVfZU1+qoy3ucuJe/Op1HIK6b49rL2BHuo3MwsPA73BbWX+KlNM4ZR2wplOdX87/k+EKGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQoyq2W+btmkfGtnMP4P+zp3Cg6dZvXr5nYzu1ePltVeQp20z3pumaZq+9ZvMZIbq67dihvVPfGeLWZTWWLqVRu0+S2qPN/U6YHMK82uOc3xqW1gj69r4NwwAAID9ENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACDE+p35AFvUneYtPUq2UA89xxpTnnfKOjhRITNxC+fBFtbIulxBAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiFE1+33TNH0f0A3aLpwrD+3wY13h9RjarrRNpS///MXx65hpLTFKz3tBz736h7WXcNfjz944/4HSx2Dil/H2L34/7Q7nVPoMHfy+9VmP3DSH7hEyh3772h/XXsJdP37ribWX0PzpG3fWXsKDm/o7yGzbZczE4vehkpQZVVhH29bNk+eCzv8U6XMo42wCAABAQAMAAEghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBCjavabplmuajSkmvikpVTOMr8ZTrehT8huPjldt/YKToM5BNtgJsJkXEEDAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAECI8TX7nA4V1qej9Fa74wWnQqX/Jh0KA6wzwICJlGZNSc0ccgUNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCjG9xLLVcsS8azU6H03q8g9+37nHs1l7BxZhrm6SpMUTKTNz6HCLWkrMm5GwCAABAQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBDja/aXqiFW578+ldOno/atLpymQw85s4nl1iKLOAwMnFKF9dA2ADXSZ4oraAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACDG+Zh/YrqlbZUv720OffqFavT/eXnAhO3BceLuUYy25/oKnfnBj7SXc9fm3139R3v33S+scuO+GH2sHfjOvvcXDsfJ17up+u+9vb3wmrv+x/FhhHbV/wp/6fuH8L/2tHjpgzTb3s/B3hvQ55AoaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCqNln2KHQeVpb+5twrFO2YI3t0Lu2m3dzqBL7lJUqxFlX0omXtJalLTk3lp5RKTPRHBqn5nysPIfbwnaLj4XwORRyNgEAACCgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACDX7ZFClz9aUqpxT6qYh0dDtPkp/Bha8RQik6jf+XalPOo+XXEvF2+ZbBAAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQqjZB6ihSh+G1VRYJ1VwQ6D2MMNJUtrlkq3+tU9tC3ceqHhuvmEAAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEmn2AGn239goAcpiJrCHlNgETcwUNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhcmv2u8puzFLNa1uZR2vXUrtdjdLzPha2K70mx8KGta/lkNr3rbRdl/H7w+PP3lh7CXd98+vPn/vPu0JP7WHintpf/urdi+/kUOrVXUjI5yuKmu1tqq2pDjgNV1Uzh0rfC6aea7XfQWrXkTITzSE+teEZFXI2AQAAIKABAACEENAAAABCCGgAAAAhBDQAAIAQuS2OS5uj/XHL62AyC3Z5no4lG1Jh72pPp6lPw601Rk49h1LmWso6uEc/x3uT8nZvYR21c6jiufnGDwAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEONr9g8BXbd7roBVpb87AWcMALBxbcJ38L0LeYmlAQAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhBhfs79X6u1hOxKqhvd8uw9YQ8BpPUrCHEpiJsL5KkaFVAIAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBBq9j/Vd8OPLVnBn7IOdqmr6HotbXNohmuVa471YPqmOR5n2jewmlJLe1yjvTnE8nq3MtimirfNN34AAIAQAhoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIRo+/7Bux/btn2raZrX5lsOcAKu931/tXZjcwiYgDkErG1wDo0KaAAAAMzHf+IIAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABDifzVd4tPUgMkaAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = NMF(n_components=5, init='random', random_state=0)\n",
    "A_1_0 = model.fit_transform(X_1)\n",
    "S_1_1 = model.components_\n",
    "\n",
    "model = NMF(n_components=5, init='random', random_state=0)\n",
    "A_2_0 = model.fit_transform(X_2)\n",
    "S_2_1 = model.components_\n",
    "\n",
    "model = NMF(n_components=5, init='random', random_state=0)\n",
    "A_3_0 = model.fit_transform(X_3)\n",
    "S_3_1 = model.components_\n",
    "\n",
    "X_approx_5 = outer_product_np(A_1_0 @ S_1_1, A_2_0 @ S_2_1, A_3_0 @ S_3_1)\n",
    "\n",
    "print(np.linalg.norm(np.ndarray.flatten(X_np-X_approx_5), 2))\n",
    "print(np.linalg.norm(np.ndarray.flatten(X_np-X_approx_5), 2)  / np.linalg.norm(np.ndarray.flatten(X_np), 2))\n",
    "\n",
    "fig, ax = plt.subplots(1, 3, constrained_layout=True, figsize=(12,5))\n",
    "ax[0].axes.get_xaxis().set_ticks([])\n",
    "ax[0].axes.get_yaxis().set_ticks([])\n",
    "ax[1].axes.get_xaxis().set_ticks([])\n",
    "ax[1].axes.get_yaxis().set_ticks([])\n",
    "ax[2].axes.get_xaxis().set_ticks([])\n",
    "ax[2].axes.get_yaxis().set_ticks([])\n",
    "X_max = np.max(X_approx_5,axis=0)\n",
    "ax[0].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(X_approx_5,axis=1)\n",
    "ax[1].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(X_approx_5,axis=2)\n",
    "ax[2].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Rank 3 (for HNCPD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "88.49126581829329\n",
      "0.6661441026870379\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAEkCAYAAABaExIDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAQ70lEQVR4nO3dT68k110G4FN9Z+7YQtkgLBZmwLJRVkGBKJoPQIRA4hvkGzg7PkGIxIZNIBuExAK84AMQCQmkLAJCBEVJECIQKXKEgJDIBAsyHjx/7NuHRbiO4pk6c+t0/Xm7+3mWblfV73ZXnb7vLfutodZaAAAA2N5u6wEAAAD4EQENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABC3JryL//MT1/UV+7eXmqWG/un/3xp6xHi3H7r3a1HYMQwDFuPEOXh/kF5Uh91vynWoVyXrXWodR3M/bgX11y2gI/n4f5BebI//nUoyTffzlgT73z/4dYjbKO1jloTnxbwlrTWoUkB7ZW7t8tX//LuPFMd4GNf+Mz8O219UEfwqLif+/zXxl/cdZ6F+8YP3rvPlDlW/NmGi4tZ93eQgFn+7sEXD9r+pNehI3f3C98Yf3HX+A829vt5B2kdi80Ntyb96rGIrzz4s4O2T1mHknz0jdfXO1jjK/y13/mH9eZI0lpHW9/9PX8gW/MPbgtJX4d8iwEAAIQQ0AAAAEIIaAAAACEENAAAgBDb/x9ynJZW+caxHG/ufW7fy/Fjc5cxMK81y4qWOFZvOYdSj7NSE0oEAkY4K2Prjc+BjaSvQ74VAQAAQghoAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQQs3+tYC2zVJKGRpz1FYtdordMQx5xtSZcw4G69BTEiql/9+Q8PkEjMAB1v78Es7ZUtrXse/3SdLXIZ8mAABACAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCqNm/1mrbnLuduHGsnCJkYHVrLgCnvNgEVcrztJrw+QSMcFbmfr/X/vwSztnn2e/HX7u4WG+OI5G+DrmDBgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIbQ4XgtpTxsar9VW0yTATKxDALAdd9AAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCzf61VnX03BX8jWOt2fbP+RmGZ598tc5/5o0dq3mt3cA/f++l8snPvn7YTmbw83/79tYjxNk/ebL1CByB4VbArx4Hrnkp61CSV792f+sRSimlVOvQU+qT97YeIc7u8vbWIzTXIXfQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQIqDrNkRn4+7QsV1rk579wU0tUac/+VjO8Xyt82Ts8QmllDL4m1+qenW19QgfOPBJG9xU71p77B/QKa9Ddb/1BKzkhM9iAACA4yKgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACDX7C6ozV9XOvT/gRPTW4s99rJQK6FOu2e407HyBRGg+Z2fGbYCj5lsMAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAh1OwfSJX+h+xbfcBn6mLrATh5c1fp9+5PvT03sQs4T+a+Zm583JW2OWQ7OAfh61DAdAAAAJQioAEAAMQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAg1+8DxSXiaQ11giFb199zHW+JYdd+33dzU/bOGhHXoTA2NNar5sfSuUdYUVuaMAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABAiEk1+998+6Xy0TdeX2qWG3vlb97deoQ4uxdf2HqEH9k1qrvP1P7ho61H+MDuzp2tR2jXu9/A7R8+KT/7F/820zD99v/9P1uPEOfhb3xi6xFKKaXsLzvPsd7a9N5Teu6a9sYcP/XFr898sAMcuAYkSFmHkux/eH/rEUoppbz7m788+z7r3Odsa3eNdeHWw77HBNz50t93bXfSwtchd9AAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBiUs1+KWW8GrRVF5zdZAnno/ZV9EYZSnw9LhvzfcTSrEPAgtxBAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAENNbHFvtWHNuA8xv8DcZNtIqvJv7O6K3XG/NGZ93vDG+T2F9axZ2KgeluIMGAAAQQ0ADAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQ02v2oWXf6IDeNbpjW9u1tPbZ0nu8uecA+uvtz7H6OmWOQ1TPCmA5dRi/SAbnHtfCzwV30AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEGJ6zf7crZRjbajZ7Zf0mLvafql99mjMMTQqf1eX8DiAQ0eoZf563KTPiJsLufyZ6BSutzXXodZxWu9leI04bCp8HXIHDQAAIISABgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAISbV7F9+793y6ue+sdQsN7a7c2f2fdZGHW1UTfqIf/mtj42/2Bp/5hbeqxfU+n7Yq7+9/TVzrT74361HKPVqf9gO3n+/7P/r7XmGOcBweTn+YutxBimPhljAW/cuth6hlFLKVeOjWWQ97P2KWOuxNaWU1/4847MppZxG/XvIOpRkeGH+381GNdbRt+7Nf++hzv1rYOc6dHm/7zq++6WuzU5b+DrkDhoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEJMqtlneYvU/Wc3iQKnYvYu6l6NRW+J9TBljU2ZA+Zwwo8k6b5WW9ulLL/Mwh00AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEmFSzP5QDqt6hV0qtbMocB0i4fg+doJb24yjWsv07eWSGVvX9iu9m61C9p1Xv+HOfxk7K1aSsQ0mcfhP0rkPe5LPhDhoAAEAIAQ0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEJMqtlfpFZ2P7K/3XiXaLNltLHd6LFKRv14Kf1z1JCovUhbdsZHkzPHmfO4j2CtKv0lthvd37y7gw+zDq2g8/e5o+e0oriDBgAAEENAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAEJNq9oHjNvtjMnpm2HoAltN6zkZvlX7PPluHUmENJPMlSXEHDQAAIIaABgAAEEJAAwAACCGgAQAAhBDQAAAAQuS2OO7V2ADPcHFRho98ZOspyv7+/a1HiPPyXz8efW2Jgscx+9v+9viUq6utJ/ixIaBK89BG25B1KEl9552tRyillPLyl9+bf6cJ52wp5eJR33Vc/U79lGG/33qE5jrkWwwAACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACFya/Z3GZWmR2PunupevR9b7/hrH2+M0/XsDCG1y8di1SWqVaHe+tx6q9d7z4VDq97nmgPmsFvxb/6dFem961Cd+5eGznVomHvNIJY7aAAAACEENAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACJFbs79XJTpJDalXXrsCtnW4Nd8Sp+uKaim1r2J51inUHefqrZtfu6ZeLf4Ry1iHonRW36+p+1elNa/VzmNV68lJcQcNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhJtXsD8NQhouLeSfo2F29umq82Jk5W/WwuxVzbOcct99ZYJYOF487a157G8t7W2XnbkhvzRH0yIjhYvu/yRxeBDyUMmz/cwCd1vxOHXNwJfkRrEOtxwCkz84zqdKfUfg6FDAdAAAApQhoAAAAMQQ0AACAEAIaAABACAENAAAghIAGAAAQYlLNfqm1XQO/lrmr/tleqzl2iZb6nqbanLb8brVu/0McPkFt10fDElrXzprV1ylznL0jX4eOefaFDI1Lq4ZcWkPAdzjrcAcNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhptXsD0NExX198t7WI8R5+fe/uvUIjHj0a7+y9Qgf+O6nRq7fVodwq3u4w+Pf+6vDdlBrKSNrQOsxAkNHBXlrf7sXXxjfcN94z3aNOXq3Gxp/axur024dq6Uxx+VXvtW3z9Zns2atfO+xjqD6/tt/8kvjL67c3P2dT/3xugd8hnu//uCwHexrqY+fjLzWqLDfne7fxYcXXxx/cWwdaq1dLY116M6X/7Frlz3fEaWs+/ia1oytOd783XuNnc48/xLPJJh7xlLKm5/+w9n3OVVrHTrdlQIAAODICGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBiWs1+Ke06YeBprcbZlMtpgQpbAACmcwcNAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCCGgAAAAhptfsA9NosOemdq1nMqx4rH3jpB38XY8bWPFUBp7jGB6lcwwzrsg3LQAAQAgBDQAAIISABgAAEEJAAwAACCGgAQAAhJje4jioZoJJWpfM2qVFY7O05ui95BUyTddqT5y74bF1LDiU0wty1Mb3h/bESO6gAQAAhBDQAAAAQghoAAAAIQQ0AACAEAIaAABACAENAAAgxPSa/aqOEyZJumT2HVXtSfOfurmr9HuPpYKfm/DUnc0NjUcfVb+vcW3uKv21Hx+U9LiilbiDBgAAEEJAAwAACCGgAQAAhBDQAAAAQghoAAAAIQQ0AACAENNr9gE2VPe17B8+mr7d3HPU/cx7PH5VPX+s1/4g53z9+Nc/M+v+aqOCe6xd/M23Pn/YMWst9dHjg/Yxh6QrrtaHW49QSulfh5LeyzG9M/7in74z6xyn4OP/MX0d6llrWlrrkDtoAAAAIQQ0AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIoWYfzklPDywcmWHX6EI+U1FX/szDND/tsWNFvSGkS1lTeh8hMNTx7eow/rONbdezzSF6j9farmcN6FprOrdxBw0AACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACHU7MPSWr2sS1Q9t4439lprjox2Ybix3ipqVrLmmjJ2LOsaE8SsKXW/7uFaNfUhx+qece4frWd/jW3cQQMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAg1+3BqWm3APU3BIe3CP2EX0JF9tfUAwKYS1qEk1sTlDX33VY6hLp+f5A4aAABACAENAAAghIAGAAAQQkADAAAIIaABAACEENAAAABCqNkHjs8+sfsfOCvWIWAh7qABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACDEtJr9oZTh1rM3qXW8bnYYhkmHed7+dpe3xzds1d7ups/xXEMj49b9s/95bzXvxcX4oR4/7ttn67NpfAbN7Xr0HmvNGTv9+6ffH39xiZbmxo/9nV/9owUOOM29N35w+E6WuJanutp6AGBTCetQEmsizMYdNAAAgBACGgAAQAgBDQAAIISABgAAEEJAAwAACDGtxfGUrd3+CEtZohkSADhKQ6Pxus7ceN06Vsvccxw7d9AAAABCCGgAAAAhBDQAAIAQAhoAAEAIAQ0AACCEgAYAABBCzf61Nav0W8dq1f3DNW20AAAnyR00AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEULMPAAAnqg6ezXNs3EEDAAAIIaABAACEENAAAABCCGgAAAAhBDQAAIAQAhoAAEAINfvX9nX8td3M9aStY8FNnPsp5BoCtmYd4kgMdfxcVcGfyR00AACAEAIaAABACAENAAAghIAGAAAQQkADAAAIIaABAACEGGqjevOpf3kYflBK+dflxgHOwC/UWl/q3dg6BMzAOgRsbXQdmhTQAAAAWI7/xBEAACCEgAYAABBCQAMAAAghoAEAAIQQ0AAAAEIIaAAAACEENAAAgBACGgAAQAgBDQAAIMT/AYQ48lUm4a8OAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 864x360 with 3 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = NMF(n_components=3, init='random', random_state=0)\n",
    "A_1_1 = model.fit_transform(S_1_1)\n",
    "S_1_2 = model.components_\n",
    "\n",
    "model = NMF(n_components=3, init='random', random_state=0)\n",
    "A_2_1 = model.fit_transform(S_2_1)\n",
    "S_2_2 = model.components_\n",
    "\n",
    "model = NMF(n_components=3, init='random', random_state=0)\n",
    "A_3_1 = model.fit_transform(S_3_1)\n",
    "S_3_2 = model.components_\n",
    "\n",
    "X_approx_3 = outer_product_np(A_1_0 @ A_1_1 @ S_1_2, A_2_0 @ A_2_1 @ S_2_2, A_3_0 @ A_3_1 @ S_3_2)\n",
    "\n",
    "print(np.linalg.norm(np.ndarray.flatten(X_np-X_approx_3), 2))\n",
    "print(np.linalg.norm(np.ndarray.flatten(X_np-X_approx_3), 2) / np.linalg.norm(np.ndarray.flatten(X_np), 2))\n",
    "\n",
    "\n",
    "fig, ax = plt.subplots(1, 3, constrained_layout=True, figsize=(12,5))\n",
    "ax[0].axes.get_xaxis().set_ticks([])\n",
    "ax[0].axes.get_yaxis().set_ticks([])\n",
    "ax[1].axes.get_xaxis().set_ticks([])\n",
    "ax[1].axes.get_yaxis().set_ticks([])\n",
    "ax[2].axes.get_xaxis().set_ticks([])\n",
    "ax[2].axes.get_yaxis().set_ticks([])\n",
    "X_max = np.max(X_approx_3,axis=0)\n",
    "ax[0].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(X_approx_3,axis=1)\n",
    "ax[1].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "X_max = np.max(X_approx_3,axis=2)\n",
    "ax[2].imshow(X_max, vmin=vmin, vmax=vmax)\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
