{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iJPwOE1mkbJd"
   },
   "outputs": [],
   "source": [
    "from DataGenerator import * \n",
    "from network import LocalNet\n",
    "from BrainNet import BrainNet\n",
    "from LocalNetBase import Options, UpdateScheme\n",
    "import matplotlib.pyplot as plt\n",
    "from train import train_local_rule, train_vanilla, train_given_rule\n",
    "import numpy as np\n",
    "from train import evaluate\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "np.set_printoptions(precision=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "J-klaK8NkbJj"
   },
   "outputs": [],
   "source": [
    "image_size = 28 # width and length\n",
    "labels = 10 \n",
    "image_pixels = image_size * image_size\n",
    "data_path = \"./mnist/\"\n",
    "train_data = np.loadtxt(data_path + \"mnist_train.csv\", \n",
    "                        delimiter=\",\")\n",
    "test_data = np.loadtxt(data_path + \"mnist_test.csv\", \n",
    "                       delimiter=\",\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "-QdA1HmMkbJm",
    "outputId": "23f91680-e41b-47ed-d2b2-011e4156f297"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "60000"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eAa6iOvakbJp"
   },
   "outputs": [],
   "source": [
    "dim = image_size * image_size\n",
    "fac = 0.99 / 255\n",
    "train_imgs = np.asfarray(train_data[:, 1:]) * fac + 0.01\n",
    "test_imgs = np.asfarray(test_data[:, 1:]) * fac + 0.01\n",
    "train_labels = np.asfarray(train_data[:, :1])\n",
    "test_labels = np.asfarray(test_data[:, :1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 193
    },
    "colab_type": "code",
    "id": "eAlIMqACkbJt",
    "outputId": "39cd9dd9-90b9-4462-e813-19138c0936c1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5923\n",
      "6742\n",
      "5958\n",
      "6131\n",
      "5842\n",
      "5421\n",
      "5918\n",
      "6265\n",
      "5851\n",
      "5949\n"
     ]
    }
   ],
   "source": [
    "data = 60000\n",
    "X = train_imgs[0:data]\n",
    "y = train_labels[0:data]\n",
    "X_test = test_imgs\n",
    "y_test = test_labels\n",
    "\n",
    "y = y.flatten().astype(int)\n",
    "y_test = y_test.flatten().astype(int)\n",
    "\n",
    "for i in range(10):\n",
    "    print(sum(y==i))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "21ukb_uqkbJ0"
   },
   "outputs": [],
   "source": [
    "options = Options(\n",
    "                 gd_output = False,\n",
    "                 gd_input = False,\n",
    "                 use_graph_rule = False, \n",
    "                 use_input_rule = False,\n",
    "                 use_output_rule = True,\n",
    "                 gd_output_rule = False,\n",
    "                 gd_graph_rule = False,\n",
    "                 additive_rule = True)\n",
    "\n",
    "scheme = UpdateScheme(\n",
    "            cross_entropy_loss = True, \n",
    "            mse_loss = False, \n",
    "            update_misclassified = True,\n",
    "            update_all_edges = False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5E-EE6aakbJ3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch #: 0\n",
      "Train on 0  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 5000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 10000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 15000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 20000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 25000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 30000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 35000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 40000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 45000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 50000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 55000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 59999  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Epoch #: 1\n",
      "Train on 0  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 5000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 10000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n",
      "Train on 15000  examples.\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Train Accuracy: 0.0987\n",
      "Acc of class 0 :1.0000\n",
      "Acc of class 1 :0.0000\n",
      "Acc of class 2 :0.0000\n",
      "Acc of class 3 :0.0000\n",
      "Acc of class 4 :0.0000\n",
      "Acc of class 5 :0.0000\n",
      "Acc of class 6 :0.0000\n",
      "Acc of class 7 :0.0000\n",
      "Acc of class 8 :0.0000\n",
      "Acc of class 9 :0.0000\n",
      "Test Accuracy: 0.0980\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-42879d2fe175>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mlocal_net\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_output_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mtrain_acc_learned\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_acc_learned\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_given_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlocal_net\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecay\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m.96\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mX_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/Projects/Learning Plasticity Rules/train.py\u001b[0m in \u001b[0;36mtrain_given_rule\u001b[0;34m(X, y, meta_model, decay, epochs, verbose, X_test, y_test)\u001b[0m\n\u001b[1;32m     24\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcontinue_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     25\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcontinue_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m             \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmeta_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontinue_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcontinue_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     28\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mverbose\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m5000\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/ann/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    530\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    531\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    533\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    534\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/Learning Plasticity Rules/LocalNetBase.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, inputs, labels, epochs, batch, continue_)\u001b[0m\n\u001b[1;32m    151\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mell\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    152\u001b[0m                 \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_pass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 153\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mell\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    154\u001b[0m         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_pass\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/Projects/Learning Plasticity Rules/LocalNetBase.py\u001b[0m in \u001b[0;36mupdate_weights\u001b[0;34m(self, probs, label)\u001b[0m\n\u001b[1;32m    133\u001b[0m                         \u001b[0mupdate_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_weights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_rule\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    134\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m                 \u001b[0mupdate_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_weights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mprediction\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_rule\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    136\u001b[0m                 \u001b[0mupdate_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_weights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlabel\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_rule\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivated\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    137\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "local_net = LocalNet(dim, labels, num_v = 1000, p = .5, cap = 500, rounds = 1, options = options, update_scheme = scheme)\n",
    "losses = []\n",
    "\n",
    "local_net.input_weights = torch.ones_like(local_net.input_weights).double()\n",
    "local_net.set_output_rule(torch.tensor([[-1, 1], [1, -1], [-1, 1], [1, -1]]))\n",
    "\n",
    "train_acc_learned, test_acc_learned = train_given_rule(X, y, local_net, decay = .96, epochs = 20, verbose = True, X_test = X_test, y_test=y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "y99rrYYAkbJ5"
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'train_acc_learned' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-12-844c72befcff>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_acc_learned\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmax\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_acc_learned\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'train_acc_learned' is not defined"
     ]
    }
   ],
   "source": [
    "print(max(train_acc_learned))\n",
    "print(max(test_acc_learned))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 295
    },
    "colab_type": "code",
    "id": "VRGIEuMUUBZv",
    "outputId": "11467148-56ae-45a5-b965-61c26fbaabf1"
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'train_acc_learned' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-13-0d39470fcfd0>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mt1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain_acc_learned\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m150\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      2\u001b[0m \u001b[0mt2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtest_acc_learned\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m150\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m5000\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m10000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt1\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"Train\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m5000\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;36m10000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt2\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mt2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m10\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m\"Test\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mxlabel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"# Training examples (starting at 5000)\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mNameError\u001b[0m: name 'train_acc_learned' is not defined"
     ]
    }
   ],
   "source": [
    "t1 = train_acc_learned[1:150]\n",
    "t2 = test_acc_learned[1:150]\n",
    "plt.plot(5000 * np.arange(len(t1)-10) + 10000, [np.mean(t1[i:i+10]) for i in range(len(t1)-10)], label=\"Train\")\n",
    "plt.plot(5000 * np.arange(len(t2)-10) + 10000, [np.mean(t2[i:i+10]) for i in range(len(t2)-10)], label=\"Test\")\n",
    "plt.xlabel(\"# Training examples (starting at 5000)\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.title(\"MNIST Brain Net (T=1) Decaying Step Size\")\n",
    "plt.legend()\n",
    "plt.savefig(\"MNIST_full_data_96_acc_n_1000_decaying_step.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "E6kvIMKmvaGf",
    "outputId": "2cd4f543-e884-455b-d325-3b127dc80c57"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.01\n"
     ]
    }
   ],
   "source": [
    "print(local_net.step_sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 336
    },
    "colab_type": "code",
    "id": "TDc4AEb8gt59",
    "outputId": "c21e740b-fe7d-4655-9390-89e7d80a66a9"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3sAAAE/CAYAAAD/m9qwAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df7hld10f+vfHDKACEiDDD5Pg0MsAhjzW0LkhXvyBBiFJbzNoiU/ylBJo6vQPolV4rKG2wIXbXkF7I/SJ2GgooRVCpCpzNTUXQxC0JM0gFEkwD3MjJWMiGSBEaQoh+rl/7DWwOTmTOXPOnH1Ovuf1ep7znL2+67vX/sx39uw1772+a63q7gAAADCWb9roAgAAADj2hD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLDHUKqqq+ppG13HiKrq+6rq1hX2fV5VHVjvmgAAODxhj02lqj5dVf+zqr5UVX9RVW+vqkdtdF1JUlWPqKr/q6o+M9X4qar6maqqFT5/xxRGtx2jeh50e1X15Gn9E+fafu4wbb93pNfr7g919zOOUe1vr6r/81hsCwCA5Ql7bEZ/r7sfleS7k5yW5NUbXM8hv5HkzCTnJHl0kn+YZE+SN29kUYfT3Xcm2Z/k++eavz/Jny7T9sEFlgYAwAIIe2xa3f0XSa7NLPQlSarqA1X1j+eWX1ZVf7jc86cjcb84HYn7bFX9SlV9y2pqqaozk7wgyd/v7k909/3dfUOSlyR5xaGpo9ORyefPPe91VfUfp8VDgeqL05HL75nq/6Oq+rdVdU9V/en0WlnN9pYp/YOZgl1VHZdZeH7zkrbvObStBxuzpVMzq+rZVfXRqvqrqvqNqnr30qN1VfWqqrqrqu6sqpdPbXuS/IMk/2yq+/+Z2n+2qv582t6t8+MAAMDRE/bYtKrqpCRnZ3Z0ajXemOTpmYXFpyU5MclrVrmtH05yY3ffPt/Y3TcmOZDZEb8jOXQ07fjuflR3f3hafk6S25KckOS1SX6zqh63hu3N++Bcv9MyO6p33ZK2hyX5r9Pyisasqh6e5LeSvD3J45K8K8mPLOn2pCSPmbZxUZLLquqx3X15kl9P8qap7r9XVc9IcnGS/7W7H53khUk+vYIxAADgMIQ9NqPfrqq/SnJ7krsyC0BHZTqP7seT/HR3f6G7/yrJv05y/iprOiHJnYdZd+e0frXuSvJL3f3V7n53kluT/N01bG/eHyQ5taoem+T7knyouz+V5IS5thu6+76jHLMzkmxL8pap7t/M1wPjIV9N8vpp/TVJvpTkcOf8/XWSRyQ5paoe1t2f7u7/b01/cgCALU7YYzN60XR053lJnpnVBantSb41yUeq6otV9cUkvze1P0BV3TxNKfxSVX3fMl0+l+TJh3mtJ0/rV+vPu7vnlv97km9fw/a+prs/ndmRx+/N7Gjeh6ZVH55rOzQd9GjG7NuXqfv2JX0+3933zy3fm2TZi+109/4kP5XkdUnuqqqrquqYjAEAwFYl7LFpdfcfZDZN8Bfnmv9HZoHkkCcd5umfS/I/kzyru4+ffh4zXfhludd61jSl8FHd/aFluvx+kudU1cnzjVV1epKTk7x/BfXNB6N5Jy65oudTktyxhu0t9aHMQt33JPkvS9q+N18Pe0czZncuU/fJy/Q7nAfU3t3v7O7vTfId0/o3HsX2AABYQthjs/ulJD9cVYcu0vKxJD9aVd86XRTlouWe1N1/k+RXk1xaVU9Ikqo6sapeuJoiuvv3MzvX7T9V1bOq6riqOiOzc8/eOk2NPFTf+VX1sKraleTFc5s5mORvkvytJZt/QpKfnJ5zXpLvTHLNGra31AeTvDTJHd39l1PbH05tj8nsKN/RjtmHM5t6eXFVbauq3UlOP0Id8z47X3dVPaOqfqiqHpHky5mFzr8+iu0BALCEsMem1t0Hk7wjyb+cmi5Ncl9mYeHKzMLW4fxsZhd3uaGq/jKzo3NruU/c309yfWZTG7+U5D8muSLJT8z1+ZdJ/pckdyf5P5K8c+7Pcm+Sf5Xkj6ZpkmdMq25MsjOzI2v/KsmLu/vza9jeUn+QWaCcv2rpx5J8S5KPTNs5ZEVj1t33JfnRzML2FzO7KunvJPnKYWpY6orMzs/7YlX9dmbn6/38NAZ/MdX7z1e4LQAAllHfeMoNsEhV9bIk/3iavviQVlU3JvmV7v73G10LAACO7AGrVFU/UFVPmqZxXpjkuzI76gkAwCYg7AGr9Ywk/y3JPUleldn008PdngK2rKp6W1XdVVWfOMz6qqq3VNX+qvp4VT170TUCMCbTOAFgHVXV92d2nu87uvvUZdafk9m5v+ckeU6SN3f3cxZbJQAjcmQPANZRd38wyRcepMvuzIJgd/cNSY6vqsPd1xMAVkzYA4CNdWKS2+eWD0xtALAm2za6gAdzwgkn9I4dOza6DAAW4CMf+cjnunv7RtexAWqZtmXPsaiqPUn2JMkjH/nIv/PMZz5zPesCYBNYy/5xU4e9HTt2ZN++fRtdBgALUFX/faNr2CAHkpw8t3xSkjuW69jdlye5PEl27drV9pEA41vL/tE0TgDYWHuTvHS6KucZSe5xZVsAjoVNfWQPAB7qqupdSZ6X5ISqOpDktUkeliTd/StJrsnsSpz7k9yb5OUbUykAoxH2AGAddfcFR1jfSV6xoHIA2EJM4wQAABiQsAcAADAgYQ8AAGBAwh4AAMCAjhj2quptVXVXVX1iru1xVfW+qvrU9PuxU3tV1Vuqan9Vfbyqnj33nAun/p+qqgvX548DAABAsrIje29PctaStkuSXNfdO5NcNy0nydlJdk4/e5K8NZmFw8wuNf2cJKcnee2hgAgAAMCxd8Sw190fTPKFJc27k1w5Pb4yyYvm2t/RMzckOb6qnpzkhUne191f6O67k7wvDwyQAAAAHCOrPWfvid19Z5JMv58wtZ+Y5Pa5fgemtsO1AwAAsA6O9QVaapm2fpD2B26gak9V7auqfQcPHjymxcFWsuOS393oEgAA2ECrDXufnaZnZvp919R+IMnJc/1OSnLHg7Q/QHdf3t27unvX9u3bV1keAADA1rbasLc3yaEral6Y5L1z7S+drsp5RpJ7pmme1yZ5QVU9drowywumNgAAANbBtiN1qKp3JXlekhOq6kBmV9X8+SRXV9VFST6T5Lyp+zVJzkmyP8m9SV6eJN39hap6Q5Kbpn6v7+6lF30BAADgGDli2OvuCw6z6sxl+naSVxxmO29L8rajqg4AAIBVOdYXaAEAAGATEPYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAEJewAAAAMS9gAAAAYk7AEAAAxI2AMAABiQsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AGCdVdVZVXVrVe2vqkuWWf+Uqrq+qj5aVR+vqnM2ok4AxiLsAcA6qqrjklyW5OwkpyS5oKpOWdLtXyS5urtPS3J+kl9ebJUAjEjYA4D1dXqS/d19W3ffl+SqJLuX9Okk3zY9fkySOxZYHwCDEvYAYH2dmOT2ueUDU9u81yV5SVUdSHJNkp9YbkNVtaeq9lXVvoMHD65HrQAMRNgDgPVVy7T1kuULkry9u09Kck6S/1BVD9hHd/fl3b2ru3dt3759HUoFYCTCHgCsrwNJTp5bPikPnKZ5UZKrk6S7P5zkm5OcsJDqABiWsAcA6+umJDur6qlV9fDMLsCyd0mfzyQ5M0mq6jszC3vmaQKwJsIeAKyj7r4/ycVJrk3yycyuunlzVb2+qs6dur0qyY9X1X9L8q4kL+vupVM9AeCobNvoAgBgdN19TWYXXplve83c41uSPHfRdQEwNkf2AAAABrSmsFdVP11VN1fVJ6rqXVX1zdM5CTdW1aeq6t3T+QmpqkdMy/un9TuOxR8AAACAB1p12KuqE5P8ZJJd3X1qkuMyO+n8jUku7e6dSe7O7ApjmX7f3d1PS3Lp1A8AAIB1sNZpnNuSfEtVbUvyrUnuTPJDSd4zrb8yyYumx7un5Uzrz6yq5e49BAAAwBqtOux1958n+cXMLhd9Z5J7knwkyRenK48ls3sLnTg9PjHJ7dNz75/6P361rw8AAMDhrWUa52MzO1r31CTfnuSRSc5epuuhS0cvdxTvAZeVrqo9VbWvqvYdPOgWQwAAAKuxlmmcz0/yZ919sLu/muQ3k/xvSY6fpnUmyUlJ7pgeH0hycpJM6x+T5AtLN9rdl3f3ru7etX379jWUBwAAsHWtJex9JskZVfWt07l3Zya5Jcn1SV489bkwyXunx3un5Uzr3++GsQAAAOtjLefs3ZjZhVb+OMmfTNu6PMnPJnllVe3P7Jy8K6anXJHk8VP7K5Ncsoa6AQAAeBDbjtzl8Lr7tUleu6T5tiSnL9P3y0nOW8vrAQAAsDJrvfUCAAAAm5CwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAEJewAAAAMS9gAAAAYk7AEAAAxI2AMAABiQsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAEJewAAAAMS9gAAAAYk7AEAAAxI2AMAABiQsAcAADAgYQ8AAGBAwh4ArLOqOquqbq2q/VV1yWH6/FhV3VJVN1fVOxddIwDj2bbRBQDAyKrquCSXJfnhJAeS3FRVe7v7lrk+O5O8Oslzu/vuqnrCxlQLwEgc2QOA9XV6kv3dfVt335fkqiS7l/T58SSXdffdSdLddy24RgAGJOwBwPo6Mcntc8sHprZ5T0/y9Kr6o6q6oarOWlh1AAzLNE4AWF+1TFsvWd6WZGeS5yU5KcmHqurU7v7iN2yoak+SPUnylKc85dhXCsBQHNkDgPV1IMnJc8snJbljmT7v7e6vdvefJbk1s/D3Dbr78u7e1d27tm/fvm4FAzAGYQ8A1tdNSXZW1VOr6uFJzk+yd0mf307yg0lSVSdkNq3ztoVWCcBwhD0AWEfdfX+Si5Ncm+STSa7u7pur6vVVde7U7dokn6+qW5Jcn+RnuvvzG1MxAKNY0zl7VXV8kl9Lcmpm5x/8o8ymnrw7yY4kn07yY9NlpCvJm5Ock+TeJC/r7j9ey+sDwENBd1+T5Jolba+Ze9xJXjn9AMAxsdYje29O8nvd/cwkfzuzbywvSXJdd+9Mct20nCRnZ3b+wc7MTi5/6xpfGwAAgMNYddirqm9L8v1JrkiS7r5vumrY7iRXTt2uTPKi6fHuJO/omRuSHF9VT1515QAAABzWWo7s/a0kB5P8+6r6aFX9WlU9MskTu/vOJJl+P2Hqv5L7DKWq9lTVvqrad/DgwTWUBwAAsHWtJextS/LsJG/t7tOS/I98fcrmclZynyGXlQYAADgG1hL2DiQ50N03TsvvySz8ffbQ9Mzp911z/Y90nyEAAACOgVWHve7+iyS3V9UzpqYzk9yS2b2DLpzaLkzy3unx3iQvrZkzktxzaLonAAAAx9aabr2Q5CeS/Pp0k9jbkrw8swB5dVVdlOQzSc6b+l6T2W0X9md264WXr/G1AQAAOIw1hb3u/liSXcusOnOZvp3kFWt5PQAAAFZmrffZAwAAYBMS9gAAAAYk7AEAAAxI2AMAABiQsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwB6u045Lf3egSjspDrV4AANZG2AMAABiQsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYA4B1VlVnVdWtVbW/qi55kH4vrqquql2LrA+AMQl7ALCOquq4JJclOTvJKUkuqKpTlun36CQ/meTGxVYIwKiEPQBYX6cn2d/dt3X3fUmuSrJ7mX5vSPKmJF9eZHEAjEvYA4D1dWKS2+eWD0xtX1NVpyU5ubt/Z5GFATA2YQ8A1lct09ZfW1n1TUkuTfKqI26oak9V7auqfQcPHjyGJQIwojWHvao6rqo+WlW/My0/tapurKpPVdW7q+rhU/sjpuX90/oda31tAHgIOJDk5Lnlk5LcMbf86CSnJvlAVX06yRlJ9i53kZbuvry7d3X3ru3bt69jyQCM4Fgc2funST45t/zGJJd2984kdye5aGq/KMnd3f20zL7BfOMxeG0A2OxuSrJz+jL04UnOT7L30Mruvqe7T+juHd29I8kNSc7t7n0bUy4Ao1hT2Kuqk5L83SS/Ni1Xkh9K8p6py5VJXjQ93j0tZ1p/5tQfAIbV3fcnuTjJtZl9OXp1d99cVa+vqnM3tjoARrZtjc//pST/LLMpKEny+CRfnHZsyTeehP61E9S7+/6qumfq/7k11gAAm1p3X5PkmiVtrzlM3+ctoiYAxrfqI3tV9b8nuau7PzLfvEzXXsG6+e06+RwAAGCN1jKN87lJzp1OJr8qs+mbv5Tk+Ko6dMRw/iT0r52gPq1/TJIvLN2ok88BAADWbtVhr7tf3d0nTSeTn5/k/d39D5Jcn+TFU7cLk7x3erx3Ws60/v3d/YAjewAAAKzdetxn72eTvLKq9md2Tt4VU/sVSR4/tb8yySXr8NoAAABk7RdoSZJ09weSfGB6fFuS05fp8+Uk5x2L1wMAAODBrceRPQAAADaYsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAEJewAAAAMS9gAAAAYk7AEAAAxI2AMAABiQsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwDrrKrOqqpbq2p/VV2yzPpXVtUtVfXxqrquqr5jI+oEYCzCHgCso6o6LsllSc5OckqSC6rqlCXdPppkV3d/V5L3JHnTYqsEYESrDntVdXJVXV9Vn6yqm6vqn07tj6uq91XVp6bfj53aq6reMn2r+fGqevax+kMAwCZ2epL93X1bd9+X5Koku+c7dPf13X3vtHhDkpMWXCMAA1rLkb37k7yqu78zyRlJXjF9U3lJkuu6e2eS66blZPaN5s7pZ0+St67htQHgoeLEJLfPLR+Y2g7noiT/ebkVVbWnqvZV1b6DBw8ewxIBGNGqw15339ndfzw9/qskn8xs57U7yZVTtyuTvGh6vDvJO3rmhiTHV9WTV105ADw01DJtvWzHqpck2ZXkF5Zb392Xd/eu7t61ffv2Y1giACM6JufsVdWOJKcluTHJE7v7zmQWCJM8Yep2tN9sAsAIDiQ5eW75pCR3LO1UVc9P8nNJzu3uryyoNgAGtuawV1WPSvKfkvxUd//lg3Vdpu0B32yaogLAYG5KsrOqnlpVD09yfpK98x2q6rQk/y6zoHfXBtQIwIDWFPaq6mGZBb1f7+7fnJo/e2h65vT70E5rRd9smqICwEi6+/4kFye5NrNTHq7u7pur6vVVde7U7ReSPCrJb1TVx6pq72E2BwArtm21T6yqSnJFkk929/89t2pvkguT/Pz0+71z7RdX1VVJnpPknkPTPQFgZN19TZJrlrS9Zu7x8xdeFADDW3XYS/LcJP8wyZ9U1cemtn+eWci7uqouSvKZJOdN665Jck6S/UnuTfLyNbw2AAAAD2LVYa+7/zDLn4eXJGcu07+TvGK1rwcAAMDKHZOrcQIAALC5CHsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAEJewAAAAMS9gAAAAYk7AEAAAxI2AMAABiQsAcAADAgYQ8AAGBAwh4AAMCAhD0AAIABCXsAAAADEvYAAAAGJOwBAAAMSNgDAAAYkLAHAAAwIGEPAABgQMIeAADAgIQ9AACAAQl7AAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAEJewAAAAMS9gAAAAa08LBXVWdV1a1Vtb+qLln06wPAoh1p31dVj6iqd0/rb6yqHYuvEoDRLDTsVdVxSS5LcnaSU5JcUFWnLLIGAFikFe77Lkpyd3c/LcmlSd642CoBGNGij+ydnmR/d9/W3fcluSrJ7gXXAACLtJJ93+4kV06P35PkzKqqBdYIwIAWHfZOTHL73PKBqQ0ARrWSfd/X+nT3/UnuSfL4hVQHwLC2Lfj1lvuWsr+hQ9WeJHumxS9V1a3rXtXGOSHJ5za6iE3iITkWtX4TrY7JeCytbx3rXU8PyffGOhl9LL5jowtYJ0fc962wz9J95Feq6hNrrG0rGf3fz7FmvI6O8To6xuvoPGO1T1x02DuQ5OS55ZOS3DHfobsvT3L5IovaKFW1r7t3bXQdm4Gx+EbG4+uMxdcZi4esI+775vocqKptSR6T5AtLNzS/j/R+ODrG6+gYr6NjvI6O8To6VbVvtc9d9DTOm5LsrKqnVtXDk5yfZO+CawCARVrJvm9vkgunxy9O8v7ufsCRPQA4Ggs9stfd91fVxUmuTXJckrd1982LrAEAFulw+76qen2Sfd29N8kVSf5DVe3P7Ije+RtXMQCjWPQ0znT3NUmuWfTrblJbYrrqChmLb2Q8vs5YfJ2xeIhabt/X3a+Ze/zlJOcd5Wa9H46O8To6xuvoGK+jY7yOzqrHq8wSAQAAGM+iz9kDAABgAYS9Bauq86rq5qr6m6ratWTdq6tqf1XdWlUv3KgaN0pVva6q/ryqPjb9nLPRNS1aVZ01/f3vr6pLNrqejVZVn66qP5neD6u+EtVDUVW9rarumr+0flU9rqreV1Wfmn4/diNrZDGO9LlQVY+oqndP62+sqh2Lr3LzWMF4vbKqbqmqj1fVdVU16i0/VmSl+52qenFV9dL/u2w1Kxmvqvqx6T12c1W9c9E1biYr+Pf4lKq6vqo+Ov2b3HL/95u33L5/yfqqqrdM4/nxqnr2kbYp7C3eJ5L8aJIPzjdW1SmZnZD/rCRnJfnlqjpu8eVtuEu7+7unny11buf0931ZkrOTnJLkgul9sdX94PR+2Gr/wXh7Zp8F8y5Jcl1370xy3bTMwFb4uXBRkru7+2lJLk3y0Lyj5jGwwvH6aJJd3f1dSd6T5E2LrXLzWOl+p6oeneQnk9y42Ao3l5WMV1XtTPLqJM/t7mcl+amFF7pJrPD99S+SXN3dp2X2/+BfXmyVm87b88B9/7yzk+ycfvYkeeuRNijsLVh3f7K7l7tR/O4kV3X3V7r7z5LsT3L6Yqtjg52eZH9339bd9yW5KrP3BVtQd38wD7zP2u4kV06Pr0zyooUWxUZYyefC/PviPUnOrKrlbtK+FRxxvLr7+u6+d1q8IbP7Hm5VK93vvCGzUPzlRRa3Ca1kvH48yWXdfXeSdPddC65xM1nJeHWSb5sePyYPvAfplnKYff+83Une0TM3JDm+qp78YNsU9jaPE5PcPrd8YGrbai6eDku/bQtOUfMeeKBO8v9W1Ueqas9GF7MJPLG770yS6fcTNrge1t9KPhe+1qe7709yT5LHL6S6zedoP0cvSvKf17Wize2I41VVpyU5ubt/Z5GFbVIreX89PcnTq+qPquqGqnqwozSjW8l4vS7JS6rqQGZXLP6JxZT2kHXU/1dc+K0XtoKq+v0kT1pm1c9193sP97Rl2oa7VOqDjU1mh6LfkNmf+w1J/k2Sf7S46jbclngPHKXndvcdVfWEJO+rqj+dvvWCrWIlnws+O75uxWNRVS9JsivJD6xrRZvbg45XVX1TZlODX7aogja5lby/tmU2xe55mR01/lBVndrdX1zn2jajlYzXBUne3t3/pqq+J7P7jZ7a3X+z/uU9JB31572wtw66+/mreNqBJCfPLZ+UAQ9lr3RsqupXk2y1bxG3xHvgaHT3HdPvu6rqtzKbErKVw95nq+rJ3X3nNG1jK08P2ipW8rlwqM+BqtqW2VSoB5sGNLIVfY5W1fMz+5LxB7r7KwuqbTM60ng9OsmpST4wzQx+UpK9VXVud2+pi2ZNVvrv8Ybu/mqSP6uqWzMLfzctpsRNZSXjdVGmc9S6+8NV9c1JToj92+Ec9f8VTePcPPYmOX+6qtpTM/tg+K8bXNNCLZlz/COZXcxmK7kpyc6qempVPTyzE5X3bnBNG6aqHjldFCBV9cgkL8jWe08stTfJhdPjC5McbqYA41jJ58L8++LFSd7fW/cmukccr2la4r9Lcu4WP58qOcJ4dfc93X1Cd+/o7h2ZneO4VYNesrJ/j7+d5AeTpKpOyGxa520LrXLzWMl4fSbJmUlSVd+Z5JuTHFxolQ8te5O8dLoq5xlJ7jl0esfhOLK3YFX1I0n+bZLtSX63qj7W3S/s7pur6uoktyS5P8kruvuvN7LWDfCmqvruzA5HfzrJP9nYcharu++vqouTXJvkuCRv6+6bN7isjfTEJL81fZu8Lck7u/v3Nrakxamqd2U2DeiE6VyG1yb5+SRXV9VFme0gz9u4ClmEw30uVNXrk+zr7r1Jrshs6tP+zI7onb9xFW+sFY7XLyR5VJLfmD5fPtPd525Y0VVdwmsAAACFSURBVBtohePFZIXjdW2SF1TVLUn+OsnPdPfnN67qjbPC8XpVkl+tqp/O7P9/L9vCX1Ydbt//sCTp7l/J7LzGczK7kOO9SV5+xG1u4fEEAAAYlmmcAAAAAxL2AAAABiTsAQAADEjYAwAAGJCwBwAAMCBhDwAAYEDCHgAAwICEPQAAgAH9/9bJCgz00GAbAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1080x360 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Distribution of weights \n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))\n",
    "\n",
    "ax1.hist(local_net.output_weights.flatten(), bins = 1000)\n",
    "ax1.set_title(\"Rule - Output Weights\")\n",
    "\n",
    "# ax2.hist(gd_net.output_weights.detach().numpy().flatten(), bins = 1000)\n",
    "# ax2.set_title(\"GD - Output Weights\")\n",
    "# plt.savefig(\"Histogram - rule training\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "42zTrFK6aaMS"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating\n",
      "Acc of class 0 :0.7147\n",
      "Acc of class 1 :0.7128\n",
      "Acc of class 2 :0.4822\n",
      "Acc of class 3 :0.3508\n",
      "Acc of class 4 :0.4281\n",
      "Acc of class 5 :0.3343\n",
      "Acc of class 6 :0.6088\n",
      "Acc of class 7 :0.6129\n",
      "Acc of class 8 :0.4979\n",
      "Acc of class 9 :0.4443\n",
      "epoch  1 Loss: 731.4810 Accuracy: 0.5229\n",
      "Finished Training\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[731.4810106019102]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train Net with GD\n",
    "gd_net = BrainNet(dim, labels, num_v = 1000, p = .5, cap = 500, rounds = 0, gd_output = True)\n",
    "gd_net.input_layer = local_net.input_layer \n",
    "gd_net.input_weights = local_net.input_weights\n",
    "train_vanilla(X, y, gd_net, epochs = 100, batch = 100, lr = 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 395
    },
    "colab_type": "code",
    "id": "bZBPZVxvkbJ8",
    "outputId": "eb531c7a-9821-4a59-acd7-98bd2a3477be"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Label       1\n",
      "Original Prediction: 9\n",
      "Target Label         0\n"
     ]
    }
   ],
   "source": [
    "input_label = 1\n",
    "target_label = 0\n",
    "image_num = 0 # (image_num)^th image in the list that appears with label 'input_label'\n",
    "\n",
    "# don't change these\n",
    "index = 0 \n",
    "c = 0\n",
    "for i in range(data): \n",
    "    if y[i] == input_label: \n",
    "        c += 1\n",
    "        if image_num == c:\n",
    "            index = i\n",
    "            break\n",
    "        \n",
    "initial = X[index]\n",
    "\n",
    "print(\"Original Label      \", input_label)\n",
    "print(\"Original Prediction:\", torch.argmax(local_net.forward_pass(torch.tensor([initial]))).item())\n",
    "print(\"Target Label        \", target_label)\n",
    "\n",
    "from AdversarialExamples import adversarial_example\n",
    "\n",
    "## Original / Rule Advasarial / GD Adversarial \n",
    "rule_adversarial = adversarial_example(torch.tensor([initial]).double(), torch.tensor([target_label]).long(), local_net.forward_pass)\n",
    "gd_adversarial = adversarial_example(torch.tensor([initial]).double(), torch.tensor([target_label]).long(), gd_net.forward_pass)\n",
    "\n",
    "fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15,15))\n",
    "\n",
    "original_image = initial.reshape((28,28))\n",
    "ax1.imshow(1-original_image, cmap='gray', vmin=0, vmax=1)\n",
    "ax1.axis('off')\n",
    "ax1.set_title(\"Original\")\n",
    "\n",
    "img = rule_adversarial.detach().numpy().reshape((28,28))\n",
    "ax2.imshow(1-img, cmap='gray', vmin=0, vmax=1)\n",
    "ax2.axis('off')\n",
    "ax2.set_title(\"Adversarial For Rule Net\")\n",
    "\n",
    "img = gd_adversarial.detach().numpy().reshape((28,28))\n",
    "ax3.imshow(1-img, cmap='gray', vmin=0, vmax=1)\n",
    "ax3.axis('off')\n",
    "ax3.set_title(\"Adversarial For GD Net\")\n",
    "\n",
    "plt.show()\n",
    "\n",
    "print(\"Distance from original\")\n",
    "print(\"To Rule Net Adversarial\", np.linalg.norm(rule_adversarial.detach().numpy() - original_image.flatten()))\n",
    "print(\"To GD Net Adversarial  \", np.linalg.norm(gd_adversarial.detach().numpy() - original_image.flatten()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "collapsed": true,
    "id": "08R2tgOFkbKj"
   },
   "outputs": [],
   "source": [
    "## Entry (i,j) of matrix denotes average distance between original and \n",
    "## adversarial image given an original image with label i and target \n",
    "## of incorrectly labelling as j.\n",
    "### Only uses examples which are initially predicted correctly\n",
    "\n",
    "# This cell is for the net trained with a *rule*\n",
    "\n",
    "dist = np.zeros((10, 10))\n",
    "cnt = np.zeros((10, 10))\n",
    "c = 0\n",
    "for _x, _y in zip(X, y): \n",
    "    prediction = torch.argmax(local_net.forward_pass(torch.tensor([_x]).double())).item()\n",
    "    if prediction != _y : continue \n",
    "    print(c)\n",
    "    c += 1\n",
    "    if c == 1000: break \n",
    "    for i in range(10): \n",
    "      if i == _y: continue \n",
    "      final = adversarial_example(torch.tensor([_x]).double(), torch.tensor([i]).long(), local_net.forward_pass)\n",
    "      dist[_y][i] += np.linalg.norm(final.detach().numpy() - _x)\n",
    "      cnt[_y][i] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 221
    },
    "colab_type": "code",
    "id": "kYOwRww-kbKm",
    "outputId": "b626365b-4fe8-451e-f7b6-205e214b0fb9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[   nan 4.3036 3.2477 3.5496 3.6716 3.9076 3.7798 3.2153 3.9242 4.0978]\n",
      " [2.9324    nan 2.5584 2.6141 2.4728 2.157  2.8401 2.1941 2.8192 2.5238]\n",
      " [3.1814 5.829     nan 5.3067 2.941  3.0644 3.5074 3.1032 3.6461 3.5598]\n",
      " [2.2091 3.3291 2.4211    nan 2.4762 1.8944 3.8523 2.3516 2.3504 2.6818]\n",
      " [2.7141 3.5266 3.1825 3.1607    nan 2.3399 3.2641 2.4749 3.1212 2.6029]\n",
      " [3.0668 3.3671 2.7559 3.1453 2.5844    nan 2.7161 2.6993 3.1519 3.0661]\n",
      " [2.8032 4.5101 2.4459 3.044  3.145  2.749     nan 3.0398 2.9833 3.0956]\n",
      " [2.6483 3.5383 3.6278 3.6236 3.0109 2.7692 3.6626    nan 2.7612 3.9348]\n",
      " [2.0901 2.2108 1.8041 1.9799 1.8602 1.4258 2.3519 1.9798    nan 2.0997]\n",
      " [2.2845 2.9895 2.7733 2.1868 1.6665 1.7096 2.5372 1.6702 2.31      nan]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: RuntimeWarning: invalid value encountered in true_divide\n",
      "  \n"
     ]
    }
   ],
   "source": [
    "# BRAIN NET \n",
    "print(dist / cnt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "collapsed": true,
    "id": "OxeD9u0EkbKo"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "collapsed": true,
    "id": "Jqbf7UZ0kbKs"
   },
   "outputs": [],
   "source": [
    "## Entry (i,j) of matrix denotes average distance between original and \n",
    "## adversarial image given an original image with label i and target \n",
    "## of incorrectly labelling as j.\n",
    "### Only uses examples which are initially predicted correctly\n",
    "\n",
    "# This cell is for the net trained with *GD*\n",
    "\n",
    "dist_gd = np.zeros((10, 10))\n",
    "cnt_gd = np.zeros((10, 10))\n",
    "c = 0\n",
    "for _x, _y in zip(X, y): \n",
    "    prediction = torch.argmax(gd_net.forward_pass(torch.tensor([_x]).double())).item()\n",
    "    if prediction != _y : continue \n",
    "    print(c)\n",
    "    c += 1\n",
    "    if c == 1000: break \n",
    "    for i in range(10): \n",
    "      if i == _y: continue \n",
    "      final = adversarial_example(torch.tensor([_x]).double(), torch.tensor([i]).long(), gd_net.forward_pass)\n",
    "      dist_gd[_y][i] += np.linalg.norm(final.detach().numpy() - _x)\n",
    "      cnt_gd[_y][i] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 221
    },
    "colab_type": "code",
    "id": "rNbTsglgkbKu",
    "outputId": "7242a0f2-879b-4d2f-9d73-c8b88bec0028"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[   nan 4.1997 2.1194 2.209  3.8363 1.7579 2.167  2.9692 2.2788 3.9838]\n",
      " [2.8477    nan 0.9763 1.022  2.1902 1.5335 1.8226 1.0837 0.9125 1.4458]\n",
      " [2.842  3.6837    nan 1.7819 3.6504 2.8917 1.7057 3.3463 1.9472 3.4332]\n",
      " [2.5515 2.5379 1.73      nan 2.8838 1.4506 3.9166 2.2613 1.4828 2.3556]\n",
      " [2.5248 3.093  1.8978 2.5504    nan 1.8244 1.2795 1.3411 1.5894 1.2177]\n",
      " [1.461  2.7471 2.0995 1.3323 2.0136    nan 2.098  1.9482 1.3422 1.8807]\n",
      " [2.3852 3.4558 1.4126 2.5065 2.2194 2.1019    nan 2.9214 1.8163 2.104 ]\n",
      " [2.4509 3.0542 2.7269 1.9061 2.1382 2.3714 3.2698    nan 1.7991 1.2074]\n",
      " [2.8093 2.5325 1.6151 1.4135 2.45   1.7137 2.5334 1.9904    nan 1.8814]\n",
      " [2.5486 2.9526 1.9784 1.5137 0.9892 1.5577 2.1608 0.9196 1.2389    nan]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: RuntimeWarning: invalid value encountered in true_divide\n",
      "  \n"
     ]
    }
   ],
   "source": [
    "# GD NET \n",
    "print(dist_gd / cnt_gd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "collapsed": true,
    "id": "Qc9CJwCbmg2M"
   },
   "outputs": [],
   "source": [
    "for i in range(10): \n",
    "  cnt[i][i] = 1\n",
    "  cnt_gd[i][i] = 1\n",
    "rule_res = np.nan_to_num(dist) / cnt \n",
    "gd_res = np.nan_to_num(dist_gd) / cnt_gd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 187
    },
    "colab_type": "code",
    "id": "zm8XAVjKpAfb",
    "outputId": "8c681107-2972-4b5f-b0b6-17f326a2c443"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.000 0.104 1.128 1.341 -0.165 2.150 1.613 0.246 1.645 0.114]\n",
      " [0.085 0.000 1.582 1.592 0.283 0.623 1.018 1.110 1.907 1.078]\n",
      " [0.339 2.145 0.000 3.525 -0.709 0.173 1.802 -0.243 1.699 0.127]\n",
      " [-0.342 0.791 0.691 0.000 -0.408 0.444 -0.064 0.090 0.868 0.326]\n",
      " [0.189 0.434 1.285 0.610 0.000 0.515 1.985 1.134 1.532 1.385]\n",
      " [1.606 0.620 0.656 1.813 0.571 0.000 0.618 0.751 1.810 1.185]\n",
      " [0.418 1.054 1.033 0.537 0.926 0.647 0.000 0.118 1.167 0.992]\n",
      " [0.197 0.484 0.901 1.718 0.873 0.398 0.393 0.000 0.962 2.727]\n",
      " [-0.719 -0.322 0.189 0.566 -0.590 -0.288 -0.181 -0.011 0.000 0.218]\n",
      " [-0.264 0.037 0.795 0.673 0.677 0.152 0.376 0.751 1.071 0.000]]\n"
     ]
    }
   ],
   "source": [
    " np.set_printoptions(formatter={'float': lambda x: \"{0:0.3f}\".format(x)})\n",
    "\n",
    "print(rule_res - gd_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "collapsed": true,
    "id": "o0oO9ScFpFX4",
    "outputId": "93f7f8b5-a10f-4fdc-b4b9-3efaf0e6bc1c"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.6277634741060685\n",
      "1.9669432183894444\n"
     ]
    }
   ],
   "source": [
    "# means\n",
    "print(np.nan_to_num(rule_res).mean())\n",
    "print(np.nan_to_num(gd_res).mean())"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "mnist_adersarial.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
