{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "wu_direct_training.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "95063105a8d446069469dfb5f37f9133": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_ee5a1e2b7fe340cda50b221b3a1ecc61",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_36bb237a606c4a0cbccade61d225f587",
              "IPY_MODEL_08cbd0cf8fdb4dc7be64a729d0f4fa1e"
            ]
          }
        },
        "ee5a1e2b7fe340cda50b221b3a1ecc61": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "36bb237a606c4a0cbccade61d225f587": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_27e692aa494348809eabe4bbbfed3743",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_3ad3f86915e540b9b5727cb8b9fd45da"
          }
        },
        "08cbd0cf8fdb4dc7be64a729d0f4fa1e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_597b1cefb60049f08853bac7cf9a86a1",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 170500096/? [00:20&lt;00:00, 53521684.87it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e456505ab5964797a7939ac0d6750df7"
          }
        },
        "27e692aa494348809eabe4bbbfed3743": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "3ad3f86915e540b9b5727cb8b9fd45da": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "597b1cefb60049f08853bac7cf9a86a1": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e456505ab5964797a7939ac0d6750df7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "ashapd-OQ691"
      },
      "source": [
        "# own implementation of Wu (2019) Direct training..paper\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import math\n",
        "import sys\n",
        "import numpy as np\n",
        "import numpy.linalg as LA\n",
        "\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "from   torch.utils.data.dataloader import DataLoader\n",
        "import time\n",
        "import shutil\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import scipy.io\n",
        "import numpy\n",
        "import torch.nn.parallel\n",
        "\n",
        "from torch.autograd import Variable\n",
        "import torch.utils.data\n",
        "\n",
        "import torchvision.transforms as transforms\n",
        "import torchvision.datasets as datasets\n",
        "from random import randrange\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "thresh = 0.5 # neuronal threshold\n",
        "lens = 0.5 # hyper-parameters of approximate function\n",
        "decay = 0.2 # decay constants\n",
        "decay2=0.2\n",
        "num_classes = 10\n",
        "batch_size  = 16\n",
        "learning_rate = 1e-3\n",
        "num_epochs = 100 # max epoch\n",
        "# define approximate firing function\n",
        "class ActFun(torch.autograd.Function):\n",
        "\n",
        "    @staticmethod\n",
        "    def forward(ctx, input):\n",
        "        ctx.save_for_backward(input)\n",
        "        return input.gt(thresh).float()\n",
        "\n",
        "    @staticmethod\n",
        "    def backward(ctx, grad_output):\n",
        "        input, = ctx.saved_tensors\n",
        "        grad_input = grad_output.clone()\n",
        "        temp = abs(input - thresh) < lens\n",
        "        return grad_input * temp.float()\n",
        "\n",
        "act_fun = ActFun.apply\n",
        "# membrane potential update\n",
        "def mem_update(ops, x, mem, spike):\n",
        "    mem = mem * decay * (1. - spike) + ops(x)\n",
        "    spike = act_fun(mem) # act_fun : approximation firing function\n",
        "    return mem, spike\n",
        "\n",
        "# cnn_layer(in_planes, out_planes, stride, padding, kernel_size)\n",
        "cfg_cnn = [(1, 32, 1, 1, 3),\n",
        "           (32, 32, 1, 1, 3),]\n",
        "# kernel size\n",
        "cfg_kernel = [32, 16, 8]\n",
        "# fc layer\n",
        "cfg_fc = [1024, 512]\n",
        "\n",
        "# Dacay learning_rate\n",
        "def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):\n",
        "    \"\"\"Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.\"\"\"\n",
        "    if epoch % lr_decay_epoch == 0 and epoch > 1:\n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = param_group['lr'] * 0.1\n",
        "    return optimizer\n",
        "\n",
        "class cifarnet(nn.Module):\n",
        "    def __init__(self, tst):\n",
        "        super(cifarnet, self).__init__()\n",
        "        #in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0]\n",
        "        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)\n",
        "        #in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[1]\n",
        "        #self.u1 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]))\n",
        "        self.u1 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]), mean=0.0, std=1.0))\n",
        "\n",
        "        self.conv2 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u2 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]))\n",
        "        self.u2= nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.conv3 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u3 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[1], cfg_kernel[1]))\n",
        "        self.u3 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[1], cfg_kernel[1]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.conv4 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u4 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]))\n",
        "        self.u4 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.conv5 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u5 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]))\n",
        "        self.u5 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.fc1 = nn.Linear(512*64, 1024)\n",
        "        self.fc2 = nn.Linear(1024, 512)\n",
        "\n",
        "        self.num_steps = tst\n",
        "\n",
        "        #random class assignments to neurons\n",
        "        a=[]\n",
        "        for i in range(512):\n",
        "          a=a+[randrange(10)]\n",
        "\n",
        "        tmp2 = torch.Tensor(a)\n",
        "\n",
        "        #voting matrix M\n",
        "        self.M=torch.zeros(10,512)\n",
        "        for k in range(10):\n",
        "          for l in range(512):\n",
        "            if(a[l]==k):\n",
        "              self.M[k,l]=1\n",
        "        self.M=self.M.T\n",
        "        self.M=self.M.cuda()\n",
        "        for m in self.modules():\n",
        "            if(isinstance(m, nn.Conv2d)):\n",
        "               #m.threshold = 0.999#0.75 #1.0\n",
        "               n1 = m.kernel_size[0] * m.kernel_size[1] * m.in_channels\n",
        "               variance1 = math.sqrt(2. / (n1))  # math.sqrt(6. / (n + n1))\n",
        "               m.weight.data.normal_(0, variance1)\n",
        "               \n",
        "            \n",
        "            elif(isinstance(m, nn.Linear)):\n",
        "               #m.threshold = 0.999               \n",
        "               size = m.weight.size()\n",
        "               fan_in = size[1]  # number of columns\n",
        "               variance2 = math.sqrt(2.0 / (fan_in))  # + fan_out)) #math.sqrt(6.0 / (fan_in + fan_out))\n",
        "               m.weight.data.normal_(0.0, variance2)\n",
        "\n",
        "\n",
        "    def forward(self, input):\n",
        "        c1_mem = c1_spike = torch.zeros(batch_size, 128, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        aux_1=torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        c2_mem = c2_spike = torch.zeros(batch_size, 256, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        aux_2=torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        c3_mem = c3_spike = torch.zeros(batch_size, 512, cfg_kernel[1], cfg_kernel[1], device=device)\n",
        "        aux_3=torch.zeros(batch_size, 1, cfg_kernel[1], cfg_kernel[1], device=device)\n",
        "        c4_mem = c4_spike = torch.zeros(batch_size, 1024, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "        aux_4=torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "        c5_mem = c5_spike = torch.zeros(batch_size, 512, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "        aux_5=torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "\n",
        "        h1_mem = h1_spike = torch.zeros(batch_size, cfg_fc[0], device=device)\n",
        "        h2_mem = h2_spike  =h2_sumspike= torch.zeros(batch_size, cfg_fc[1], device=device)\n",
        "\n",
        "        for step in range(self.num_steps): # simulation time steps\n",
        "            #x = input > torch.rand(input.size(), device=device) # prob. firing\n",
        "\n",
        "            c1_mem, c1_spike = mem_update(self.conv1, input, c1_mem, c1_spike)\n",
        "            aux_1= decay2 *aux_1+.00625*c1_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            #neu-norm\n",
        "            \n",
        "            x1=c1_spike-(self.u1*aux_1)\n",
        "            c2_mem, c2_spike = mem_update(self.conv2,x1, c2_mem,c2_spike)\n",
        "            aux_2= decay2 *aux_2+.003125*c2_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x2=c2_spike-(self.u2*aux_2)\n",
        "\n",
        "            x3 = F.avg_pool2d(x2, 2)\n",
        "\n",
        "            c3_mem, c3_spike = mem_update(self.conv3,x3, c3_mem,c3_spike)\n",
        "            aux_3= decay2 *aux_3+0.0015625*c3_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x4=c3_spike-(self.u3*aux_3)\n",
        "            x5 = F.avg_pool2d(x4, 2)\n",
        "\n",
        "            c4_mem, c4_spike = mem_update(self.conv4,x5, c4_mem,c4_spike)\n",
        "            aux_4= decay2 *aux_4+0.00078125*c4_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x6=c4_spike-(self.u4*aux_4)\n",
        "\n",
        "            c5_mem, c5_spike = mem_update(self.conv5,x6, c5_mem,c5_spike)\n",
        "            aux_5= decay2 *aux_5+0.0015625*c5_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x7=c5_spike-(self.u5*aux_5)\n",
        "                       \n",
        "\n",
        "            \n",
        "            x7 = x7.view(batch_size, -1)\n",
        "\n",
        "            h1_mem, h1_spike = mem_update(self.fc1, x7, h1_mem, h1_spike)\n",
        "            #h1_sumspike += h1_spike\n",
        "            h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)\n",
        "            h2_sumspike += h2_spike\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "        h=torch.matmul(h2_sumspike, self.M)\n",
        "        outputs = h/ self.num_steps         \n",
        "        return outputs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CnjW5lRGJNmz"
      },
      "source": [
        "# own implementation of Wu (2019) Direct training..paper with eqn(9) k_tau+V/F=1\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import math\n",
        "import sys\n",
        "import numpy as np\n",
        "import numpy.linalg as LA\n",
        "\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "from   torch.utils.data.dataloader import DataLoader\n",
        "import time\n",
        "import shutil\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import scipy.io\n",
        "import numpy\n",
        "import torch.nn.parallel\n",
        "\n",
        "from torch.autograd import Variable\n",
        "import torch.utils.data\n",
        "\n",
        "import torchvision.transforms as transforms\n",
        "import torchvision.datasets as datasets\n",
        "from random import randrange\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "thresh = 0.5 # neuronal threshold\n",
        "lens = 0.5 # hyper-parameters of approximate function\n",
        "decay = 0.2 # decay constants\n",
        "decay2=0.2\n",
        "num_classes = 10\n",
        "batch_size  = 16\n",
        "learning_rate = 1e-3\n",
        "num_epochs = 100 # max epoch\n",
        "# define approximate firing function\n",
        "class ActFun(torch.autograd.Function):\n",
        "\n",
        "    @staticmethod\n",
        "    def forward(ctx, input):\n",
        "        ctx.save_for_backward(input)\n",
        "        return input.gt(thresh).float()\n",
        "\n",
        "    @staticmethod\n",
        "    def backward(ctx, grad_output):\n",
        "        input, = ctx.saved_tensors\n",
        "        grad_input = grad_output.clone()\n",
        "        temp = abs(input - thresh) < lens\n",
        "        return grad_input * temp.float()\n",
        "\n",
        "act_fun = ActFun.apply\n",
        "# membrane potential update\n",
        "def mem_update(ops, x, mem, spike):\n",
        "    mem = mem * decay * (1. - spike) + ops(x)\n",
        "    spike = act_fun(mem) # act_fun : approximation firing function\n",
        "    return mem, spike\n",
        "\n",
        "# cnn_layer(in_planes, out_planes, stride, padding, kernel_size)\n",
        "cfg_cnn = [(1, 32, 1, 1, 3),\n",
        "           (32, 32, 1, 1, 3),]\n",
        "# kernel size\n",
        "cfg_kernel = [32, 16, 8]\n",
        "# fc layer\n",
        "cfg_fc = [1024, 512]\n",
        "\n",
        "# Dacay learning_rate\n",
        "def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):\n",
        "    \"\"\"Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.\"\"\"\n",
        "    if epoch % lr_decay_epoch == 0 and epoch > 1:\n",
        "        for param_group in optimizer.param_groups:\n",
        "            param_group['lr'] = param_group['lr'] * 0.1\n",
        "    return optimizer\n",
        "\n",
        "class cifarnet2(nn.Module):\n",
        "    def __init__(self, tst):\n",
        "        super(cifarnet2, self).__init__()\n",
        "        #in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0]\n",
        "        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1)\n",
        "        #in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[1]\n",
        "        #self.u1 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]))\n",
        "        self.u1 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]), mean=0.0, std=1.0))\n",
        "\n",
        "        self.conv2 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u2 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]))\n",
        "        self.u2= nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.conv3 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u3 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[1], cfg_kernel[1]))\n",
        "        self.u3 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[1], cfg_kernel[1]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.conv4 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u4 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]))\n",
        "        self.u4 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.conv5 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)\n",
        "        #self.u5 = nn.Parameter(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]))\n",
        "        self.u5 = nn.Parameter(nn.init.normal_(torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2]), mean=0.0, std=1.0))\n",
        "\n",
        "\n",
        "        self.fc1 = nn.Linear(512*64, 1024)\n",
        "        self.fc2 = nn.Linear(1024, 512)\n",
        "\n",
        "        self.num_steps = tst\n",
        "\n",
        "        #random class assignments to neurons\n",
        "        a=[]\n",
        "        for i in range(512):\n",
        "          a=a+[randrange(10)]\n",
        "\n",
        "        tmp2 = torch.Tensor(a)\n",
        "\n",
        "        #voting matrix M\n",
        "        self.M=torch.zeros(10,512)\n",
        "        for k in range(10):\n",
        "          for l in range(512):\n",
        "            if(a[l]==k):\n",
        "              self.M[k,l]=1\n",
        "        self.M=self.M.T\n",
        "        self.M=self.M.cuda()\n",
        "        for m in self.modules():\n",
        "            if(isinstance(m, nn.Conv2d)):\n",
        "               #m.threshold = 0.999#0.75 #1.0\n",
        "               n1 = m.kernel_size[0] * m.kernel_size[1] * m.in_channels\n",
        "               variance1 = math.sqrt(2. / (n1))  # math.sqrt(6. / (n + n1))\n",
        "               m.weight.data.normal_(0, variance1)\n",
        "               \n",
        "            \n",
        "            elif(isinstance(m, nn.Linear)):\n",
        "               #m.threshold = 0.999               \n",
        "               size = m.weight.size()\n",
        "               fan_in = size[1]  # number of columns\n",
        "               variance2 = math.sqrt(2.0 / (fan_in))  # + fan_out)) #math.sqrt(6.0 / (fan_in + fan_out))\n",
        "               m.weight.data.normal_(0.0, variance2)\n",
        "\n",
        "\n",
        "    def forward(self, input):\n",
        "        c1_mem = c1_spike = torch.zeros(batch_size, 128, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        aux_1=torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        c2_mem = c2_spike = torch.zeros(batch_size, 256, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        aux_2=torch.zeros(batch_size, 1, cfg_kernel[0], cfg_kernel[0], device=device)\n",
        "        c3_mem = c3_spike = torch.zeros(batch_size, 512, cfg_kernel[1], cfg_kernel[1], device=device)\n",
        "        aux_3=torch.zeros(batch_size, 1, cfg_kernel[1], cfg_kernel[1], device=device)\n",
        "        c4_mem = c4_spike = torch.zeros(batch_size, 1024, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "        aux_4=torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "        c5_mem = c5_spike = torch.zeros(batch_size, 512, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "        aux_5=torch.zeros(batch_size, 1, cfg_kernel[2], cfg_kernel[2], device=device)\n",
        "\n",
        "        h1_mem = h1_spike = torch.zeros(batch_size, cfg_fc[0], device=device)\n",
        "        h2_mem = h2_spike  =h2_sumspike= torch.zeros(batch_size, cfg_fc[1], device=device)\n",
        "\n",
        "        for step in range(self.num_steps): # simulation time steps\n",
        "            #x = input > torch.rand(input.size(), device=device) # prob. firing\n",
        "\n",
        "            c1_mem, c1_spike = mem_update(self.conv1, input, c1_mem, c1_spike)\n",
        "            aux_1= decay2 *aux_1+.8*c1_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            #neu-norm\n",
        "            \n",
        "            x1=c1_spike-(self.u1*aux_1)\n",
        "            c2_mem, c2_spike = mem_update(self.conv2,x1, c2_mem,c2_spike)\n",
        "            aux_2= decay2 *aux_2+.8*c2_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x2=c2_spike-(self.u2*aux_2)\n",
        "\n",
        "            x3 = F.avg_pool2d(x2, 2)\n",
        "\n",
        "            c3_mem, c3_spike = mem_update(self.conv3,x3, c3_mem,c3_spike)\n",
        "            aux_3= decay2 *aux_3+0.8*c3_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x4=c3_spike-(self.u3*aux_3)\n",
        "            x5 = F.avg_pool2d(x4, 2)\n",
        "\n",
        "            c4_mem, c4_spike = mem_update(self.conv4,x5, c4_mem,c4_spike)\n",
        "            aux_4= decay2 *aux_4+0.8*c4_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x6=c4_spike-(self.u4*aux_4)\n",
        "\n",
        "            c5_mem, c5_spike = mem_update(self.conv5,x6, c5_mem,c5_spike)\n",
        "            aux_5= decay2 *aux_5+0.8*c5_spike.sum(dim=1).unsqueeze(dim=1)\n",
        "            x7=c5_spike-(self.u5*aux_5)\n",
        "                       \n",
        "\n",
        "            \n",
        "            x7 = x7.view(batch_size, -1)\n",
        "\n",
        "            h1_mem, h1_spike = mem_update(self.fc1, x7, h1_mem, h1_spike)\n",
        "            #h1_sumspike += h1_spike\n",
        "            h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)\n",
        "            h2_sumspike += h2_spike\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "        h=torch.matmul(h2_sumspike, self.M)\n",
        "        outputs = h/ self.num_steps         \n",
        "        return outputs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rFVnEjobK49y",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101,
          "referenced_widgets": [
            "95063105a8d446069469dfb5f37f9133",
            "ee5a1e2b7fe340cda50b221b3a1ecc61",
            "36bb237a606c4a0cbccade61d225f587",
            "08cbd0cf8fdb4dc7be64a729d0f4fa1e",
            "27e692aa494348809eabe4bbbfed3743",
            "3ad3f86915e540b9b5727cb8b9fd45da",
            "597b1cefb60049f08853bac7cf9a86a1",
            "e456505ab5964797a7939ac0d6750df7"
          ]
        },
        "outputId": "1f2bfe60-b7cb-470c-db17-ab4915cb0fac"
      },
      "source": [
        "normalize_usual = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
        "normalize_cifar = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])\n",
        "\n",
        "usual=1\n",
        "#cifar10 loader\n",
        "if usual:\n",
        "  transform_train = transforms.Compose([\n",
        "    transforms.RandomCrop(32, padding=4),\n",
        "    transforms.RandomHorizontalFlip(),\n",
        "    transforms.ToTensor(),\n",
        "    normalize_usual,])\n",
        "\n",
        "  transform_test = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    normalize_usual,])\n",
        "  \n",
        "else:\n",
        "  transform_train = transforms.Compose([\n",
        "    transforms.RandomCrop(32, padding=4),\n",
        "    transforms.RandomHorizontalFlip(),\n",
        "    transforms.ToTensor(),\n",
        "    normalize_cifar,])\n",
        "\n",
        "  transform_test = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    normalize_cifar,])\n",
        "\n",
        "\n",
        "trainset = torchvision.datasets.CIFAR10(\n",
        "    root='./data', train=True, download=True, transform=transform_train)\n",
        "trainloader = torch.utils.data.DataLoader(\n",
        "    trainset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
        "\n",
        "testset = torchvision.datasets.CIFAR10(\n",
        "    root='./data', train=False, download=True, transform=transform_test)\n",
        "testloader = torch.utils.data.DataLoader(\n",
        "    testset, batch_size=batch_size, shuffle=False, num_workers=2)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "95063105a8d446069469dfb5f37f9133",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting ./data/cifar-10-python.tar.gz to ./data\n",
            "Files already downloaded and verified\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "btUUTxhnLakK",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "76e8b989-c724-4dd0-e852-4ee8acfd4b6d"
      },
      "source": [
        "t_end=20\n",
        "\n",
        "model_str_use    = 'wu_direct'+'_t_'+str(t_end)+'_lr'+str(learning_rate)+'_bs'+str(batch_size)+'_weightdecay5e-4'+'_testloss'\n",
        "\n",
        "ckpt_fname  = model_str_use+'ckpt.pth'\n",
        "\n",
        "model = cifarnet(tst=t_end) \n",
        "#model = cifarnet2(tst=t_end)\n",
        "model = model.cuda()\n",
        "model = torch.nn.DataParallel(model).cuda()\n",
        "\n",
        "# Print the SNN model\n",
        "print('********** SNN model **********')\n",
        "print(model)\n"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "********** SNN model **********\n",
            "DataParallel(\n",
            "  (module): cifarnet2(\n",
            "    (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (conv3): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (conv4): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (conv5): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
            "    (fc1): Linear(in_features=32768, out_features=1024, bias=True)\n",
            "    (fc2): Linear(in_features=1024, out_features=512, bias=True)\n",
            "  )\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LTYsu-gzPDR_"
      },
      "source": [
        "best_acc = 0  # best test accuracy\n",
        "start_epoch = 0  # start from epoch 0 or last checkpoint epoch\n",
        "acc_record = list([])\n",
        "loss_train_record = list([])\n",
        "loss_test_record = list([])\n",
        "\n",
        "criterion = nn.MSELoss()\n",
        "#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)\n",
        "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=5e-4)\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "    running_loss = 0\n",
        "    start_time = time.time()\n",
        "    model.train()\n",
        "    for i, (images, labels) in enumerate(trainloader):\n",
        "        model.zero_grad()\n",
        "        optimizer.zero_grad()\n",
        "\n",
        "        images = images.cuda()\n",
        "        outputs = model(images)\n",
        "        labels_ = torch.zeros(batch_size, 10).scatter_(1, labels.view(-1, 1), 1)\n",
        "        loss = criterion(outputs.cpu(), labels_)\n",
        "        running_loss += loss.item()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        if (i+1)%100 == 0:\n",
        "             print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f'\n",
        "                    %(epoch+1, num_epochs, i+1, len(trainset)//batch_size,running_loss ))\n",
        "             running_loss = 0\n",
        "             print('Time elasped:', time.time()-start_time)\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        for batch_idx, (inputs, targets) in enumerate(testloader):\n",
        "            inputs = inputs.to(device)\n",
        "            #optimizer.zero_grad()\n",
        "            outputs = model(inputs)\n",
        "            labels_ = torch.zeros(batch_size, 10).scatter_(1, targets.view(-1, 1), 1)\n",
        "            loss = criterion(outputs.cpu(), labels_)\n",
        "            _, predicted = outputs.cpu().max(1)\n",
        "            total += float(targets.size(0))\n",
        "            correct += float(predicted.eq(targets).sum().item())\n",
        "            if batch_idx %100 ==0:\n",
        "                acc = 100. * float(correct) / float(total)\n",
        "                print(batch_idx, len(testloader),' Acc: %.5f' % acc)\n",
        "\n",
        "    print('Iters:', epoch,'\\n\\n\\n')\n",
        "    print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total))\n",
        "    acc = 100. * float(correct) / float(total)\n",
        "    acc_record.append(acc)\n",
        "    "
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}