{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "89b20c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torchvision\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import glob\n",
    "import os\n",
    "from datetime import datetime\n",
    "import time\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "\n",
    "from itertools import repeat\n",
    "from torch.nn.parameter import Parameter\n",
    "import collections\n",
    "import matplotlib\n",
    "from torch_utils import *\n",
    "from ContrastiveModels import ContrastiveCorInfoMaxHopfield\n",
    "from visualization import *\n",
    "# matplotlib.use('Agg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ef0c0a15",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "128dc72a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), \n",
    "                                            torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), \n",
    "                                            std=(3*0.2023, 3*0.1994, 3*0.2010))])\n",
    "\n",
    "cifar_dset_train = torchvision.datasets.CIFAR10('../../data', train=True, transform=transform, target_transform=None, download=True)\n",
    "train_loader = torch.utils.data.DataLoader(cifar_dset_train, batch_size=20, shuffle=True, num_workers=0)\n",
    "\n",
    "cifar_dset_test = torchvision.datasets.CIFAR10('../../data', train=False, transform=transform, target_transform=None, download=True)\n",
    "test_loader = torch.utils.data.DataLoader(cifar_dset_test, batch_size=20, shuffle=False, num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "64eeddcb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAYj0lEQVR4nO3da3Bc9XnH8e9jyTL4fpF8k21swFxssMGWHXKHkBADSQwB29BOk6btMJ6GXl4VZnp70b5JM+10OiX1eFqm7UwmaI0NcUCEJIRLkxSile8XDMIG70rCkm/y3bo9fbFrV8gr7ZG8u+fs7u8zo/Hunr92H/9X/vnof84+x9wdEREpfqPCLkBERHJDgS4iUiIU6CIiJUKBLiJSIhToIiIlojKsF66urvb58+eH9fIiIkWpqanpqLvXZNoWWqDPnz+feDwe1suLiBQlM/tosG1achERKREKdBGREqFAFxEpEQp0EZESoUAXESkRWQPdzJ41s3Yz2zPIdjOzfzGzZjPbZWbLcl+miIhkE2QP/T+BVUNsvx9YmP56Avi3qy9LRESGK+t56O7+lpnNH2LIauC/PdWH920zm2xms9y9LVdFikh0NX10nDcPdIRdRlGpmz+VL9yU8bNBVyUXHyyqBRL97ifTj10R6Gb2BKm9eObNm5eDlxaRMO1p6WTNhv+lz8Es7GqKx/ov3hDZQM/0Nma8aoa7bwQ2AtTV1enKGiJF7ofvHKaqchS/fupLTBs/Juxyyl4uznJJAnP73Z8DtObgeUUkws519fCTna08ePtshXlE5CLQtwLfSp/tchfQqfVzkdL3yu6POXOxh7V1c8IuRdKyLrmY2Y+Au4FqM0sCfwuMBnD3DUAD8ADQDJwDvpOvYkUkOurjCRZUj2PlgqlhlyJpQc5yeTzLdge+m7OKRCTyDh09y28PHecvVt2M6WhoZOiToiIybJviCUYZPLJMyy1RokAXkWHp6e1j87Yk99w8nRkTrwm7HOlHgS4iw/LW+x0cOXWRNXVzsw+WglKgi8iwxBqTVI+v4t5bp4ddigygQBeRwI6eucgv9h/h4TtrGV2h+IgavSMiEtiL21vo6XPWarklkhToIhKIu1PfmODOeZNZOGNC2OVIBgp0EQlke+Ik77efYZ32ziNLgS4igWyKJ7h2dAUPLpkVdikyCAW6iGSVasTVxoNLZjHhmtFhlyODUKCLSFYNlxtxabklyhToIpJVrDHB9dXjWDF/StilyBAU6CIypIMdZ/jth8dZUzdXjbgiToEuIkPa1JSkYpTxyLLasEuRLBToIjKont4+NjcluefmGqarEVfkKdBFZFBvvtdB+2k14ioWCnQRGVQsnqB6fBVfukWNuIqBAl1EMuo4fZHX9rfzzWVz1IirSOhdEpGM/r8Rl65KVCwU6CJyBXenPp5g2bzJ3DhdjbiKhQJdRK6w7fBJmtvPsG6FDoYWEwW6iFxhUzzB2KoKHlwyO+xSZBgU6CLyCWcv9vCTna08ePssxo+pDLscGQYFuoh8QsPuNs529Wq5pQgp0EXkE2LxBNfXjGP5dWrEVWwU6CJy2cGOMzR+eIK1asRVlBToInJZLJ5qxPVNNeIqSgp0EQHSjbi2Jbnn5ulMn6BGXMVIgS4iALxxoIOO0xf1ydAipkAXEeBSI64x3KNGXEVLgS4idJy+yC/fbeeRZbVqxFXEAr1zZrbKzA6YWbOZPZ1h+yQz+4mZ7TSzvWb2ndyXKiL5smVbkp4+V9/zIpc10M2sAngGuB9YBDxuZosGDPsusM/dlwJ3A/9oZlU5rlVE8sDdicUTLL9uCjdOHx92OXIVguyhrwSa3f2gu3cBzwGrB4xxYIKlTlwdDxwHenJaqYjkxbbDJ/ig4yzrtHde9IIEei2Q6Hc/mX6sv38FbgVagd3An7l738AnMrMnzCxuZvGOjo4RliwiuRRrTDK2qoIHlswKuxS5SkECPdPHxXzA/a8CO4DZwB3Av5rZxCu+yX2ju9e5e11NTc0wSxWRXDt7sYeXdrXytSVqxFUKggR6Euj/u9gcUnvi/X0H2OIpzcAh4JbclCgi+fKyGnGVlCCB3ggsNLMF6QOdjwFbB4w5DNwLYGYzgJuBg7ksVERyL9aYasS1bJ4acZWCrIHu7j3Ak8CrwH4g5u57zWy9ma1PD/s74DNmtht4DXjK3Y/mq2gRuXofdJwh/tEJ1qkRV8kItGjm7g1Aw4DHNvS73Qrcl9vSRCSfYvEEFaOMh9WIq2ToI2EiZai7t4/NTS186RY14iolCnSRMvTGgQ6OnrnIWp17XlIU6CJlqL4xQc2EMdxzs04fLiUKdJEy0376Aq8faOeby2qpVCOukqJ3U6TMbNnWQm+fa7mlBCnQRcrIpUZcdddN4YYaNeIqNQp0kTLS9NEJDnacZa0+GVqSFOgiZSQWTzCuqoIHb1cjrlKkQBcpE2cu9vDSrja+tmQ249SIqyQp0EXKRMOuNs519Wq5pYQp0EXKRH08wQ0141g2b3LYpUieKNBFykBz+xmaPjrBuhVqxFXKFOgiZWBTPEHlKOPhO+eEXYrkkQJdpMR19/axeVuqEVfNhDFhlyN5pEAXKXGvv9uuRlxlQoEuUuJi8VQjrrvViKvkKdBFSlj7qQu8fqCDR5bNUSOuMqB3WKSEbb7ciEsHQ8uBAl2kRLk7m+IJVs6fyvVqxFUWFOgiJSr+0QkOHj3LGu2dlw0FukiJijWmG3EtUSOucqFAFylBZy728PLuNr6+dDZjq9SIq1wo0EVK0Mu7WtWIqwwp0EVKUH1jghunj+fOuZPDLkUKSIEuUmKa20+z7fBJ1tWpEVe5UaCLlJhYPJlqxLWsNuxSpMAU6CIlpLu3jy3bktx763Sqx6sRV7lRoIuUkF++287RM11qxFWmFOgiJSTWmGD6hDF88SY14ipHCnSREnHk1AVeP9DOI8vViKtcBXrXzWyVmR0ws2Yze3qQMXeb2Q4z22tmb+a2TBHJZvO2JH2OllvKWNaPkJlZBfAM8BUgCTSa2VZ339dvzGTgB8Aqdz9sZtPzVK+IZJBqxJVk5YKpLKgeF3Y5EpIge+grgWZ3P+juXcBzwOoBY34H2OLuhwHcvT23ZYrIUBo/PMGho2e1d17mggR6LZDodz+Zfqy/m4ApZvaGmTWZ2bcyPZGZPWFmcTOLd3R0jKxiEblCLJ5g/JhKHrh9ZtilSIiCBHqmj5r5gPuVwHLgQeCrwF+b2U1XfJP7Rnevc/e6mhodhRfJhdMXunl5VxtfXzpLjbjKXJB3Pwn0/z1uDtCaYcxRdz8LnDWzt4ClwHs5qVJEBvXyrjbOd/dquUUC7aE3AgvNbIGZVQGPAVsHjPkx8HkzqzSzscCngP25LVVEMqmPJ1g4fTx3qBFX2cu6h+7uPWb2JPAqUAE86+57zWx9evsGd99vZj8FdgF9wL+7+558Fi4i8P6R02w/fJK/evBWNeKSQEsuuHsD0DDgsQ0D7n8f+H7uShORbGLxBJWjjIfuVCMu0SdFRYpWV08fW7a18OVbZ6gRlwAKdJGi9ct32zl2tou1K3QRaElRoIsUqVg8wYyJY/jCQp0CLCkKdJEidOTUBd440M6jasQl/egnQaQIPd+UasS1ZrnOPZf/p0AXKTKpRlwJPrVgKvPViEv6UaCLFJnfHjrOh8fO6ZOhcgUFukiRicWT6UZcs8IuRSJGgS5SRE5f6KZhdxtfXzqba6sqwi5HIkaBLlJEXko34lq3QsstciUFukgRqW9McNOM8SydMynsUiSCFOgiReK9I6fZkTjJ2rq5asQlGSnQRYpErDHB6ArjYTXikkEo0EWKQFdPH1u2pxpxTVMjLhmEAl2kCPzy3SMcP9ulc89lSAp0kSJQ35hg5sRr+MJNasQlg1Ogi0Tcx50XePO9Dh5dPoeKUToYKoNToItE3OZt6UZcdep7LkNToItEmLsTiye46/qpXDdNjbhkaAp0kQh759BxPlIjLglIgS4SYbF4ggljKrn/NjXikuwU6CIRdepSI6471IhLglGgi0TUSzvbuNDdxzott0hACnSRiKqPJ7h5xgSWqBGXBKRAF4mgAx+fZmfiJGtXqBGXBKdAF4mgWFyNuGT4FOgiEdPV08cL21v4yqIZTB1XFXY5UkQU6CIR89p+NeKSkVGgi0RMfTzBrEnX8PmFasQlw6NAF4mQts7zvKVGXDJCCnSRCNnclG7EtVzLLTJ8gQLdzFaZ2QEzazazp4cYt8LMes3s0dyVKFIe+vqcWDzJp6+fxrxpY8MuR4pQ1kA3swrgGeB+YBHwuJktGmTc94BXc12kSDl459BxDh8/x9oVapMrIxNkD30l0OzuB929C3gOWJ1h3J8Am4H2HNYnUjY2xRNMuEaNuGTkggR6LZDodz+ZfuwyM6sFHgY2DPVEZvaEmcXNLN7R0THcWkVK1qkL3TTsaeMbS2dzzWg14pKRCRLomQ61+4D7/ww85e69Qz2Ru2909zp3r6up0SlZIpds3dGaasS1QgdDZeQqA4xJAv1/yuYArQPG1AHPpXtOVAMPmFmPu7+YiyJFSt2meIJbZk7g9lo14pKRC7KH3ggsNLMFZlYFPAZs7T/A3Re4+3x3nw88D/yxwlwkmHc/PsXOZCdr69SIS65O1j10d+8xsydJnb1SATzr7nvNbH16+5Dr5iIytFhjkqqKUWrEJVctyJIL7t4ANAx4LGOQu/vvX31ZIuXhYk8vL2xP8pVFM5iiRlxylfRJUZEQvba/nRPnulmrg6GSAwp0kRDVNyaYPekaPndjddilSAlQoIuEpPXked56X424JHcU6CIh2dyUxB0eVSMuyREFukgI+vqcTU1JPnODGnFJ7ijQRULw9qFjqUZcuiqR5JACXSQEscZUI65Vt80MuxQpIQp0kQLrPN/NK3s+ZvUdasQluaVAFymwrTtbudjTx7q6eWGXIiVGgS5SYJcacd1WOzHsUqTEKNBFCmh/2yl2JTtZt0KNuCT3FOgiBRSLJ6iqGMVDd6gRl+SeAl2kQFKNuFr4ymI14pL8UKCLFMgv9rVz8lw363TuueSJAl2kQOrjqUZcn1UjLskTBbpIAbSePM//vN/Bo3Vz1YhL8kaBLlIAz6cbca1ZPifsUqSEKdBF8izViCvBZ2+cxtypasQl+aNAF8mztw8eI3H8vBpxSd4p0EXyrD6eYOI1lXx1sRpxSX4p0EXyqPPcpUZctWrEJXmnQBfJo607W+jq6WOdLgItBaBAF8mjWDzJolkTua12UtilSBlQoIvkyb7WU+xu6WRtnU5VlMJQoIvkyeVGXHeqEZcUhgJdJA8u9vTy4o4W7ls8g8lj1YhLCkOBLpIHP993JNWISwdDpYAU6CJ5UN+YoHbytXz2BjXiksJRoIvkWMvJ8/yq+SiPLp/DKDXikgJSoIvk2PPxJACPqhGXFFigQDezVWZ2wMyazezpDNt/18x2pb9+Y2ZLc1+qSPRdbsR1Q7UacUnBZQ10M6sAngHuBxYBj5vZogHDDgFfdPclwN8BG3NdqEgx+N+Dx0ieOM8anXsuIQiyh74SaHb3g+7eBTwHrO4/wN1/4+4n0nffBvTTLGWpvlGNuCQ8QQK9Fkj0u59MPzaYPwReybTBzJ4ws7iZxTs6OoJXKVIEOs9189O9H/PQnWrEJeEIEuiZDtN7xoFm95AK9KcybXf3je5e5+51NTU1wasUKQI/TjfiUt9zCUtlgDFJoP9P6BygdeAgM1sC/Dtwv7sfy015IsUjFk+weLYacUl4guyhNwILzWyBmVUBjwFb+w8ws3nAFuD33P293JcpEm17WzvZ03JKe+cSqqx76O7eY2ZPAq8CFcCz7r7XzNant28A/gaYBvzAzAB63L0uf2WLRMumeJKqylGsvmN22KVIGQuy5IK7NwANAx7b0O/2HwF/lNvSRIrDhe5eXtjewlcXz1QjLgmVPikqcpV+vu8Inee7WaflFgmZAl3kKsXiqUZcn7lhWtilSJlToItcheSJc/yq+Shr6tSIS8KnQBe5Cs83qRGXRIcCXWSE+vqcTfEkn7uxmjlT1IhLwqdAFxmh33xwjJaT51mjg6ESEQp0kRGqjyeYdO1o7ls0I+xSRAAFusiInDzXxat7P+ZhNeKSCFGgi4zAj3e00tXTp77nEikKdJERiMUT3FY7kcWz1YhLokOBLjJMe1o62duqRlwSPQp0kWHaFE+kGnEtHeo6LyKFp0AXGYYL3b28uKOVVYtnMmns6LDLEfkEBbrIMPzsUiOuFVpukehRoIsMQ6wxwZwp1/Lp69WIS6JHgS4SUOJ4uhHX8rlqxCWRpEAXCej5piRm8KjOPZeIUqCLBNDb5zzflGrEVTv52rDLEclIgS4SwG8+OErLyfM691wiTYEuEkB9Y4LJY0dz32I14pLoUqCLZHHyXBc/23uEh+6oZUylGnFJdCnQRbJ4cXsLXb19Wm6RyFOgi2QRiye5vXYSi2ZPDLsUkSEp0EWGsKelk31tp1irUxWlCCjQRYYQiycYUzmKb9yhRlwSfQp0kUFc6O7lxe0trLptJpOuVSMuiT4FusggXt37Macu9LBOB0OlSCjQRQYRiyeYO/Va7lIjLikSCnSRDBLHz/Hr5mNqxCVFRYEuksGmdCOuR5br7BYpHgp0kQF6+5zn4wk+v7BGjbikqAQKdDNbZWYHzKzZzJ7OsN3M7F/S23eZ2bLclypSGL9uPkpr5wUdDJWikzXQzawCeAa4H1gEPG5miwYMux9YmP56Avi3HNcpUjA/+u1hpowdzZcXTQ+7FJFhqQwwZiXQ7O4HAczsOWA1sK/fmNXAf7u7A2+b2WQzm+Xubbku+M33Ovj7l/ZlHygyAg40t5/hyXtuVCMuKTpBAr0WSPS7nwQ+FWBMLfCJQDezJ0jtwTNv3rzh1grA+DGVLJwxfkTfKxLEl26Zzp/euzDsMkSGLUigZzpny0cwBnffCGwEqKuru2J7EMuvm8Ly65aP5FtFREpakIOiSaD/0aE5QOsIxoiISB4FCfRGYKGZLTCzKuAxYOuAMVuBb6XPdrkL6MzH+rmIiAwu65KLu/eY2ZPAq0AF8Ky77zWz9entG4AG4AGgGTgHfCd/JYuISCZB1tBx9wZSod3/sQ39bjvw3dyWJiIiw6FPioqIlAgFuohIiVCgi4iUCAW6iEiJsNTxzBBe2KwD+GiE314NHM1hObkS1bogurWpruFRXcNTinVd5+41mTaEFuhXw8zi7l4Xdh0DRbUuiG5tqmt4VNfwlFtdWnIRESkRCnQRkRJRrIG+MewCBhHVuiC6tamu4VFdw1NWdRXlGrqIiFypWPfQRURkAAW6iEiJKIpAN7Pvm9m76QtQv2BmkwcZN+TFrPNQ1xoz22tmfWY26ClIZvahme02sx1mFo9QXYWer6lm9nMzez/955RBxhVkvqJ68fMAdd1tZp3p+dlhZn9ToLqeNbN2M9szyPaw5itbXWHN11wze93M9qf/Pf5ZhjG5nTN3j/wXcB9Qmb79PeB7GcZUAB8A1wNVwE5gUZ7ruhW4GXgDqBti3IdAdQHnK2tdIc3XPwBPp28/nel9LNR8Bfn7k2oJ/QqpK3LdBbxTgPcuSF13Ay8V6uep3+t+AVgG7Blke8HnK2BdYc3XLGBZ+vYE4L18/4wVxR66u//M3XvSd98mdUWkgS5fzNrdu4BLF7POZ1373f1APl9jJALWVfD5Sj//f6Vv/xfwUJ5fbyhB/v6XL37u7m8Dk81sVgTqCoW7vwUcH2JIGPMVpK5QuHubu29L3z4N7Cd1reX+cjpnRRHoA/wBqf/RBhrsQtVR4MDPzKwpfaHsKAhjvmZ4+kpW6T+nDzKuEPMV5O8fxhwFfc1Pm9lOM3vFzBbnuaagovxvMNT5MrP5wJ3AOwM25XTOAl3gohDM7BfAzAyb/tLdf5we85dAD/DDTE+R4bGrPiczSF0BfNbdW81sOvBzM3s3vVcRZl0Fn69hPE3O5yuDnF38PMeCvOY2Uv08zpjZA8CLwMI81xVEGPMVRKjzZWbjgc3An7v7qYGbM3zLiOcsMoHu7l8earuZfRv4GnCvpxefBsjLhaqz1RXwOVrTf7ab2Qukfq2+qoDKQV0Fny8zO2Jms9y9Lf1rZfsgz5Hz+cogqhc/z/qa/UPB3RvM7AdmVu3uYTehiuTF4sOcLzMbTSrMf+juWzIMyemcFcWSi5mtAp4CvuHu5wYZFuRi1gVnZuPMbMKl26QO8GY8Gl9gYczXVuDb6dvfBq74TaKA8xXVi59nrcvMZpqZpW+vJPXv+Fie6woikheLD2u+0q/5H8B+d/+nQYblds4KfeR3JF+kLj6dAHakvzakH58NNPQb9wCpI8kfkFp6yHddD5P6H/YicAR4dWBdpM5W2Jn+2huVukKar2nAa8D76T+nhjlfmf7+wHpgffq2Ac+kt+9miDOZClzXk+m52UnqJIHPFKiuHwFtQHf65+sPIzJf2eoKa74+R2r5ZFe/7Hogn3Omj/6LiJSIolhyERGR7BToIiIlQoEuIlIiFOgiIiVCgS4iUiIU6CIiJUKBLiJSIv4PNnXKsEFAEmoAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAdL0lEQVR4nO3deXhU5d3/8ffXsMkaIGHfIWwqS4ggaK17ARceuyiLrVsfSgSX6qPSYm2fLr9WrVotCOWxtLYgCIqKFUVtVdpalSQQIEAggEAIkLCvIdv9+yNjrzQk5CTMzJmZfF7XlYuZOSdzPrmTfHNzz5nvMeccIiIS/c7zO4CIiASHCrqISIxQQRcRiREq6CIiMUIFXUQkRjTw68AJCQmuR48efh1eRCQqpaen73fOJVa1zbeC3qNHD9LS0vw6vIhIVDKzHdVt05KLiEiMUEEXEYkRKugiIjFCBV1EJEaooIuIxIgaC7qZzTOzfDNbX812M7PnzSzHzNaaWXLwY4qISE28zND/CIw+y/YxQFLgYzIw+9xjiYhIbdVY0J1zK4GDZ9llHPAnV+5TIN7MOgYroIhILHnugy1k5R0JyXMH441FnYFdFe7nBh7bU3lHM5tM+Syebt26BeHQIiLR47X0XJ79YDPFpWVc0KlV0J8/GC+KWhWPVXnVDOfcXOdcinMuJTGxyneuiojEpE17jzLjjXWM7NWWB65JCskxglHQc4GuFe53AfKC8LwiIjHhaGExqfMzaNmkIc9PGEqDuNCcYBiMZ10GfCdwtsslwBHn3BnLLSIi9ZFzjkeWrGXnwZPMnJhMYovGITtWjWvoZrYQuAJIMLNc4MdAw0DQOcByYCyQA5wE7gxVWBGRaPP7f2zn3ay9zBg7gOE924T0WDUWdOfchBq2O2Bq0BKJiMSIVV8c5JfvbGL0BR347ld6hvx4eqeoiEgIFBw7zdQFGXRtfT5PfmsQZlWdPxJcvvVDFxGJVSWlZdy3cDVHC4t56a7htGzSMCzHVUEXEQmyp9/fzL+2HeDX3xrMgI4tw3ZcLbmIiATR+xv2MfujrUwY3o1vDusS1mOroIuIBMnOAyd5cPEaLuzckh/fODDsx1dBFxEJgsLiUlIXpHOeGbMnDaNJw7iwZ9AauohIEPxkWRZZeUeZd0cKXds09SWDZugiIudoSdouFq3axdQre3NV//a+5VBBFxE5BxvyjvLYG+sZ1bstD17bz9csKugiInV0tLCYexakE9+0vOlW3Hmhf/PQ2WgNXUSkDpxz/M/iTHIPnWLR5EtIaB66plteaYYuIlIH//f3bby3YR/Tx/QnpUdom255pYIuIlJLn207wBPvZjP2og7cfVnom255pYIuIlIL+ccKmbZwNd3bNOWJb4Sn6ZZXWkMXEfGopLSMaS+v5lhhMX++ezgtwtR0yysVdBERj556L5vPtx/k2VsH079D+JpueaUlFxERD97L2svvPt7GpBHduHloeJtueaWCLiJSgx0HTvDQkkwGdWnF4z403fJKBV1E5CwKi0uZMj+D88yYNTGZxg3C33TLK62hi4icxeNvrmfjnqP84Y6LfWu65ZVm6CIi1Vi8aheL03K596o+XNm/nd9xaqSCLiJShay8I/zozfVc1ieBB67p63ccT1TQRUQqOXKqmNT5GbRu2ojnxg/xvemWV1pDFxGpwDnH/yzJJO/wKV753kjaRkDTLa80QxcRqeB3K7fx/oZ9/HDsAIZ1b+13nFpRQRcRCfjX1gM8+e4mrh/UkTsv7eF3nFpTQRcRAfKPFnLvwtX0SGgWcU23vNIauojUe8WBplsnTpfw8n+PoHnj6CyN0ZlaRCSInlqRzedfHOS58UPo276F33HqTEsuIlKvvbt+L3NXbuPbl3Rn3JDOfsc5JyroIlJvbd9/goeXZDK4azyP3TDA7zjnzFNBN7PRZpZtZjlmNr2K7a3M7C0zyzSzLDO7M/hRRUSC51RRKanz04mLM2ZNHBrRTbe8qrGgm1kcMAsYAwwEJphZ5f6RU4ENzrnBwBXA02bWKMhZRUSCwjnHj95cT/a+Y/zm1iF0aR3ZTbe88jJDHw7kOOe2OeeKgEXAuEr7OKCFlZ/n0xw4CJQENamISJC8smoXr6bncu9VSVzRL/KbbnnlpaB3BnZVuJ8beKyimcAAIA9YB9zvnCur/ERmNtnM0swsraCgoI6RRUTqbv3uIzy+LIuvJCVw/9VJfscJKi8Fvaqz612l+18D1gCdgCHATDM744J7zrm5zrkU51xKYmJiLaOKiJybIyeLSV2QTttmjXhu/NCoabrllZeCngt0rXC/C+Uz8YruBJa6cjnAdqB/cCKKiJy7sjLHQ0vWsPdIIbMmJdOmWey9zOeloK8CksysZ+CFzvHAskr77ASuBjCz9kA/YFswg4qInIs5K7fywcZ8ZowdQHK36Gq65VWN7xR1zpWY2TRgBRAHzHPOZZnZlMD2OcDPgD+a2TrKl2gedc7tD2FuERHPPtm6n1+vyObGwZ24fVQPv+OEjKe3/jvnlgPLKz02p8LtPOC64EYTETl3e48Uct/C1fRKbM6vvn5RVDbd8kq9XEQkZpU33crgZFEpiyYn0yxKm255FdtfnYjUa0+8s4m0HYd4fsJQ+rSL3qZbXqmXi4jEpHfW7eHFf2zn9pHduWlwJ7/jhIUKuojEnG0Fx3n41bUM6RrPjOsrdyqJXSroIhJTThWVcs+CDBrGGbMmJdOoQf0pc1pDF5GY4ZxjxhvryN53jJfuHE7n+PP9jhRW9edPl4jEvIWf72Jpxm7uvzqJy/vWv/YiKugiEhPW5R7hJ8uyuLxvIvddFVtNt7xSQReRqHf4ZBGpC9JJaN6I39w6hPNirOmWV1pDF5GoVlbmeHBxJvuOFrJkyqiYbLrllWboIhLVZn+8lb9tyudHNwxkSNd4v+P4SgVdRKLWP3P28/R72Ywb0olvX9Ld7zi+U0EXkaj0ZdOt3onN+WWMN93ySgVdRKJOcWkZU1/OoLC4lNm3DaNpI70cCHpRVESi0C+XbyJ9xyFmThxKn3bN/Y4TMTRDF5Go8vbaPcz753buGNWDGwbVj6ZbXqmgi0jU2FpwnEdezSS5Wzw/HDvA7zgRRwVdRKLCyaISUuen07hhXL1ruuWV1tBFJOI555jx+nq25B/nT3cNp2Or+tV0yyv9iRORiLfgs528vno337+mL19Jqn9Nt7xSQReRiLY29zA/fWsDV/RLZNqVffyOE9FU0EUkYh06UUTq/AwSWzTm2Vvqb9Mtr7SGLiIRqazM8f3Fayg4dpolU0bSuh433fJKM3QRiUgzP8zho+wCHr9xIIPredMtr1TQRSTi/H1LAc9+sJmbh3Zm0ohufseJGiroIhJR8g6f4v5Fa0hq15xf3Hyhmm7Vggq6iESMopLypltFJWVqulUHGi0RiRj/b/lGVu88zAuTkumdqKZbtaUZuohEhLcy8/jjJ19w16U9GXtRR7/jRCUVdBHxXU7+caa/tpZh3Vvzg7H9/Y4TtTwVdDMbbWbZZpZjZtOr2ecKM1tjZllm9nFwY4pIrDpxurzpVpOGccyamEzDOM0z66rGNXQziwNmAdcCucAqM1vmnNtQYZ944AVgtHNup5m1C1FeEYkhzjl++Po6thYc5893j6BDqyZ+R4pqXv4UDgdynHPbnHNFwCJgXKV9JgJLnXM7AZxz+cGNKSKxaP6nO3hzTR4PXtuXS/sk+B0n6nkp6J2BXRXu5wYeq6gv0NrMPjKzdDP7TlVPZGaTzSzNzNIKCgrqllhEYsKaXYf56V82cFX/dtxzhZpuBYOXgl7VWf2u0v0GwDDgeuBrwI/MrO8Zn+TcXOdcinMuJTFRLTBF6quDJ4qYuiCD9i2b8Mwtg9V0K0i8nIeeC3StcL8LkFfFPvudcyeAE2a2EhgMbA5KShGJGaVljgdeKW+69VrqKOKbqulWsHiZoa8Cksysp5k1AsYDyyrt8ybwFTNrYGZNgRHAxuBGFZFY8Nu/bWHl5gJ+ctMFXNSlld9xYkqNM3TnXImZTQNWAHHAPOdclplNCWyf45zbaGbvAmuBMuBF59z6UAYXkejz8eYCnvvrFr6e3JkJw7vW/AlSK+Zc5eXw8EhJSXFpaWm+HFtEwm/34VPc8Pzfad+yCa/fcynnN4rzO1JUMrN051xKVdt0Br+IhFxRSRlTF2RQXOp4YVKyinmIqDmXiITcL97ewJpdh5lzWzK91HQrZDRDF5GQWpaZx0v/2sF3L+vJ6AvVdCuUVNBFJGS27DvG9NfWcnGP1jw6Rk23Qk0FXURC4sTpElIXZNC0URwz1XQrLLSGLiJB55xj+tJ1bCs4zvzvjqB9SzXdCgf9yRSRoPvTv3bwVmYeD13Xj1G91XQrXFTQRSSoMnYe4udvb+Dq/u1I/Wpvv+PUKyroIhI0B46fZuqCDDq0asIztwxR060w0xq6iATFl023DpwoYmnqKFo1beh3pHpHM3QRCYrn/rqFv2/Zz09vuoALO6vplh9U0EXknH2Unc9v/7aFbw7rwq0Xq+mWX1TQReSc5B46yQOvrKFf+xb8bNyFmGnd3C8q6CJSZ6dLSpm6IIPSUsec24ap6ZbP9KKoiNTZz/+ykczcI8y5bRg9Epr5Hafe0wxdROrkzTW7+fOnO5h8eS9GX9jB7ziCCrqI1MHmfceY/to6hvdowyNf6+d3HAlQQReRWjl+uoQp89Np1rgBMycOpYGabkUMfSdExDPnHI++tpYv9p/gtxOG0k5NtyKKCrqIePbHT77g7bV7ePhr/RnZu63fcaQSFXQR8SR9x0F+8fZGrhnQnilf7eV3HKmCCrqI1Gj/8dNMXbCaTvHn8/Qtg/XmoQil89BF5KxKyxz3L1rNoZNFLL1nFK3OV9OtSKWCLiJn9ZsPNvPPnAM8+Y1BXNBJTbcimZZcRKRaH27K57d/y+GWlC7coqZbEU8FXUSqtOtgedOtgR1b8tNxF/odRzxQQReRM5wuKWXqyxmUOcfs25Jp0lBNt6KB1tBF5Aw/fWsDa3OPMPfbw+jeVk23ooVm6CLyH15fncuCz3byva/24roL1HQrmqigi8i/Ze89xg+WrmNEzzY8fJ2abkUbFXQRAeBYYTGp89Np0aQhv1XTrajk6TtmZqPNLNvMcsxs+ln2u9jMSs3sm8GLKCKh9mXTrR0HTzJzwlDatVDTrWhUY0E3szhgFjAGGAhMMLOB1ez3BLAi2CFFJLR+/4/tLF+3l0e+1o8RvdR0K1p5maEPB3Kcc9ucc0XAImBcFfvdC7wG5Acxn4iEWNoXB/nVO5u4bmB7Jl+uplvRzEtB7wzsqnA/N/DYv5lZZ+BmYM7ZnsjMJptZmpmlFRQU1DariATZ/uOnmfpyBl1an8+v1XQr6nkp6FV9h12l+78BHnXOlZ7tiZxzc51zKc65lMTERI8RRSQUSssc9y1czeGTxbwwaRgtm6jpVrTz8saiXKBiE4cuQF6lfVKARYG/7gnAWDMrcc69EYyQIhJ8z7yfzSdbD/DUNwcxsFNLv+NIEHgp6KuAJDPrCewGxgMTK+7gnOv55W0z+yPwFxVzkcj11437mPXhVsZf3JVvpajpVqyosaA750rMbBrlZ6/EAfOcc1lmNiWw/azr5iISWXYdPMn3X1nDBZ1a8pObLvA7jgSRp14uzrnlwPJKj1VZyJ1zd5x7LBEJhcLiUlIXpAMwe9IwNd2KMWrOJVKP/O9bG1i/+ygvfieFbm2b+h1Hgkzv7RWpJ15Lz2Xh5ztJvaI31wxs73ccCQEVdJF6YNPeo8x4Yx0je7XloWv7+h1HQkQFXSTGHS0sJnV+Bi2bNOT5CWq6Fcu0hi4Sw5xzPLJkLTsPnmThf19CYovGfkeSENKfapEY9uLft/Nu1l5+MKY/w3u28TuOhJgKukiM+nz7QX717ibGXNiBuy/rWfMnSNRTQReJQfnHCpn2cgbd2jTlyW8OUtOtekIFXSTGlJSWcd/C1RwtLGb2bcm0UNOtekMviorEmKff38yn2w7y9LcG07+Dmm7VJ5qhi8SQ9zfsY/ZHW5kwvBvfGNbF7zgSZiroIjFi54GTPLh4DRd2bsmPbzzjKpFSD6igi8SAL5tunWemplv1mNbQRWLAT5ZlkZV3lHl3pNC1jZpu1VeaoYtEuSVpu1i0ahdTr+zNVf3VdKs+U0EXiWIb8o7y2BvrGdW7LQ9e28/vOOIzFXSRKHXkVDGpC9KJb1redCvuPL15qL7TGrpIFHLO8fCSTHYfOsWiyZeQ0FxNt0QzdJGoNHflNt7bsI8fjB1ASg813ZJyKugiUeazbQd4ckU211/Ukbsu7eF3HIkgKugiUST/aCHTFq6me5um/OobF6nplvwHraGLRImS0jKmLVzN8cIS5t89Qk235Awq6CJR4qn3svl8+0GevXUw/Tq08DuORCAtuYhEgfey9vK7j7cxaUQ3bh6qpltSNRV0kQi348AJHlqSyaAurXhcTbfkLFTQRSJYYXEpU+ZncJ4ZsyYm07iBmm5J9bSGLhLBHn9zPRv3HOUPd1yspltSI83QRSLUK6t2sjgtl3uv6sOV/dv5HUeigAq6SARav/sIP3ozi8v6JPDANX39jiNRQgVdJMIcOVXMPQsyaNusEc+NH6KmW+KZp4JuZqPNLNvMcsxsehXbJ5nZ2sDHJ2Y2OPhRRWJfWZnjocWZ5B0+xcyJybRV0y2phRoLupnFAbOAMcBAYIKZVT53ajvwVefcIOBnwNxgBxWpD363chsfbNzHjOsHMKx7a7/jSJTxMkMfDuQ457Y554qARcC4ijs45z5xzh0K3P0U0DsfRGrpX1sP8NSKTVw/qCN3jOrhdxyJQl4KemdgV4X7uYHHqnM38E5VG8xsspmlmVlaQUGB95QiMS7/aCH3LlxNz4RmPPGNQWq6JXXipaBX9ZPlqtzR7ErKC/qjVW13zs11zqU451ISExO9pxSJYcWlZUx7eTUnTpcw+7ZhNG+st4dI3Xj5yckFula43wXIq7yTmQ0CXgTGOOcOBCeeSOx7akU2n39xkOfGD6FvezXdkrrzMkNfBSSZWU8zawSMB5ZV3MHMugFLgW875zYHP6ZIbHp3/V7mrtzGty/pzrghZ1vJFKlZjTN051yJmU0DVgBxwDznXJaZTQlsnwM8DrQFXgis/ZU451JCF1sk+m3ff4KHl2QyuGs8j90wwO84EgM8LdY555YDyys9NqfC7e8C3w1uNJHYdaqolNT56cTFGbMmDlXTLQkKvfoiEmbOOR57Yz3Z+47xhzsupktrNd2S4NBb/0XCbNGqXbyWkcu9VyVxRT813ZLgUUEXCaP1u4/w42VZfCUpgfuvTvI7jsQYFXSRMDlyspgp89NJaNaI58YPVdMtCTqtoYuEQVmZ48HFa9h3tJDF3xtJm2aN/I4kMUgzdJEwmP3xVv66KZ/Hrh/I0G5quiWhoYIuEmKfbN3P0+9lc+PgTnxnZHe/40gMU0EXCaG9Rwq5b+FqeiU251dfv0hNtySktIYuEiLlTbcyOFlUyqLJyTRT0y0JMf2EiYTIE+9sIm3HIZ6fMJQ+7dR0S0JPSy4iIfDOuj28+I/t3D6yOzcN7uR3HKknVNBFgmxbwXEefnUtQ7rGM+P6yldrFAkdFXSRIDpZVELq/AwaxhmzJiXTqIF+xSR8tIYuEiTOOR57fT2b84/x0p3D6Rx/vt+RpJ7R9EEkSF7+fCdLV+/mgav7cnlfXWJRwk8FXSQI1uYe5n+XbeCrfRO596o+fseRekoFXeQcHT5ZROr8DBJbNOY3tw7hPDXdEp9oDV3kHJSVOb7/yhryjxWyZMooWqvplvhIM3SRc/DCRzl8mF3A4zcMZEjXeL/jSD2ngi5SR//M2c8z729m3JBO3HaJmm6J/1TQRergy6ZbvROb80s13ZIIoYIuUkvFpWVMfTmDwuJSZt82jKaN9FKURAb9JIrU0i+XbyJ9xyFmThxKn3bN/Y4j8m+aoYvUwttr9zDvn9u5Y1QPbhikplsSWVTQRTzaWnCcR17NJLlbPD8cO8DvOCJnUEEX8aC86VY6jRvGqemWRCytoYvUwDnHD5euY0v+cf581wg6tlLTLYlMmmaI1GD+Zzt5Y00eD17Tl8uSEvyOI1ItFXSRs8jcdZifvbWBK/slMvVKNd2SyKaCLlKNQyeKuGdBedOtZ9V0S6KA1tBFqlBW5vj+4jUUHDvNq6kjiW+qplsS+TzN0M1stJllm1mOmU2vYruZ2fOB7WvNLDn4UUXCZ+aHOXyUXcDjNw5kUJd4v+OIeFJjQTezOGAWMAYYCEwws8pXvh0DJAU+JgOzg5xTJGxWbi7g2Q82c/PQzkwa0c3vOCKeeVlyGQ7kOOe2AZjZImAcsKHCPuOAPznnHPCpmcWbWUfn3J5gB/54cwE//8uGmncUqQMHfLH/BEntmvOLmy9U0y2JKl4KemdgV4X7ucAID/t0Bv6joJvZZMpn8HTrVreZT/PGDUhqr/4ZEjoje7Xl3qv7qOmWRB0vP7FVTVFcHfbBOTcXmAuQkpJyxnYvhnVvzbDuw+ryqSIiMc3Li6K5QNcK97sAeXXYR0REQshLQV8FJJlZTzNrBIwHllXaZxnwncDZLpcAR0Kxfi4iItWrccnFOVdiZtOAFUAcMM85l2VmUwLb5wDLgbFADnASuDN0kUVEpCqeXvVxzi2nvGhXfGxOhdsOmBrcaCIiUht667+ISIxQQRcRiREq6CIiMUIFXUQkRlj565k+HNisANhRx09PAPYHMU6wRGouiNxsylU7ylU7sZiru3MusaoNvhX0c2Fmac65FL9zVBapuSBysylX7ShX7dS3XFpyERGJESroIiIxIloL+ly/A1QjUnNB5GZTrtpRrtqpV7micg1dRETOFK0zdBERqUQFXUQkRkRFQTezp8xsU+AC1K+bWXw1+531YtYhyPUtM8syszIzq/YUJDP7wszWmdkaM0uLoFzhHq82Zva+mW0J/Nu6mv3CMl6RevFzD7muMLMjgfFZY2aPhynXPDPLN7P11Wz3a7xqyuXXeHU1sw/NbGPg9/H+KvYJ7pg55yL+A7gOaBC4/QTwRBX7xAFbgV5AIyATGBjiXAOAfsBHQMpZ9vsCSAjjeNWYy6fxehKYHrg9varvY7jGy8vXT3lL6HcovyLXJcBnYfjeecl1BfCXcP08VTju5UAysL6a7WEfL4+5/BqvjkBy4HYLYHOof8aiYobunHvPOVcSuPsp5VdEquzfF7N2zhUBX17MOpS5NjrnskN5jLrwmCvs4xV4/pcCt18C/ivExzsbL1//vy9+7pz7FIg3s44RkMsXzrmVwMGz7OLHeHnJ5Qvn3B7nXEbg9jFgI+XXWq4oqGMWFQW9krso/4tWWXUXqo4EDnjPzNIDF8qOBH6MV3sXuJJV4N921ewXjvHy8vX7MUZejznSzDLN7B0zuyDEmbyK5N9BX8fLzHoAQ4HPKm0K6phFzGXNzewDoEMVm2Y4594M7DMDKAEWVPUUVTx2zudkesnlwaXOuTwzawe8b2abArMKP3OFfbxq8TRBH68qBO3i50Hm5ZgZlPfzOG5mY4E3gKQQ5/LCj/HywtfxMrPmwGvAA865o5U3V/EpdR6ziCnozrlrzrbdzG4HbgCudoHFp0pCcqHqmnJ5fI68wL/5ZvY65f+tPqcCFYRcYR8vM9tnZh2dc3sC/63Mr+Y5gj5eVYjUi5/XeMyKRcE5t9zMXjCzBOec302oIvJi8X6Ol5k1pLyYL3DOLa1il6COWVQsuZjZaOBR4Cbn3MlqdvNyMeuwM7NmZtbiy9uUv8Bb5avxYebHeC0Dbg/cvh04438SYRyvSL34eY25zKyDmVng9nDKf48PhDiXFxF5sXi/xitwzN8DG51zz1SzW3DHLNyv/Nblg/KLT+8C1gQ+5gQe7wQsr7DfWMpfSd5K+dJDqHPdTPlf2NPAPmBF5VyUn62QGfjIipRcPo1XW+CvwJbAv238HK+qvn5gCjAlcNuAWYHt6zjLmUxhzjUtMDaZlJ8kMCpMuRYCe4DiwM/X3REyXjXl8mu8LqN8+WRthdo1NpRjprf+i4jEiKhYchERkZqpoIuIxAgVdBGRGKGCLiISI1TQRURihAq6iEiMUEEXEYkR/x8PKQ7Cor1bHgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAh9ElEQVR4nO3deXyU5bn/8c8lgsoii4QdZBFRtAIhglvdsUiraFsVBUurHg4IrdpFaT3a9vTYWv2pdUEsPfVUD5sbVKq4IGqtWi0hhLBDRJYQJGHflyTX74889ExjQibMZJ5Zvu/Xa17zLPc9c+XOZK482/WYuyMiIpnrmLADEBGRcCkRiIhkOCUCEZEMp0QgIpLhlAhERDLcsWEHcDRat27tXbt2DTsMEZGUMn/+/M3unlV1eUomgq5du5Kbmxt2GCIiKcXM1la3XLuGREQynBKBiEiGUyIQEclwSgQiIhlOiUBEJMPFJRGY2bNmVmJmi2tYb2b2hJkVmlmBmWVHrBtsZiuCdePjEY+IiEQvXlsEfwIGH2H9lUDP4DEKmAhgZg2ACcH63sCNZtY7TjGJiEgU4nIdgbt/YGZdj9BkKPC8V9a8/sTMWphZe6ArUOjuqwHMbHrQdmk84hKR5LXvYDmv5BVRsnN/2KGklGuzO9GtdZO4vmaiLijrCKyPmC8KllW3fGB1L2Bmo6jcmqBLly71E6WIJERFhTN2ah7vLi8BwCzkgFJI9sktUzYRVPdr9iMs//JC90nAJICcnBzdTUckhU3862e8u7yE+77Rm1sv6BZ2OBkvUYmgCOgcMd8JKAYa1bBcRNLUR4WbeeTtFVzdpwO3nN817HCExJ0+Ogv4TnD20DnADnffCMwDeppZNzNrBAwL2opIGvpix35+MG0B3bOa8ptvfgXTPqGkEJctAjObBlwMtDazIuDnQEMAd38GmA0MAQqBvcD3gnVlZjYOeAtoADzr7kviEZOIJJdD5RWMm5rHvkPlvDAimybHpWTNy7QUr7OGbqxlvQNja1g3m8pEISJp7ME3lpO7dhtP3tiPU9o0CzsciaAri0Wk3s1etJE/fvg53z2vK1f16RB2OFKFEoGI1KvVpbu5++UC+nVpwc+GnB52OFINJQIRqTd7D5YxZnIejY49hgk3ZdPoWH3lJCMdrRGReuHu/MfMxaws2cXztwygQ4sTwg5JaqD0LCL1Yuo/1jFjwQbuvOxUvtrzS7fJlSSiRCAicVdQtJ1fzlrKRadm8f1LTwk7HKmFEoGIxNW2PQcZMzmPrGbH8bsb+nLMMbpoLNnpGIGIxE1FhXPXi/mU7jrAS6PPpWWTRmGHJFHQFoGIxM2E9wp5f0Up913Vmz6dW4QdjkRJiUBE4uLDVZt59J2VXNO3AyMGqlR8KlEiEJGYbdyxjx9MX0DPNk35tYrJpRwlAhGJycGyCsZOyePAoXImjuhP40Y69Jhq9BsTkZj85o1l5K3bzoSbsumR1TTscOQoaItARI7aawXF/M9Ha/je+V35+lntww5HjpISgYgclcKS3dzzcgHZXVrw0ytVTC6VKRGISJ3tOVDGmMnzOa5hAyYMVzG5VBeX356ZDTazFWZWaGbjq1n/EzPLDx6LzazczFoF69aY2aJgXW484hGR+uPu/GzmIgpLd/PEsH60b65icqku5oPFZtYAmAAMovIm9fPMbJa7Lz3cxt0fBh4O2l8F3OXuWyNe5hJ33xxrLCJS/yZ/spZX84v50aBTuaBn67DDkTiIxxbBAKDQ3Ve7+0FgOjD0CO1vBKbF4X1FJMHy12/nP19byiW9shh7iYrJpYt4JIKOwPqI+aJg2ZeYWWNgMPBKxGIH3jaz+WY2qqY3MbNRZpZrZrmlpaVxCFtE6mLbnoOMnZJH2xOP5zEVk0sr8UgE1X0avIa2VwEfVdktdL67ZwNXAmPN7MLqOrr7JHfPcfecrCzVNhdJpIoK584XKovJPT08mxaNVUwuncQjERQBnSPmOwHFNbQdRpXdQu5eHDyXADOp3NUkIknkyXcL+evKUn5+dW/O6tQi7HAkzuKRCOYBPc2sm5k1ovLLflbVRmbWHLgIeDViWRMza3Z4GrgCWByHmEQkTj5YWcrv5q7km/06ctMAFZNLRzGfNeTuZWY2DngLaAA86+5LzGx0sP6ZoOm1wNvuvieie1tgZlCg6lhgqru/GWtMIhIfxdv3ccf0BZzaphkPXKticunK3GvanZ+8cnJyPDdXlxyI1KeDZRVc//u/U1iym1njzqe76gilPDOb7+45VZer6JyIVOvXs5eRv347Tw/PVhJIc7ouXES+ZNbCYv708RpuvaAbQ76iYnLpTolARP5FYckuxr9SQM7JLRl/5WlhhyMJoEQgIv+050AZoyfn0bhRA566KZuGDfQVkQl0jEBEgMpicuNnLGJ16W4m3zqQds2PDzskSRClexEB4Pm/r+UvC4v50RW9OO8UFZPLJEoEIkLeum381+tLuey0Noy5qEfY4UiCKRGIZLitew4ybkoe7Zofz6PXq5hcJtIxApEMVl7h3DF9AZv3HGTGmPNo3rhh2CFJCLRFIJLBnpi7ir+t2swvrz6DMzs2DzscCYkSgUiGen9FCU+8u4pvZXdi2Nmda+8gaUuJQCQDbdi+jztfyKdX22b81zVnqphchlMiEMkwB8rKuX1KHuXlzsQR/TmhUYOwQ5KQ6WCxSIZ54PVlLFy/nWdGZNOtdZOww5EkoC0CkQzyav4Gnv/7Wv7tq90YfKaKyUklJQKRDLFy0y7Gv7KIs7u25O7BKiYn/ycuicDMBpvZCjMrNLPx1ay/2Mx2mFl+8Lg/2r4iErvdB8oYPXk+TY47VsXk5EtiPkZgZg2ACcAgKm9kP8/MZrn70ipN/+bu3zjKviJylNyde14pYM3mPUy57RzanqhicvKv4vFvwQCg0N1Xu/tBYDowNAF9RSQKf/p4Da8XbOQnXzuNc3ucFHY4koTikQg6Ausj5ouCZVWda2YLzewNMzujjn0xs1FmlmtmuaWlpXEIWyT9zV+7jQdeX8blp7dl9EXdww5HklQ8EkF1V6J4lfk84GR37wM8Cfy5Dn0rF7pPcvccd8/Jyso62lhFMsaW3QcYNzWPDi1O4JHr++iiMalRPBJBERB5fXonoDiygbvvdPfdwfRsoKGZtY6mr4jUXWUxuXy27DnI08OzaX6CislJzeKRCOYBPc2sm5k1AoYBsyIbmFk7C/4dMbMBwftuiaaviNTd4++s5MPCzfxqqIrJSe1iPmvI3cvMbBzwFtAAeNbdl5jZ6GD9M8C3gTFmVgbsA4a5uwPV9o01JpFM9t6KEp54t5Dr+nfihrO7hB2OpACr/D5OLTk5OZ6bmxt2GCJJZ/3WvVz11Ie0b34CM28/j+Mbqo6Q/B8zm+/uOVWX66oSkTRxoKycsVODYnLDs5UEJGoqOieSJv7zL0spKNrB72/uT1cVk5M60BaBSBqYuaCIKZ+u498v7M7XzmgXdjiSYpQIRFLcii928dMZixjQrRU/+VqvsMORFKREIJLCdu0/xJjJ82l6XEOeurEfx6qYnBwFHSMQSVGHi8mt3bqXqbcNpI2KyclR0r8PIinq2Y/WMHvRF9z9tV4M7K5icnL0lAhEUlDumq38ZvYyrujdllEXqpicxEaJQCTFbN59gLFT8+jY8gQevk7F5CR2OkYgkkIqi8ktYPveQ8y8fYCKyUlcKBGIpJDH5qzko8ItPPTts+jd4cSww5E0oV1DIini3eWbeOq9Qm7I6cz1OZ1r7yASJSUCkRSwfute7pyeT+/2J/LLoWfU3kGkDpQIRJLc/kPljJkyHweeGdFfxeQk7nSMQCTJ/fIvS1m8YSd/+E4OXU5qHHY4kobiskVgZoPNbIWZFZrZ+GrWDzezguDxsZn1iVi3xswWmVm+mekmAyIRXplfxLR/rGP0RT0Y1Ltt2OFImop5i8DMGgATgEFU3oN4npnNcvelEc0+By5y921mdiUwCRgYsf4Sd98caywi6WT5Fzu598+LOKd7K358xalhhyNpLB5bBAOAQndf7e4HgenA0MgG7v6xu28LZj+h8ib1IlKDnfsPMWZyHice35AnVExO6lk8Pl0dgfUR80XBsprcCrwRMe/A22Y238xG1dTJzEaZWa6Z5ZaWlsYUsEgyc3fufqmAdVv38tRN2bRppmJyUr/icbC4uuvbq70RspldQmUiuCBi8fnuXmxmbYA5Zrbc3T/40gu6T6JylxI5OTmpd6NlkSj98cPPeXPJF9w75HQGdGsVdjiSAeKxRVAERF7d0gkortrIzM4C/hsY6u5bDi939+LguQSYSeWuJpGMNG/NVn7zxnIGn9GO277aLexwJEPEIxHMA3qaWTczawQMA2ZFNjCzLsAM4GZ3XxmxvImZNTs8DVwBLI5DTCIpp3TXAcZOyaNzyxN46LqzVExOEibmXUPuXmZm44C3gAbAs+6+xMxGB+ufAe4HTgKeDj7cZe6eA7QFZgbLjgWmuvubscYkkmrKyiv4wbQF7Nx/iOduGcCJx6uYnCROXC4oc/fZwOwqy56JmL4NuK2afquBPlWXi2SaR+as5O+rt/D/ruvD6e1VTE4SS+ekiYRsztJNTHz/M24c0Jlv99eZ1ZJ4SgQiIVq3ZS8/fDGfMzueyM+vUjE5CYcSgUhIDheTM2DicBWTk/Co6JxISH4xawlLinfyx5E5dG6lYnISHm0RiITgpdz1TJ+3ntsv7sFlp6uYnIRLiUAkwZYW7+Q//ryY83qcxA8HqZichE+JQCSBdu4/xO1T5tOisYrJSfLQMQKRBHF3fvziQoq27WP6qHNo3fS4sEMSAbRFIJIwf/jbat5euonxV55GTlcVk5PkoUQgkgCfrt7Cb99cwZCvtOPWC1RMTpKLEoFIPSvZuZ9x0xZwcqvG/PZbKiYnyUfHCETqUVl5BeOmLWDX/kP8760DaKZicpKElAhE6tHDb6/gH59v5dHr+3BaOxWTk+SkXUMi9eTtJV/w+7+u5qaBXfhmtorJSfJSIhCpB2u37OFHLy3kKx2bc/83eocdjsgRKRGIxNn+Q+WMnpzHMWY8PTxbxeQk6cUlEZjZYDNbYWaFZja+mvVmZk8E6wvMLDvaviKp5v5XF7Ns404eu6GPislJSog5EZhZA2ACcCXQG7jRzKpuC18J9Aweo4CJdegrkjJenLeeF3OL+P6lp3DpaSomJ6khHlsEA4BCd1/t7geB6cDQKm2GAs97pU+AFmbWPsq+IilhSfEO7nt1MRec0po7L1cxOUkd8UgEHYH1EfNFwbJo2kTTFwAzG2VmuWaWW1paGnPQIvG0Y98hxkzOo2XjRjw+rC8NjtFFY5I64pEIqvvEe5RtoulbudB9krvnuHtOVlZWHUMUqT/uzo9fWkjx9n1MGJ7NSSomJykmHheUFQGdI+Y7AcVRtmkURV+RpPb7D1YzZ+km7v9Gb/qf3DLscETqLB5bBPOAnmbWzcwaAcOAWVXazAK+E5w9dA6ww903RtlXJGn9/bMtPPTmcr5+Vnu+d37XsMMROSoxbxG4e5mZjQPeAhoAz7r7EjMbHax/BpgNDAEKgb3A947UN9aYRBKhZOd+vj9tAV1bN1ExOUlpcak15O6zqfyyj1z2TMS0A2Oj7SuS7A6VVzBu6gL2HChj6r8NpOlxKtslqUufXpGj8PBbK/jHmq387oa+nNq2WdjhiMREJSZE6ujNxV8w6YPVjDinC9f0q/ZsZ5GUokQgUgefb97DT15aSJ9OzblPxeQkTSgRiERp38FyxkyeT4MGxoTh2Rx3rIrJSXrQMQKRKLg79726mBWbdvE/3z2bTi1VTE7Sh7YIRKLwwrz1vDy/iO9f2pOLe7UJOxyRuFIiEKnF4g07uH/WEr7aszV3XNYz7HBE4k6JQOQIduw9xJgp8zmpSSMeH9ZPxeQkLekYgUgNKiqcH76Yzxc79vPCv59LqyaNwg5JpF5oi0CkBhP/+hlzl5dw75DTye6iYnKSvpQIRKrx8WebeeTtFVzVpwMjz+sadjgi9UqJQKSKL3bs5wfTFtCtdRMe/OZXVExO0p6OEYhEqCwml8feg+VM+7dzaKJicpIB9CkXifDbN5aTu3Ybjw/rS08Vk5MMoV1DIoE3Fm3kvz/8nO+cezJD+6qYnGQOJQIRYHXpbn7ycgF9Orfg3q+fHnY4IgkVUyIws1ZmNsfMVgXPXzrHzsw6m9l7ZrbMzJaY2R0R635hZhvMLD94DIklHpGjse9gObdPyaNhA+NpFZOTDBTrFsF4YK679wTmBvNVlQE/cvfTgXOAsWYWWb/3MXfvGzx0pzJJKHfn3j8vYsWmXTw+rB8dW5wQdkgiCRdrIhgKPBdMPwdcU7WBu29097xgehewDNAOWEkK0/6xnhl5G7jjsp5ceGpW2OGIhCLWRNDW3TdC5Rc+cMSyjGbWFegHfBqxeJyZFZjZs9XtWoroO8rMcs0st7S0NMawRaCgaDu/mLWEC0/N4geXqpicZK5aE4GZvWNmi6t5DK3LG5lZU+AV4E533xksngj0APoCG4FHaurv7pPcPcfdc7Ky9J+bxGb73oOMmZxH66aN+N0NfTlGxeQkg9V6HYG7X17TOjPbZGbt3X2jmbUHSmpo15DKJDDF3WdEvPamiDZ/AF6rS/AiR6OiwrnrhXxKdu3npdHnqZicZLxYdw3NAkYG0yOBV6s2sMrr8/8ILHP3R6usax8xey2wOMZ4RGr19PuFvLeilPu+0Zu+nVuEHY5I6GJNBA8Cg8xsFTAomMfMOpjZ4TOAzgduBi6t5jTRh8xskZkVAJcAd8UYj8gRfVS4mUfnrOTqPh24+ZyTww5HJCnEVGLC3bcAl1WzvBgYEkx/CFS7A9bdb47l/UXq4nAxue5ZTfmNismJ/JNqDUlGOFRewdipeew7VM4LI7JVTE4kgv4aJCP8ZvZy5q/dxpM39uOUNiomJxJJtYYk7b1esJFnP/qc757Xlav6dAg7HJGko0Qgae2z0t3c/fJC+nVpwc+GqJicSHWUCCRt7T1YxpjJ8zmuYQOeHp5No2P1cRepjo4RSFpyd+6duZhVJbt5/pYBtG+uYnIiNdG/SJKWJn+6jpkLNnDX5afy1Z4qSSJyJEoEknYWrt/Or/6ylIt7ZTHuklPCDkck6SkRSFrZtucgt0/JI6vZcTx2vYrJiURDxwgkbVRUOHe9mE/prgO8NPpcWqqYnEhUtEUgaeOp9wp5f0Up913Vmz4qJicSNSUCSQt/W1XKY++s5Jq+HRgxsEvY4YikFCUCSXnF2/dxx/R8erZpyq9VTE6kzpQIJKUdLKssJnfgUDkTR/SncSMd9hKpK/3VSEr79exlLFi3nQk3ZdMjq2nY4YikpJi2CMyslZnNMbNVwXO1N583szXBDWjyzSy3rv1FqvOXhcX86eM13HJ+N75+VvvaO4hItWLdNTQemOvuPYG5wXxNLnH3vu6ec5T9Rf6psGQX418poP/JLfnpkNPCDkckpcWaCIYCzwXTzwHXJLi/ZKA9B8oYPTmP4xs2YMJN2TRsoENdIrGI9S+orbtvBAie29TQzoG3zWy+mY06iv6Y2SgzyzWz3NLS0hjDllTl7vx0xiJWl+7miRv70a758WGHJJLyaj1YbGbvAO2qWXVvHd7nfHcvNrM2wBwzW+7uH9ShP+4+CZgEkJOT43XpK+njfz9Zy6yFxfz4ilM5/5TWYYcjkhZqTQTufnlN68xsk5m1d/eNZtYeKKnhNYqD5xIzmwkMAD4AouovArBg3TZ+9dpSLj2tDbdfrGJyIvES666hWcDIYHok8GrVBmbWxMyaHZ4GrgAWR9tfBGDrnoOMnZJH2xOP59Hr+6iYnEgcxZoIHgQGmdkqYFAwj5l1MLPZQZu2wIdmthD4B/C6u795pP4ikcornDtfyGfz7oM8PTybFo1VTE4knmK6oMzdtwCXVbO8GBgSTK8G+tSlv0ikJ99dxQcrS3ng2jM5q1OLsMMRSTs6706S2l9XlvL43FV8s19HbhqgYnIi9UGJQJLWhu37uHP6Ak5t04wHrlUxOZH6okQgSelgWQVjp+RxqNyZOCKbExo1CDskkbSlonOSlB54fSn567czcXg23VVMTqReaYtAks6r+Rt47u9rue2Cblz5FRWTE6lvSgSSVFZt2sX4VxZxdteW3HOlismJJIISgSSN3QfKGD15Pk2Oa8BTKiYnkjA6RiBJwd0Z/0oBn2/ew+TbBtL2RBWTE0kU/cslSeG5j9fwWsFGfnRFL87roWJyIomkRCChy1u3jQdmL+Oy09ow5qIeYYcjknGUCCRUW3YfYOyUPNo1P55Hr++rYnIiIdAxAgnN4WJyW/YcZMaY82jeuGHYIYlkJG0RSGgen7uKv63azC+vPoMzOzYPOxyRjKVEIKF4f0UJT767im9ld2LY2Z3DDkckoykRSMIVbdvLnS/k06ttM/7rmjNVTE4kZEoEklAHysoZOyWP8nJn4oj+KiYnkgRiSgRm1srM5pjZquC5ZTVteplZfsRjp5ndGaz7hZltiFg3JJZ4JPn96rWlLCzawcPX9aFb6yZhhyMixL5FMB6Y6+49gbnB/L9w9xXu3tfd+wL9gb3AzIgmjx1e7+6zq/aX9PHnBRuY/Mk6Rl3YncFntgs7HBEJxJoIhgLPBdPPAdfU0v4y4DN3Xxvj+0qKWblpFz+dsYgBXVtx99d6hR2OiESINRG0dfeNAMFzm1raDwOmVVk2zswKzOzZ6nYtHWZmo8ws18xyS0tLY4taEur/iskdy1M39eNYFZMTSSq1/kWa2Ttmtriax9C6vJGZNQKuBl6KWDwR6AH0BTYCj9TU390nuXuOu+dkZWXV5a0lRO7OPS8XsGbzHp68sR9tVExOJOnUemWxu19e0zoz22Rm7d19o5m1B0qO8FJXAnnuvinitf85bWZ/AF6LLmxJFf/z0RpeX7SRewafxrk9Tgo7HBGpRqzb6LOAkcH0SODVI7S9kSq7hYLkcdi1wOIY45EkMn/tVn49exmXn96W0Rd1DzscEalBrIngQWCQma0CBgXzmFkHM/vnGUBm1jhYP6NK/4fMbJGZFQCXAHfFGI8kic27DzB2ygI6tDiBR67vo4vGRJJYTEXn3H0LlWcCVV1eDAyJmN8LfGm/gLvfHMv7S3Iqr3DumL6ArXuDYnInqJicSDLT6RsSd797ZyUfFW7hV0NVTE4kFSgRSFy9u3wTT75byHX9O3HD2V3CDkdEoqBEIHGzfute7nphIae3P5FfXXNm2OGISJSUCCQu9h8q5/YpeVS488yIbI5vqGJyIqlCdyiTuPjP15ayaMMOJt3cn5NPUjE5kVSiLQKJ2Yy8IqZ+uo5/v6g7V5yhYnIiqUaJQGKy/Iud/GzmIgZ2a8VPrlAxOZFUpEQgR23X/kOMmZxHs+Mb8qSKyYmkLB0jkKPi7tz9cgHrtu5l6m0DadNMxeREUpX+hZOj8scPP+eNxV9w99d6MbC7ismJpDIlAqmz3DVbefCN5VzRuy2jLlQxOZFUp0QgdbJ59wHGTs2jY8sTePg6FZMTSQc6RiBRK69wfjBtAdv3HmLm7QNUTE4kTSgRSNQenbOCjz/bwkPfPoveHU4MOxwRiRPtGpKovLN0ExPe+4xhZ3fm+pzOYYcjInGkRCC1WrdlLz98MZ8zOpzIL64+I+xwRCTOYkoEZnadmS0xswozyzlCu8FmtsLMCs1sfMTyVmY2x8xWBc8tY4lH4m//oXJunzofgInD+6uYnEgainWLYDHwTeCDmhqYWQNgApU3r+8N3GhmvYPV44G57t4TmBvMSxL55V+WsHjDTh69vi9dTmocdjgiUg9ivVXlMqC2UwgHAIXuvjpoOx0YCiwNni8O2j0HvA/cE0tMR/Lk3FXMWlhcXy+fdg6UVbBu617GXNyDy3u3DTscEakniThrqCOwPmK+CBgYTLd1940A7r7RzNrU9CJmNgoYBdCly9Hd+Sqr2XH0bNv0qPpmqhHndOGW87uFHYaI1KNaE4GZvQNUV1v4Xnd/NYr3qG5zwaPo968d3CcBkwBycnLq3B9g2IAuDBug2yeKiESqNRG4++UxvkcREHm+YSfg8P6ZTWbWPtgaaA+UxPheIiJSR4k4fXQe0NPMuplZI2AYMCtYNwsYGUyPBKLZwhARkTiK9fTRa82sCDgXeN3M3gqWdzCz2QDuXgaMA94ClgEvuvuS4CUeBAaZ2SpgUDAvIiIJZO5Htbs9VDk5OZ6bmxt2GCIiKcXM5rv7l6750pXFIiIZTolARCTDKRGIiGQ4JQIRkQyXkgeLzawUWHuU3VsDm+MYTrworrpRXHWjuOomWeOC2GI72d2zqi5MyUQQCzPLre6oedgUV90orrpRXHWTrHFB/cSmXUMiIhlOiUBEJMNlYiKYFHYANVBcdaO46kZx1U2yxgX1EFvGHSMQEZF/lYlbBCIiEkGJQEQkw6V9IjCzh81suZkVmNlMM2tRQ7vBZrbCzArNrN7vnWxm15nZEjOrMLMaTwUzszVmtsjM8s2s3ivt1SGuRI9XKzObY2argueWNbRLyHjV9vNbpSeC9QVmll1fsdQxrovNbEcwPvlmdn+C4nrWzErMbHEN68Mar9riSvh4mVlnM3vPzJYFf4t3VNMmvuPl7mn9AK4Ajg2mfwv8tpo2DYDPgO5AI2Ah0Lue4zod6EXlfZpzjtBuDdA6geNVa1whjddDwPhgenx1v8dEjVc0Pz8wBHiDyjv0nQN8moDfXTRxXQy8lqjPU8T7XghkA4trWJ/w8YoyroSPF9AeyA6mmwEr6/vzlfZbBO7+tlfeEwHgEyrvkFbVAKDQ3Ve7+0FgOjC0nuNa5u4r6vM9jkaUcSV8vILXfy6Yfg64pp7f70ii+fmHAs97pU+AFsFd+MKOKxTu/gGw9QhNwhivaOJKOHff6O55wfQuKu/j0rFKs7iOV9ongipuoTKLVtURWB8xX8SXBz4sDrxtZvPNbFTYwQTCGK+27r4RKv9QgDY1tEvEeEXz84cxRtG+57lmttDM3jCzM+o5pmgl899gaONlZl2BfsCnVVbFdbxqvWdxKjCzd4B21ay6191fDdrcC5QBU6p7iWqWxXxebTRxReF8dy82szbAHDNbHvwXE2ZcCR+vOrxM3MerGtH8/PUyRrWI5j3zqKw3s9vMhgB/BnrWc1zRCGO8ohHaeJlZU+AV4E5331l1dTVdjnq80iIRuPvlR1pvZiOBbwCXebCDrYoioHPEfCeguL7jivI1ioPnEjObSeXmf0xfbHGIK+HjZWabzKy9u28MNoFLaniNuI9XNaL5+etljGKNK/ILxd1nm9nTZtba3cMusBbGeNUqrPEys4ZUJoEp7j6jmiZxHa+03zVkZoOBe4Cr3X1vDc3mAT3NrJuZNQKGAbMSFWNNzKyJmTU7PE3lge9qz25IsDDGaxYwMpgeCXxpyyWB4xXNzz8L+E5wdsc5wI7Du7bqUa1xmVk7M7NgegCV3wFb6jmuaIQxXrUKY7yC9/sjsMzdH62hWXzHK5FHw8N4AIVU7kvLDx7PBMs7ALMj2g2h8uj8Z1TuIqnvuK6lMqsfADYBb1WNi8qzPxYGjyXJEldI43USMBdYFTy3CnO8qvv5gdHA6GDagAnB+kUc4cywBMc1LhibhVSePHFeguKaBmwEDgWfr1uTZLxqiyvh4wVcQOVunoKI760h9TleKjEhIpLh0n7XkIiIHJkSgYhIhlMiEBHJcEoEIiIZTolARCTDKRGIiGQ4JQIRkQz3/wE8Zpmuq67d9QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = torch.linspace(-2, 2, 500)\n",
    "hard_sig_output = hard_sigmoid(x)\n",
    "my_hard_sig_output = my_hard_sig(x)\n",
    "my_output = torch.clip(x, -1, 1)\n",
    "plt.plot(torch2numpy(x), torch2numpy(hard_sig_output))\n",
    "plt.show()\n",
    "plt.plot(torch2numpy(x), torch2numpy(my_hard_sig_output))\n",
    "plt.show()\n",
    "plt.plot(torch2numpy(x), torch2numpy(my_output))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "599ac54d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_activation(x):\n",
    "    return torch.clip(x, -1, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8c77e36f",
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = my_activation\n",
    "# architecture = [int(32*32*3), 500, 10]\n",
    "architecture = [int(32*32*3), 1000, 1000, 1000, 10]\n",
    "\n",
    "x,y = next(iter(train_loader))\n",
    "x = x.view(x.size(0),-1).to(device).T\n",
    "y_one_hot = F.one_hot(y, 10).to(device).T\n",
    "\n",
    "beta = 0.5\n",
    "lambda_ = 0.99999\n",
    "epsilon = 0.15\n",
    "one_over_epsilon = 1 / epsilon\n",
    "# lr_start = {'ff': np.array([0.2, 0.12, 0.065, 0.035]), 'fb': np.array([ np.nan, 0.085, 0.065, 0.03])}\n",
    "lr_start = {'ff': np.array([0.25 , 0.14 , 0.075, 0.045]), 'fb': np.array([  np.nan, 0.095, 0.075, 0.04 ])}\n",
    "neural_lr_start = 0.03\n",
    "neural_lr_stop = 0.001\n",
    "neural_lr_rule = \"constant\"\n",
    "neural_lr_decay_multiplier = 0.01\n",
    "neural_dynamic_iterations_nudged = 15\n",
    "neural_dynamic_iterations_free = 80\n",
    "hopfield_g = 0.1\n",
    "use_random_sign_beta = True\n",
    "use_three_phase = False\n",
    "weight_decay = False\n",
    "\n",
    "model = ContrastiveCorInfoMaxHopfield(architecture = architecture, lambda_ = lambda_, \n",
    "                                      epsilon = epsilon, activation = activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3448d1d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train accuracy :\t 0.12798\n"
     ]
    }
   ],
   "source": [
    "_ = evaluateContrastiveCorInfoMaxHopfield(model, train_loader, hopfield_g,\n",
    "                                          neural_lr_start, neural_lr_stop, neural_lr_rule, \n",
    "                                          neural_lr_decay_multiplier, neural_dynamic_iterations_free, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c6107ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.18it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 1, Train Accuracy : 0.35212, Test Accuracy : 0.3512\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.20it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 2, Train Accuracy : 0.38076, Test Accuracy : 0.3763\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.18it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 3, Train Accuracy : 0.40124, Test Accuracy : 0.3877\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.19it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 4, Train Accuracy : 0.41404, Test Accuracy : 0.4016\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.21it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 5, Train Accuracy : 0.42392, Test Accuracy : 0.411\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.20it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 6, Train Accuracy : 0.43582, Test Accuracy : 0.4215\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.17it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 7, Train Accuracy : 0.27182, Test Accuracy : 0.2657\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.21it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 8, Train Accuracy : 0.11692, Test Accuracy : 0.1185\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.24it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 9, Train Accuracy : 0.1301, Test Accuracy : 0.1309\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.22it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 10, Train Accuracy : 0.1128, Test Accuracy : 0.1116\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.24it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 11, Train Accuracy : 0.1495, Test Accuracy : 0.1502\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.19it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 12, Train Accuracy : 0.14968, Test Accuracy : 0.1498\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.20it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 13, Train Accuracy : 0.11072, Test Accuracy : 0.1146\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.20it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 14, Train Accuracy : 0.13052, Test Accuracy : 0.1302\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.21it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 15, Train Accuracy : 0.11414, Test Accuracy : 0.1145\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:45, 11.07it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 16, Train Accuracy : 0.09682, Test Accuracy : 0.1021\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.21it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 17, Train Accuracy : 0.11738, Test Accuracy : 0.116\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:43, 11.17it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 18, Train Accuracy : 0.12076, Test Accuracy : 0.1207\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.24it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 19, Train Accuracy : 0.13814, Test Accuracy : 0.1361\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.22it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 20, Train Accuracy : 0.12544, Test Accuracy : 0.1208\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.22it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 21, Train Accuracy : 0.14362, Test Accuracy : 0.1403\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2500it [03:42, 11.21it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch : 22, Train Accuracy : 0.1496, Test Accuracy : 0.1442\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1709it [02:33, 11.75it/s]"
     ]
    }
   ],
   "source": [
    "trn_acc_list = []\n",
    "tst_acc_list = []\n",
    "\n",
    "n_epochs = 25\n",
    "\n",
    "for epoch_ in range(n_epochs):\n",
    "    if epoch_ < 5:\n",
    "        lr = {'ff' : lr_start['ff'] * (0.9)**epoch_, 'fb' : lr_start['fb'] * (0.9)**epoch_}\n",
    "    else:\n",
    "        lr = {'ff' : lr_start['ff'] * (0.9)**epoch_, 'fb' : lr_start['fb'] * (0.9)**epoch_}\n",
    "    for idx, (x, y) in tqdm(enumerate(train_loader)):\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        x = x.view(x.size(0),-1).T\n",
    "        y_one_hot = F.one_hot(y, 10).to(device).T\n",
    "        take_debug_logs_ = (idx % 500 == 0)\n",
    "        if use_random_sign_beta:\n",
    "            rnd_sgn = 2*np.random.randint(2) - 1\n",
    "            beta = rnd_sgn*beta\n",
    "            \n",
    "        neurons = model.batch_step_hopfield( x, y_one_hot, hopfield_g, \n",
    "                                             lr, neural_lr_start, neural_lr_stop, neural_lr_rule, \n",
    "                                             neural_lr_decay_multiplier, neural_dynamic_iterations_free,\n",
    "                                             neural_dynamic_iterations_nudged, beta, \n",
    "                                             use_three_phase, take_debug_logs_, weight_decay)\n",
    "    \n",
    "    trn_acc = evaluateContrastiveCorInfoMaxHopfield(model, train_loader, hopfield_g, neural_lr_start, \n",
    "                                                    neural_lr_stop, neural_lr_rule, \n",
    "                                                    neural_lr_decay_multiplier, \n",
    "                                                    neural_dynamic_iterations_free, \n",
    "                                                    device, printing = False)\n",
    "    tst_acc = evaluateContrastiveCorInfoMaxHopfield(model, test_loader, hopfield_g, neural_lr_start, \n",
    "                                                    neural_lr_stop, neural_lr_rule, \n",
    "                                                    neural_lr_decay_multiplier, \n",
    "                                                    neural_dynamic_iterations_free, \n",
    "                                                    device, printing = False)\n",
    "    trn_acc_list.append(trn_acc)\n",
    "    tst_acc_list.append(tst_acc)\n",
    "    \n",
    "    print(\"Epoch : {}, Train Accuracy : {}, Test Accuracy : {}\".format(epoch_+1, trn_acc, tst_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d166d1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 10\n",
    "\n",
    "for epoch_ in range(10, n_epochs + 10):\n",
    "    if epoch_ < 15:\n",
    "        lr = {'ff' : lr_start['ff'] * (0.9)**epoch_, 'fb' : lr_start['fb'] * (0.9)**epoch_}\n",
    "    else:\n",
    "        lr = {'ff' : lr_start['ff'] * (0.9)**epoch_, 'fb' : lr_start['fb'] * (0.9)**epoch_}\n",
    "    for idx, (x, y) in tqdm(enumerate(train_loader)):\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        x = x.view(x.size(0),-1).T\n",
    "        y_one_hot = F.one_hot(y, 10).to(device).T\n",
    "        take_debug_logs_ = (idx % 500 == 0)\n",
    "        if use_random_sign_beta:\n",
    "            rnd_sgn = 2*np.random.randint(2) - 1\n",
    "            beta = rnd_sgn*beta\n",
    "            \n",
    "        neurons = model.batch_step_hopfield( x, y_one_hot, hopfield_g, \n",
    "                                             lr, neural_lr_start, neural_lr_stop, neural_lr_rule, \n",
    "                                             neural_lr_decay_multiplier, neural_dynamic_iterations_free,\n",
    "                                             neural_dynamic_iterations_nudged, beta, \n",
    "                                             use_three_phase, take_debug_logs_, weight_decay)\n",
    "    \n",
    "    trn_acc = evaluateContrastiveCorInfoMaxHopfield(model, train_loader, hopfield_g, neural_lr_start, \n",
    "                                                    neural_lr_stop, neural_lr_rule, \n",
    "                                                    neural_lr_decay_multiplier, \n",
    "                                                    neural_dynamic_iterations_free, \n",
    "                                                    device, printing = False)\n",
    "    tst_acc = evaluateContrastiveCorInfoMaxHopfield(model, test_loader, hopfield_g, neural_lr_start, \n",
    "                                                    neural_lr_stop, neural_lr_rule, \n",
    "                                                    neural_lr_decay_multiplier, \n",
    "                                                    neural_dynamic_iterations_free, \n",
    "                                                    device, printing = False)\n",
    "    trn_acc_list.append(trn_acc)\n",
    "    tst_acc_list.append(tst_acc)\n",
    "    \n",
    "    print(\"Epoch : {}, Train Accuracy : {}, Test Accuracy : {}\".format(epoch_+1, trn_acc, tst_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe20a069",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',\n",
    "                      title = 'Contrastive CorInfoMax Train Accuracy w.r.t. Epochs', \n",
    "                      figsize = (12,8), fontsize = 25, linewidth = 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93e28c34",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',\n",
    "                      title = 'Contrastive CorInfoMax Test Accuracy w.r.t. Epochs', \n",
    "                      figsize = (12,8), fontsize = 25, linewidth = 3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
