{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WjC-i_eftowA"
      },
      "source": [
        "# Install the stable version of BindsNet, the restart your run time before running the next cell"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "7RmmHgLl7LuD"
      },
      "outputs": [],
      "source": [
        "! pip install git+https://github.com/BindsNET/bindsnet.git\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CN438ZO1tmJt"
      },
      "source": [
        "# Training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "79J03CFKgy73"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import os\n",
        "import numpy as np\n",
        "\n",
        "from biolcnet import BioLCNet\n",
        "from reward import DynamicDopamineInjection\n",
        "from dataset import ClassSelector, load_datasets\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "gpu = True\n",
        "seed = 2045 # The singularity is near!\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "if gpu and torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "else:\n",
        "    torch.manual_seed(seed)\n",
        "    device = \"cpu\"\n",
        "    if gpu:\n",
        "        gpu = False\n",
        "\n",
        "torch.set_num_threads(os.cpu_count() - 1)\n",
        "print(\"Running on\", device)\n",
        "\n",
        "\n",
        "n_neurons = 1000\n",
        "n_classes = 10\n",
        "neuron_per_class = int(n_neurons/n_classes)\n",
        "\n",
        "\n",
        "train_hparams = {\n",
        "    'n_train' : 10000,\n",
        "    'n_test' : 10000,\n",
        "    'n_val' : 1,\n",
        "    'val_interval' : 250,\n",
        "    'running_window_length': 250,\n",
        "}\n",
        "\n",
        "network_hparams = {\n",
        "    # net structure\n",
        "    'crop_size': 22,\n",
        "    'neuron_per_c': neuron_per_class,\n",
        "    'in_channels':1,\n",
        "    'n_channels_lc': 100,\n",
        "    'filter_size': 15,\n",
        "    'stride': 4,\n",
        "    'n_neurons' : n_neurons,\n",
        "    'n_classes': n_classes,\n",
        "    \n",
        "    # time & Phase\n",
        "    'dt' : 1,\n",
        "    'pre_observation': True,\n",
        "    'has_decision_period': True,\n",
        "    'observation_period': 256,\n",
        "    'decision_period': 256,\n",
        "    'time': 256*3,\n",
        "    'online_rewarding': False,\n",
        "\n",
        "    # Nodes\n",
        "    'theta_plus': 0.05,\n",
        "    'tc_theta_decay': 1e6,\n",
        "    'tc_trace':20,\n",
        "    'trace_additive' : False,\n",
        "    \n",
        "    # Learning\n",
        "    'nu_LC': (0.0001,0.01),\n",
        "    'nu_Output':0.1,\n",
        "\n",
        "    # weights\n",
        "    'wmin': 0.0,\n",
        "    'wmax': 1.0,\n",
        "    \n",
        "    # Inhibition\n",
        "    'inh_type_FC': 'between_layers',\n",
        "    'inh_factor_FC': 100,\n",
        "    'inh_LC': True,\n",
        "    'inh_factor_LC': 100,\n",
        "    \n",
        "    # Normalization\n",
        "    'norm_factor_LC': 0.25*15*15,\n",
        "    \n",
        "    # clamping\n",
        "    'clamp_intensity': None,\n",
        "\n",
        "    # Save\n",
        "    'save_path': None,  # Specify for saving the model (Especially for pre-training the lc layer)\n",
        "    'load_path': None,\n",
        "    'LC_weights_path': 'BioLCNet_layer1_Shallow_f15_s4_inh100_norm25_ch100_pretrained_LCweights.pth', # Specify for loading the pre-trained lc weights\n",
        "\n",
        "    # Plot:\n",
        "    'confusion_matrix' : False,\n",
        "    'lc_weights_vis': False,\n",
        "    'out_weights_vis': False,\n",
        "    'lc_convergence_vis': False,\n",
        "    'out_convergence_vis': False,\n",
        "}\n",
        "\n",
        "reward_hparams= {\n",
        "    'n_labels': n_classes,\n",
        "    'neuron_per_class': neuron_per_class,\n",
        "    'variant': 'scalar',\n",
        "    'tc_reward':0,\n",
        "    'dopamine_base': 0.0,\n",
        "    'reward_base': 1.,\n",
        "    'punishment_base': 1.,\n",
        "    'sub_variant': 'RPE',\n",
        "    'td_nu': 0.0005,  #RPE\n",
        "    'ema_window': 10, #RPE\n",
        "    }\n",
        "\n",
        "# Dataset Hyperparameters\n",
        "mask = None\n",
        "mask_test = None\n",
        "n_classes = 10\n",
        "\n",
        "data_hparams = { \n",
        "    'intensity': 128,\n",
        "    'time': 256*3,\n",
        "    'crop_size': 22,\n",
        "    'round_input': False,\n",
        "}\n",
        "\n",
        "dataloader, val_loader, test_loader = load_datasets(data_hparams, target_classes=None, mask=mask, mask_test=mask_test)\n",
        "\n",
        "\n",
        "hparams = {**reward_hparams, **network_hparams, **train_hparams, **data_hparams}\n",
        "net = BioLCNet(**hparams, reward_fn = DynamicDopamineInjection, gpu = gpu)\n",
        "net.testing = False\n",
        "net.fit(dataloader = dataloader, val_loader = val_loader, reward_hparams = reward_hparams, **train_hparams)\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RN95m7FOtie2"
      },
      "source": [
        "# Test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BVsfU4Aqthz7"
      },
      "outputs": [],
      "source": [
        "\n",
        "### Test\n",
        "data_hparams = { \n",
        "    'intensity': 128,\n",
        "    'time': 256*2,\n",
        "    'crop_size': 22,\n",
        "    'round_input': False,\n",
        "}\n",
        "network_hparams['time'] = 256*2\n",
        "net.time = 256*2\n",
        "_, _, test_loader = load_datasets(data_hparams, target_classes=None, mask=mask, mask_test=mask_test)\n",
        "net.testing = True\n",
        "net.evaluate(test_loader, train_hparams['n_test'], **reward_hparams)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "main.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
