{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "sparse_combo_RNNs.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "ntDs_OIQpyqR"
      },
      "source": [
        "# export of Google Colab notebook which can be used to run trials reported on in \"RECURSIVE CONSTRUCTION OF STABLE ASSEMBLIES OF RECURRENT NEURAL NETWORKS\"\n",
        "# major user settings are found in the next two cells after this\n",
        "# desired trial can then be run by simply running all cells\n",
        "# code as is requires a single GPU\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "import torch.jit as jit\n",
        "from torch.autograd import Variable\n",
        "from torch.utils.data import Dataset, DataLoader, Subset\n",
        "from torch._utils import _accumulate\n",
        "from torchvision import datasets, transforms\n",
        "\n",
        "from sklearn.model_selection import train_test_split \n",
        "from tqdm.auto import tqdm\n",
        "\n",
        "import numpy as np\n",
        "import numpy.random as npr\n",
        "import scipy.sparse\n",
        "import scipy.integrate\n",
        "\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "import os"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SKuhEHMqKt5T"
      },
      "source": [
        "# major model settings\n",
        "task = \"cifar10\" # will accept cifar10, pmnist, or mnist\n",
        "ns = [32 for x in range(16)] # list of integers, with each integer being the size for a component RNN. the number of component RNNs is then the length of the list\n",
        "rnn_density = 0.033 # density setting for each component RNN\n",
        "pre_scalar = 30.0 # sample non-zero entries for potential component RNN uniformly from -pre_scalar to pre_scalar, keep only if satisfies theorem 1\n",
        "# (note that if density and pre_scalar settings are both too large code will stall because it will be impossible to find component RNNs meeting the condition)\n",
        "post_scalar = 0.2 # multiply chosen component RNNs by post_scalar before beginning training (should be a positive number <= 1)\n",
        "\n",
        "# training settings - uses Adam optimizer\n",
        "num_epochs = 200\n",
        "lr = 1e-3\n",
        "lr_scale_epochs = [140, 190] # the epochs after which to perform a learning rate cut\n",
        "lr_scalar = 0.1 # the scalar to multiply the learning rate by at the specified epochs\n",
        "weight_decay = 1e-5"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oZlVeZlEHKQx"
      },
      "source": [
        "# settings for saving model stats and checkpoints\n",
        "# model checkpoint is saved after each epoch, as well as an updated CSV with per-epoch training loss and test accuracy\n",
        "\n",
        "# if not on colab can comment out the following two lines for mounting Google Drive \n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# output file path info\n",
        "output_root = \"/content/drive/My Drive/\" # folder path everything will be saved to\n",
        "model_name = \"test\" # specific trial name to use as part of each file's save name"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "UbecPCk3vcYA"
      },
      "source": [
        "# use GPU\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "if device.type == 'cuda':\n",
        "    print('Default tensor type is now cuda')\n",
        "    torch.set_default_tensor_type('torch.cuda.FloatTensor')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RlMls_I4aA6-"
      },
      "source": [
        "# functions for initializing the sparse matrices\n",
        "\n",
        "# function that checks the Theorem 1 condition for an input square matrix\n",
        "def check_cond(W):\n",
        "    W_diag_only = np.diag(np.diag(W))\n",
        "    W_diag_pos_only = W_diag_only.copy()\n",
        "    W_diag_pos_only[W_diag_pos_only < 0] = 0.0\n",
        "    W_abs_cond = np.abs(W - W_diag_only) + W_diag_pos_only\n",
        "    max_eig_abs_cond = np.max(np.real(np.linalg.eigvals(W_abs_cond)))\n",
        "    if max_eig_abs_cond < 1:\n",
        "        return True\n",
        "    else:\n",
        "        return False\n",
        "\n",
        "# sampling function to use for each non-zero element in a generated matrix\n",
        "# (uniform but between -1 and 1 instead of default 0 to 1)\n",
        "def uniform_with_neg(x):\n",
        "    return np.random.uniform(low=-1.0,high=1.0, size=x)\n",
        "\n",
        "# function that creates a square matrix of a given size with a given density and distribution (+ scalar for the distribution)\n",
        "# also zeroes out the diagonal after generation\n",
        "# the uniform_with_neg function is used as the sampling function by default\n",
        "def generate_random_sparse_matrix(num_units, density, dist_multiplier, dist_func=uniform_with_neg):\n",
        "    test = scipy.sparse.random(num_units,num_units,density=density,format='csr',data_rvs=dist_func)\n",
        "    np_test = test.toarray()\n",
        "    np_test = dist_multiplier * np_test\n",
        "    np.fill_diagonal(np_test, 0)\n",
        "    return np_test\n",
        "\n",
        "# create a full set of RNN modules using the above functions\n",
        "# will only keep a given random RNN if it meets the condition for theorem 1\n",
        "def create_modules(module_sizes, density, dist_multiplier, post_select_multiplier):\n",
        "    modules = []\n",
        "    for m in module_sizes:\n",
        "        okay = False\n",
        "        while not okay:\n",
        "            cur_matrix = generate_random_sparse_matrix(m, density, dist_multiplier)\n",
        "            okay = check_cond(cur_matrix)\n",
        "        cur_matrix = post_select_multiplier * cur_matrix # once reach here the current matrix is one of the ones selected\n",
        "        modules.append(cur_matrix)\n",
        "    return modules\n",
        "\n",
        "# function to put together the generated component RNNs into one big block diagonal weight matrix\n",
        "def combine_W(W_list):\n",
        "    shapes = [w.shape[0] for w in W_list]\n",
        "    total_size = np.sum(shapes)\n",
        "    full_W = np.zeros((total_size, total_size))\n",
        "    for i in range(len(W_list)):\n",
        "        cur_W = W_list[i]\n",
        "        cur_size = shapes[i]\n",
        "        first_index = int(np.sum(shapes[0:i]))\n",
        "        last_index = first_index + cur_size\n",
        "        full_W[first_index:last_index, first_index:last_index] = cur_W\n",
        "    return full_W\n",
        "\n",
        "# use theorem 1 + info about linear stability to find a metric for a given weight matrix (expected to be generated from the above)\n",
        "def find_M(W_inp):\n",
        "    # what we actually want to find metric for is W - I -> just W won't be stable here\n",
        "    # also first need to focus on abs(W), not W itself! that is what linear stable test can find, the same metric will then work for the other (per Thm 1)\n",
        "    W = np.abs(W_inp) # diagonal is set to 0 already so no need to worry about that\n",
        "    W = W - np.identity(W.shape[0])\n",
        "    # this just finds some M that will work, could be many others\n",
        "    Q = np.identity(W.shape[0])\n",
        "    # solve for M in -Q = M * W + np.transpose(W) * M\n",
        "    # using integration formula for LTI system\n",
        "    P = np.zeros(W.shape)\n",
        "    for i in range(W.shape[0]):\n",
        "        # integrating elementwise\n",
        "        # keep off-diags as 0 to save time with larger martrix, as know there will be some diagonal metric, expect good odds that will find one with Q = I\n",
        "        # will confirm the metric works before moving forward though (done in final function below), to be sure with stability guarantee\n",
        "        def func_to_integrate(t):\n",
        "            og_func = np.exp(np.transpose(W) * t) * Q * np.exp(W * t)\n",
        "            return og_func[i, i]\n",
        "        P[i,i] = scipy.integrate.quad(func_to_integrate, 0, np.inf)[0]\n",
        "    if np.max(np.linalg.eigvals(P)) <= 0:\n",
        "        # guaranteed M will be symmetric as it is definitely diagonal here, but also need it be PD for it to be a valid metric\n",
        "        # should never reach this in theory, but add as a check to be safe\n",
        "        print(\"returned metric not PD, problem!\")\n",
        "        return None\n",
        "    return P\n",
        "\n",
        "# put everything together to get W and M to perform the training with\n",
        "# W is of course part of the network\n",
        "# M is used in finding negative feedback connections between components of W that maintain stability\n",
        "# neither are themselves updated over the course of training\n",
        "def generate_initial_W_M(module_sizes, density, dist_multiplier, post_select_multiplier):\n",
        "    individual_networks = create_modules(module_sizes, density, dist_multiplier, post_select_multiplier)\n",
        "    full_W = combine_W(individual_networks)\n",
        "\n",
        "    # in prior experiments 0.5 * I was metric generally found\n",
        "    # so just use that if it will work - only integrate if necessary\n",
        "    matching_M = 0.5*np.identity(full_W.shape[0])\n",
        "    check_formula = matching_M * (np.abs(full_W) - np.identity(full_W.shape[0])) + np.transpose(np.abs(full_W) - np.identity(full_W.shape[0])) * matching_M\n",
        "    if np.max(np.linalg.eigvals(check_formula)) >= 0:\n",
        "        matching_M = find_M(full_W)\n",
        "\n",
        "        # confirm that M does work to satisfy the Theorem 1 condition with this W\n",
        "        check_formula = matching_M * (np.abs(full_W) - np.identity(full_W.shape[0])) + np.transpose(np.abs(full_W) - np.identity(full_W.shape[0])) * matching_M\n",
        "        if np.max(np.linalg.eigvals(check_formula)) >= 0:\n",
        "            print(\"problem with found metric!\")\n",
        "            return None\n",
        "\n",
        "    # return the final W and M\n",
        "    # need these to be tensors for computation to work - but they aren't parameters!\n",
        "    return torch.from_numpy(full_W).float().cuda(), torch.from_numpy(matching_M).float().cuda()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "He1tUpOKrNG1"
      },
      "source": [
        "# helper functions for training\n",
        "\n",
        "def add_channels(X):\n",
        "    # reshaping necessary when loading the training data\n",
        "    if len(X.shape) == 2:\n",
        "        return X.reshape(X.shape[0], 1, X.shape[1], 1)\n",
        "    elif len(X.shape) == 3:\n",
        "        return X.reshape(X.shape[0], 1, X.shape[1], X.shape[2])\n",
        "    else:\n",
        "        return \"dimenional error\"\n",
        " \n",
        "def exp_lr_scheduler(epoch,\n",
        "                     optimizer,\n",
        "                     strategy='normal',\n",
        "                     decay_eff=0.1,\n",
        "                     decayEpoch=[]):\n",
        "    \"\"\"Decay learning rate by a factor of lr_decay every lr_decay_epoch epochs\"\"\"\n",
        " \n",
        "    if strategy == 'normal':\n",
        "        if epoch in decayEpoch:\n",
        "            for param_group in optimizer.param_groups:\n",
        "                param_group['lr'] *= decay_eff\n",
        "            print('New learning rate is: ', param_group['lr'])\n",
        "    else:\n",
        "        print('wrong strategy')\n",
        "        raise ValueError('A very specific bad thing happened.')\n",
        " \n",
        "    return optimizer\n",
        " \n",
        "# get adjacency matrix specifying which modules should be in negative feedback with each other\n",
        "# (this is part of initialization, we currently always use \"fully connected\" scheme here)\n",
        "def create_random_A(ns,frac_zeros=0):\n",
        "    num_networks = len(ns)\n",
        "    A = torch.cuda.FloatTensor(num_networks, num_networks).uniform_() > frac_zeros\n",
        "    A_tril = torch.tril(A)\n",
        "    A_tril.fill_diagonal_(0)\n",
        "    # note only lower triangular needs to be trained, as bidirectional version of connection determined by negative feedback stability cond\n",
        "    return A_tril\n",
        " \n",
        "def create_mask_given_A(A,ns):\n",
        "    '''\n",
        "    Creates 'hidden' mask for training, given an arbitrary adjacency matrix.\n",
        "    \n",
        "    ARGS:\n",
        "        - A: adjacency matrix\n",
        "        - ns: list of neural population sizes (e.g ns = [5,4,18,2]).\n",
        "    OUTS:\n",
        "        - A: mask\n",
        "    '''\n",
        "\n",
        "    N_nets = len(ns)\n",
        "    mask = []\n",
        "    \n",
        "    for i in range(N_nets):\n",
        "            mask_row_i = torch.cat([torch.ones((ns[i],ns[j])) if A[i,j] == 1 and i >= j else torch.zeros((ns[i],ns[j]))  for j in range(N_nets)],1)\n",
        "            mask.append(mask_row_i)    \n",
        "    \n",
        "    final = torch.cat(mask,0)\n",
        "    final.to(device)\n",
        "    return final"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fqtjWg23qP-e"
      },
      "source": [
        "# get one of the 3 sequential image classification datasets we use\n",
        "# (sequential MNIST, permuted sequential MNIST, and sequential CIFAR10 - accepted arguments are mnist, pmnist, and cifar10 respectively)\n",
        "# always uses the same batch size settings, 64 for train and 1024 for test\n",
        "def getData(name, train_bs=64, test_bs=1024):\n",
        "\n",
        "    if name == 'cifar10':\n",
        "        transform_train = transforms.Compose([\n",
        "        transforms.RandomCrop(32, padding=4),\n",
        "        transforms.RandomHorizontalFlip(),\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "    ])\n",
        "\n",
        "        transform_test = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
        "    ])\n",
        "\n",
        "        train_loader = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)\n",
        "\n",
        "        offset = 2000\n",
        "        rng = np.random.RandomState(1234)\n",
        "        R = rng.permutation(len(train_loader))\n",
        "        lengths = (len(train_loader) - offset, offset)\n",
        "        train_loader, val_loader = [Subset(train_loader, R[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]\n",
        "\n",
        "        # specifying the generator here is necessary for the code to work on colab\n",
        "        generator = torch.Generator(device=device)\n",
        "        generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))\n",
        "        train_loader = torch.utils.data.DataLoader(train_loader, batch_size=train_bs, shuffle=True, generator=generator)\n",
        "        val_loader = torch.utils.data.DataLoader(val_loader, batch_size=test_bs, shuffle=False, generator=generator)\n",
        "        testset = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)\n",
        "        test_loader = torch.utils.data.DataLoader(testset, batch_size=test_bs, shuffle=False, generator=generator)\n",
        "\n",
        "    if name == 'mnist':\n",
        "        train_loader = datasets.MNIST('./data', train=True, download=True,\n",
        "                           transform=transforms.Compose([transforms.ToTensor(),]))\n",
        " \n",
        "        val_loader = datasets.MNIST('./data', train=True, download=True,\n",
        "                           transform=transforms.Compose([transforms.ToTensor(),]))\n",
        " \n",
        "        offset = 2000\n",
        "        rng = np.random.RandomState(1234)\n",
        "        R = rng.permutation(len(train_loader))\n",
        "        lengths = (len(train_loader) - offset, offset)\n",
        "        train_loader, val_loader = [Subset(train_loader, R[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]\n",
        " \n",
        "        # specifying the generator here is necessary for the code to work on colab\n",
        "        generator = torch.Generator(device=device)\n",
        "        generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))\n",
        "        train_loader = torch.utils.data.DataLoader(train_loader, batch_size=train_bs, shuffle=True, generator=generator)\n",
        "        val_loader = torch.utils.data.DataLoader(val_loader, batch_size=test_bs, shuffle=False, generator=generator)\n",
        "        test_loader = torch.utils.data.DataLoader(\n",
        "            datasets.MNIST('./data', train=False, download=False,\n",
        "            transform=transforms.Compose([transforms.ToTensor(),])),\n",
        "            batch_size=test_bs, shuffle=False, generator=generator)\n",
        " \n",
        "    if name == 'pmnist':\n",
        "        trainset = datasets.MNIST(root='./data', train=True, download=True,\n",
        "                            transform=transforms.Compose([transforms.ToTensor(),]))\n",
        "        testset = datasets.MNIST(root='./data', train=False, download=False,\n",
        "                            transform=transforms.Compose([ transforms.ToTensor(),]))\n",
        "        \n",
        "        x_train = trainset.train_data\n",
        "        y_train = trainset.targets\n",
        "        x_test = testset.test_data        \n",
        "        y_test = testset.targets\n",
        " \n",
        "        torch.manual_seed(12008)        \n",
        "        perm = torch.randperm(784)\n",
        "\n",
        "        x_train_permuted = x_train.reshape(x_train.shape[0],-1)\n",
        "        x_train_permuted = x_train_permuted[:, perm]\n",
        "        x_train_permuted = x_train_permuted.reshape(x_train.shape[0], 28, 28)\n",
        "        \n",
        "        x_test_permuted = x_test.reshape(x_test.shape[0],-1)\n",
        "        x_test_permuted = x_test_permuted[:, perm]\n",
        "        x_test_permuted = x_test_permuted.reshape(x_test.shape[0], 28, 28)        \n",
        " \n",
        "        x_train_permuted = add_channels(x_train_permuted)\n",
        "        x_test_permuted = add_channels(x_test_permuted)\n",
        "        \n",
        "        train_loader = torch.utils.data.TensorDataset(x_train_permuted.float(), y_train)\n",
        "        offset = 2000\n",
        "        rng = np.random.RandomState(1234)\n",
        "        R = rng.permutation(len(train_loader))\n",
        "        lengths = (len(train_loader) - offset, offset)\n",
        "        train_loader, val_loader = [Subset(train_loader, R[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]\n",
        " \n",
        "        # specifying the generator here is necessary for the code to work on colab\n",
        "        generator = torch.Generator(device=device)\n",
        "        generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))\n",
        "        train_loader = torch.utils.data.DataLoader(train_loader, batch_size=train_bs, shuffle=True, generator=generator)\n",
        "        val_loader = torch.utils.data.DataLoader(val_loader, batch_size=test_bs, shuffle=False, generator=generator) \n",
        "        test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_permuted.float(), y_test),\n",
        "                                                batch_size=test_bs, shuffle=False, generator=generator)\n",
        " \n",
        "    return train_loader, val_loader, test_loader"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OIQl_59qre9q"
      },
      "source": [
        "# load specified dataset\n",
        "train_loader,val_loader,test_loader = getData(task)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MBXEGSFvxafh"
      },
      "source": [
        "# class definitions for the sparse combo networks\n",
        "\n",
        "class rnnAssemblyCell_Thm1(jit.ScriptModule):\n",
        "    '''\n",
        "    Pytorch module for training the following system:\n",
        "        tau*dx/dt = -x + W*phi(x) + L*x+ u(t) \n",
        "    where tau > 0, phi is a nonlinearity, W is block diagonal, L is some 'contracting' combination matrix and u is some input. \n",
        "       \n",
        "    ''' \n",
        "    \n",
        "    def __init__(self, input_size, hidden_sizes, output_size, alpha, A, density, pre_select_mult, post_select_mult):\n",
        "        super(rnnAssemblyCell_Thm1, self).__init__()\n",
        "        self.input_size = input_size\n",
        "        self.hidden_sizes = hidden_sizes\n",
        "        self.hidden_size = int(np.sum(hidden_sizes))\n",
        "        self.output_size = output_size\n",
        "        self.alpha = alpha \n",
        "        \n",
        "        # initialize linear input and output layers, to be trained, along with biases\n",
        "        self.weight_ih = nn.Parameter(torch.normal(0,1/np.sqrt(self.hidden_size),(self.hidden_size, self.input_size)))        \n",
        "        self.weight_ho = nn.Parameter(torch.normal(0,1/np.sqrt(self.hidden_size),(self.output_size, self.hidden_size)))\n",
        "        self.bias_oh = nn.Parameter(torch.normal(0,1/np.sqrt(self.hidden_size),(1,self.output_size)))\n",
        "        self.bias_hh = nn.Parameter(torch.normal(0,1/np.sqrt(self.hidden_size),(1,self.hidden_size)))               \n",
        " \n",
        "        # specify W and M here based on the random initialization mentioned     \n",
        "        self.W, self.M = generate_initial_W_M(self.hidden_sizes, density, pre_select_mult, post_select_mult)\n",
        "        self.M_inv = torch.inverse(self.M)\n",
        " \n",
        "        # L contains the connections between subsystems. this will be trained. \n",
        "        self.L_mask = create_mask_given_A(A,self.hidden_sizes).bool()\n",
        "        self.L_train = nn.Parameter(self.L_mask*torch.normal(0,1/np.sqrt(np.mean(self.hidden_sizes)),(np.sum(self.hidden_sizes), np.sum(self.hidden_sizes))))\n",
        "        \n",
        "    @jit.script_method\n",
        "    def forward(self, input, state):\n",
        "        # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]\n",
        "        L_masked = self.L_train*self.L_mask\n",
        "        \n",
        "        fx = -state + F.relu(state @ self.W.T + input @ (self.weight_ih.T) + self.bias_hh) + state @ (L_masked.T - (self.M @ L_masked) @ self.M_inv) \n",
        "              \n",
        "        hx =  state + self.alpha*fx\n",
        "        \n",
        "        hy = hx @ (self.weight_ho.T)\n",
        " \n",
        "        return hy, hx\n",
        " \n",
        "class rnnAssemblyLayer_Thm1(jit.ScriptModule):\n",
        "    def __init__(self, cell, *cell_args):\n",
        "        super(rnnAssemblyLayer_Thm1, self).__init__()\n",
        "        self.cell = cell(*cell_args)\n",
        " \n",
        "    @jit.script_method\n",
        "    def forward(self, input):\n",
        "        # type: (Tensor) -> Tuple[Tensor, Tensor]\n",
        "        state = 0*0.1*torch.randn(input.shape[0],self.cell.hidden_size, device='cuda')\n",
        "        \n",
        "        inputs = input.unbind(1)              \n",
        "        \n",
        "        outputs = torch.jit.annotate(List[Tensor], [])\n",
        "        for i in range(len(inputs)):\n",
        "            out, state = self.cell(inputs[i], state)\n",
        "            outputs += [out]\n",
        " \n",
        "        return torch.stack(outputs).permute(1,0,2), state"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G8TZyXiLr4-Q"
      },
      "source": [
        "# training function for sequential CIFAR\n",
        "# next cell contains the slightly different function for when training is on one of the MNIST tasks instead\n",
        "def train_CIFAR(rnn,optimizer,train_loader,test_loader,max_epoch,decay_eff,decay_epochs):    \n",
        "    '''\n",
        "    Main training loop.\n",
        "    '''\n",
        " \n",
        "    optim_params = list(rnn.parameters())\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    test_accs = []\n",
        "    train_losses = []\n",
        "    epochs_list = [] # just grabbing numbers for the sake of dataframe\n",
        " \n",
        "    # Train for some epochs\n",
        "    for epoch in tqdm(range(max_epoch),total = max_epoch):\n",
        "        rnn.train()\n",
        "        loss_epoch = []\n",
        "        for i,(inp,target) in tqdm(enumerate(train_loader),total = len(train_loader)):    \n",
        "            optimizer.zero_grad()           \n",
        "            \n",
        "            # had to add these lines for colab GPU\n",
        "            inp = inp.cuda() \n",
        "            target = target.cuda()\n",
        "\n",
        "            # CIFAR-10 is 32x32, so 1024 pixels\n",
        "            # it is also color, so the original format would be (batch_size, 3, 32, 32)\n",
        "            # for the RNN input, want input size to be the last variable though\n",
        "            output,_ = rnn(inp.view(-1, 3, int(1024)).permute(0,2,1))\n",
        "            \n",
        "            loss = criterion(output[:,-1,:], target.to(device))\n",
        "            loss_epoch.append(loss.item())\n",
        "            loss.backward()        \n",
        "            optimizer.step()\n",
        " \n",
        "        # track loss over time, and have it reflect the mean loss over the batches so it is more reflective of training trends than last batch loss\n",
        "        mean_loss = np.mean(loss_epoch)\n",
        "        train_losses.append(mean_loss)\n",
        "\n",
        "        # calling the network with no training epoch 0, want each epoch number to reflect how many have been run so far, so add 1 here\n",
        "        print('Epoch {}, mean batch loss {}'.format(epoch+1, mean_loss))\n",
        "\n",
        "        # use specified scalar to learning rate after specified epochs\n",
        "        # (subtracting 1 from each input epoch to again account for epoch naming convention)\n",
        "        optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=decay_eff, decayEpoch=[x-1 for x in decay_epochs])\n",
        " \n",
        "        # testing is fast so no problem doing it every epoch \n",
        "        rnn.eval()\n",
        "        with torch.no_grad():\n",
        "            total = 0\n",
        "            correct = 0\n",
        "            for inp,target in test_loader:\n",
        "                inp = inp.cuda() # adding here too\n",
        "                target = target.cuda()\n",
        "                # reformat input in same way for CIFAR here\n",
        "                output,_ = rnn(inp.view(-1, 3, int(1024)).permute(0,2,1))\n",
        " \n",
        "                # the class with the highest energy is what we choose as prediction\n",
        "                _, predicted = torch.max(output[:,-1,:].data, 1)\n",
        "                total += target.size(0)\n",
        "                correct += (predicted == target).sum().item()\n",
        "            print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
        "                  100 * correct / total))\n",
        "            test_accs.append((100.0 * float(correct) / float(total))) # for the actual list make sure we are getting exact test accuracy, so convert to floats!\n",
        "        epochs_list.append(epoch+1) # end of epoch so label +1\n",
        "\n",
        "        # save model for every epoch as storage space required is quite small, can have training disruptions with colab\n",
        "        model_path = os.path.join(output_root, model_name + \"-cifar10-\" + \"epoch\" + str(epoch+1) + \".pt\") # end of epoch so label +1\n",
        "        rnn.save(model_path) # would use torch.jit.load to reload in the future\n",
        "\n",
        "        # save stats so far, will overwrite every time as rows are added to dataframe\n",
        "        stats_path = os.path.join(output_root, model_name + \"-cifar10-stats.csv\")\n",
        "        cur_stats = pd.DataFrame()\n",
        "        cur_stats[\"epoch\"] = epochs_list\n",
        "        cur_stats[\"loss\"] = train_losses\n",
        "        cur_stats[\"test-acc\"] = test_accs\n",
        "        cur_stats.to_csv(stats_path, index=False)\n",
        "        \n",
        "    return rnn, optimizer, test_accs, train_losses"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "egqmOxE4WX-y"
      },
      "source": [
        "# analogous training function for sequential or permuted MNIST\n",
        "def train_MNIST(rnn,optimizer,train_loader,test_loader,max_epoch,decay_eff,decay_epochs,task):    \n",
        "    '''\n",
        "    Main training loop.\n",
        "    '''\n",
        " \n",
        "    optim_params = list(rnn.parameters())\n",
        "    criterion = nn.CrossEntropyLoss()\n",
        "    test_accs = []\n",
        "    train_losses = []\n",
        "    epochs_list = [] # just grabbing numbers for the sake of dataframe\n",
        " \n",
        "    # Train for some epochs\n",
        "    for epoch in tqdm(range(max_epoch),total = max_epoch):\n",
        "        rnn.train()\n",
        "        loss_epoch = []\n",
        "        for i,(inp,target) in tqdm(enumerate(train_loader),total = len(train_loader)):    \n",
        "            optimizer.zero_grad()           \n",
        "            \n",
        "            # had to add these lines for colab GPU\n",
        "            inp = inp.cuda() \n",
        "            target = target.cuda()\n",
        "\n",
        "            output,_ = rnn(inp.view(-1, int(784)).unsqueeze(dim = 2))\n",
        "            \n",
        "            loss = criterion(output[:,-1,:], target.to(device))\n",
        "            loss_epoch.append(loss.item())\n",
        "            loss.backward()        \n",
        "            optimizer.step()\n",
        " \n",
        "        # track loss over time, and have it reflect the mean loss over the batches so it is more reflective of training trends than last batch loss\n",
        "        mean_loss = np.mean(loss_epoch)\n",
        "        train_losses.append(mean_loss)\n",
        "\n",
        "        # calling the network with no training epoch 0, want each epoch number to reflect how many have been run so far, so add 1 here\n",
        "        print('Epoch {}, mean batch loss {}'.format(epoch+1, mean_loss))\n",
        "\n",
        "        # use specified scalar to learning rate after specified epochs\n",
        "        # (subtracting 1 from each input epoch to again account for epoch naming convention)\n",
        "        optimizer = exp_lr_scheduler(epoch, optimizer, decay_eff=decay_eff, decayEpoch=[x-1 for x in decay_epochs])\n",
        " \n",
        "        # testing is fast so no problem doing it every epoch \n",
        "        rnn.eval()\n",
        "        with torch.no_grad():\n",
        "            total = 0\n",
        "            correct = 0\n",
        "            for inp,target in test_loader:\n",
        "                inp = inp.cuda() # adding here too\n",
        "                target = target.cuda()\n",
        "                output,_ = rnn(inp.view(-1, int(784)).unsqueeze(dim = 2))\n",
        " \n",
        "                # the class with the highest energy is what we choose as prediction\n",
        "                _, predicted = torch.max(output[:,-1,:].data, 1)\n",
        "                total += target.size(0)\n",
        "                correct += (predicted == target).sum().item()\n",
        "            print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
        "                  100 * correct / total))\n",
        "            test_accs.append((100.0 * float(correct) / float(total))) # for the actual list make sure we are getting exact test accuracy, so convert to floats!\n",
        "        epochs_list.append(epoch+1) # end of epoch so label +1\n",
        "\n",
        "        # save model for every epoch as storage space required is quite small, can have training disruptions with colab\n",
        "        model_path = os.path.join(output_root, model_name + \"-\" + task + \"-\" + \"epoch\" + str(epoch+1) + \".pt\") # end of epoch so label +1\n",
        "        rnn.save(model_path) # would use torch.jit.load to reload in the future\n",
        "\n",
        "        # save stats so far, will overwrite every time as rows are added to dataframe\n",
        "        stats_path = os.path.join(output_root, model_name + \"-\" + task + \"-stats.csv\")\n",
        "        cur_stats = pd.DataFrame()\n",
        "        cur_stats[\"epoch\"] = epochs_list\n",
        "        cur_stats[\"loss\"] = train_losses\n",
        "        cur_stats[\"test-acc\"] = test_accs\n",
        "        cur_stats.to_csv(stats_path, index=False)\n",
        "        \n",
        "    return rnn, optimizer, test_accs, train_losses"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "o5hzUm9kr9vN"
      },
      "source": [
        "'''\n",
        "Define network and use main training loop\n",
        "'''\n",
        " \n",
        "# setup the static RNN settings now\n",
        "dt_rnn = .03\n",
        "tau_rnn = 1\n",
        "alpha = dt_rnn/tau_rnn\n",
        "output_size = 10 # always 10 possible labels\n",
        "\n",
        "# other settings that vary only based on above settings\n",
        "if task == \"cifar10\":\n",
        "    input_size = 3 # rgb\n",
        "else:\n",
        "    input_size = 1 # black and white\n",
        "A = create_random_A(ns) # because using fully connected A not actually random, this function will just always generate an all 1 lower triangular matrix based on input sizes\n",
        "\n",
        "# instantiate a network and optimizer\n",
        "rnn = rnnAssemblyLayer_Thm1(rnnAssemblyCell_Thm1,input_size,ns,output_size,alpha,A,rnn_density,pre_scalar,post_scalar)\n",
        "rnn.to(device)\n",
        " \n",
        "# setup the optimizer\n",
        "optim_params = list(rnn.parameters()) \n",
        "optimizer = torch.optim.Adam(optim_params, lr=lr, weight_decay=weight_decay)\n",
        " \n",
        "# save initialization of network\n",
        "model_path = os.path.join(output_root, model_name + \"-\" + task + \"-\" + \"epoch0.pt\")\n",
        "rnn.save(model_path)\n",
        " \n",
        "# train the initialized network with the loaded dataset, using appropriate training function\n",
        "if task == \"cifar10\":\n",
        "    rnn,optimizer,test_accs,train_losses = train_CIFAR(rnn,optimizer,train_loader,test_loader,num_epochs,lr_scalar,lr_scale_epochs)\n",
        "else:\n",
        "    rnn,optimizer,test_accs,train_losses = train_MNIST(rnn,optimizer,train_loader,test_loader,num_epochs,lr_scalar,lr_scale_epochs,task)\n",
        "# next two cells will then plot the found accuracies and losses"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sYO8-zoId5S5"
      },
      "source": [
        "plt.plot(test_accs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tb2JHCIhd7eE"
      },
      "source": [
        "plt.plot(train_losses)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}