{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0787a534",
   "metadata": {},
   "source": [
    "This notebook illustrates how to train Fully Connected models with FTP. We train and test the model on MNIST."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae7a36dc",
   "metadata": {},
   "source": [
    "#### Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a56a3e6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from torch.autograd import Variable\n",
    "\n",
    "import copy\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab6701be",
   "metadata": {},
   "source": [
    "#### Define Network architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d192b26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# models with Dropout\n",
    "class NetFC1x1024x128mn(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc1 = nn.Linear(28*28*1,1024,bias=False)  # 32x32x3 for CIFAR10\n",
    "        self.fc2 = nn.Linear(1024, 128,bias=False)\n",
    "        self.fc3 = nn.Linear(128, 10,bias=False)\n",
    "        \n",
    "        # initialize the layers using the He uniform initialization scheme\n",
    "        fc1_nin = 28*28*1 # Note: if dataset is CIFAR10 --> fc1_nin = 32*32*3\n",
    "        fc1_limit = np.sqrt(6.0 / fc1_nin)\n",
    "        torch.nn.init.uniform_(self.fc1.weight, a=-fc1_limit, b=fc1_limit)\n",
    "        fc2_nin = 1024\n",
    "        fc2_limit = np.sqrt(6.0 / fc2_nin)\n",
    "        torch.nn.init.uniform_(self.fc2.weight, a=-fc2_limit, b=fc2_limit)\n",
    "        fc3_nin = 128\n",
    "        fc3_limit = np.sqrt(6.0 / fc3_nin)\n",
    "        torch.nn.init.uniform_(self.fc3.weight, a=-fc3_limit, b=fc3_limit)\n",
    "        \n",
    "\n",
    "    def forward(self, x, do_masks):\n",
    "        x = F.tanh(self.fc1(x))\n",
    "        # apply dropout --> we use a custom dropout implementation because we need to present the same dropout mask in the two forward passes\n",
    "        if do_masks is not None:\n",
    "            x = x * do_masks[0]   \n",
    "\n",
    "        x = F.tanh(self.fc2(x))\n",
    "        # apply dropout --> we use a custom dropout implementation because we need to present the same dropout mask in the two forward passes\n",
    "        if do_masks is not None:\n",
    "            x = x * do_masks[1]\n",
    "        x = F.softmax(self.fc3(x))\n",
    "        return x\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ad144c5",
   "metadata": {},
   "source": [
    "#### Set hyperparameters and train+test the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2a085cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set hyperparameters\n",
    "## learning rate\n",
    "eta = 0.01  \n",
    "## dropout keep rate\n",
    "keep_rate = 0.9\n",
    "## loss --> used to monitor performance, but not for parameter updates (PEPITA does not backpropagate the loss)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "## optimizer (choose 'SGD' o 'mom')\n",
    "optim = 'mom' # --> default in the paper\n",
    "if optim == 'SGD':\n",
    "    gamma = 0\n",
    "elif optim == 'mom':\n",
    "    gamma = 0.9\n",
    "## batch size\n",
    "batch_size = 64 # --> default in the paper\n",
    "\n",
    "# initialize the network\n",
    "net = NetFC1x1024x128mn()\n",
    "\n",
    "# load the dataset, replace MNIST by CIFAR10 for CIFAR10\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor()]) # this normalizes to [0,1]\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "\n",
    "# define function to register the activations --> we need this to compare the activations in the two forward passes\n",
    "activation = {}\n",
    "def get_activation(name):\n",
    "    def hook(model, input, output):\n",
    "        activation[name] = output.detach()\n",
    "    return hook\n",
    "for name, layer in net.named_modules():\n",
    "    layer.register_forward_hook(get_activation(name))\n",
    "\n",
    "\n",
    "# define G --> this is the G projection matrix \n",
    "nin = 1024\n",
    "sd = np.sqrt(6/nin)\n",
    "G = (torch.rand(nin,10)*2*sd-sd)*0.05  # G is initialized with the He uniform initialization (like the forward weights)\n",
    "\n",
    "\n",
    "# check cosine similarity before training AND matrix norm\n",
    "angles = []\n",
    "w_all = []\n",
    "norm_w0 = []\n",
    "for l_idx,w in enumerate(net.parameters()):\n",
    "    with torch.no_grad():\n",
    "        w_all.append(copy.deepcopy(w))\n",
    "        if l_idx == 0:\n",
    "            norm_w0.append(torch.norm(w))\n",
    "        print('norm of w at layer {} is {}'.format(l_idx,torch.norm(w)))\n",
    "w_prod = w_all[1].T\n",
    "for idx in range(2,len(w_all)):\n",
    "    w_prod = torch.matmul(w_prod,w_all[idx].T)\n",
    "    print(w_prod.size())\n",
    "\n",
    "# do one forward pass to get the activation size needed for setting up the dropout masks\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = next(dataiter)\n",
    "images = torch.flatten(images, 1) # flatten all dimensions except batch        \n",
    "outputs = net(images,do_masks=None)\n",
    "layers_act = []\n",
    "for key in activation:\n",
    "    if 'fc' in key or 'conv' in key:\n",
    "        layers_act.append(F.tanh(activation[key]))\n",
    "        \n",
    "# set up for momentum\n",
    "if optim == 'mom':\n",
    "    gamma = 0.9\n",
    "    v_w_all = []\n",
    "    for l_idx,w in enumerate(net.parameters()):\n",
    "        if len(w.shape)>1:\n",
    "            with torch.no_grad():\n",
    "                v_w_all.append(torch.zeros(w.shape))\n",
    "\n",
    "# Train and test the model\n",
    "test_accs = []\n",
    "for epoch in range(100):  # loop over the dataset multiple times\n",
    "    \n",
    "    # learning rate decay\n",
    "    if epoch in [60,90]: \n",
    "        eta = eta*0.1\n",
    "        print('eta decreased to ',eta)\n",
    "    \n",
    "    # loop over batches\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, target = data\n",
    "        inputs = torch.flatten(inputs, 1) # flatten all dimensions except batch\n",
    "        target_onehot = F.one_hot(target,num_classes=10)\n",
    "        \n",
    "        # create dropout mask for the two forward passes --> we need to use the same mask for the two passes\n",
    "        do_masks = []\n",
    "        if keep_rate < 1:\n",
    "            for l in layers_act[:-1]:\n",
    "                input1 = l\n",
    "                do_mask = Variable(torch.ones(inputs.shape[0],input1.data.new(input1.data.size()).shape[1]).bernoulli_(keep_rate))/keep_rate\n",
    "                do_masks.append(do_mask)\n",
    "            do_masks.append(1) # for the last layer we don't use dropout --> just set a scalar 1 (needed for when we register activation layer)\n",
    "     \n",
    "        # forward pass 1 with original input --> keep track of activations\n",
    "        outputs = net(inputs,do_masks)\n",
    "        layers_act = []\n",
    "        cnt_act = 0\n",
    "        for key in activation:\n",
    "            if 'fc' in key or 'conv' in key:\n",
    "                layers_act.append(F.tanh(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask\n",
    "                cnt_act += 1\n",
    "                \n",
    "        # Convert target_onehot to float\n",
    "        target_onehot = target_onehot.float()\n",
    "\n",
    "        # Convert outputs to float\n",
    "        outputs = outputs.float()\n",
    "        error = outputs - target_onehot  \n",
    "        error_input = F.tanh(target_onehot @ G.T) - F.tanh(outputs @ G.T)\n",
    "\n",
    "        \n",
    "        target_h1 = F.tanh(activation[\"fc1\"]) + error_input\n",
    "        \n",
    "        # forward pass 2 with targets --> keep track of target estimations\n",
    "        with torch.no_grad():\n",
    "            h_mod= {}\n",
    "            # Step 1: Get first-layer target\n",
    "            h1_mod = target_h1 \n",
    "            if do_masks is not None:\n",
    "                h1_mod = h1_mod * do_masks[0]\n",
    "            h_mod['0'] = h1_mod\n",
    "            # Step 2: Forward through fc2\n",
    "            h2_mod = F.tanh(net.fc2(h1_mod))\n",
    "            if do_masks is not None:\n",
    "                h2_mod = h2_mod * do_masks[1]\n",
    "            h_mod['1'] = h2_mod\n",
    "\n",
    "            # Step 3: Forward through fc3\n",
    "            mod_outputs = F.softmax(net.fc3(h2_mod), dim=1)\n",
    "            h_mod['2'] = mod_outputs\n",
    "\n",
    "\n",
    "        mod_layers_act = []\n",
    "        cnt_act = 0\n",
    "        for key in activation:\n",
    "            if 'fc' in key or 'conv' in key:\n",
    "                mod_layers_act.append(h_mod[str(cnt_act)]) \n",
    "                cnt_act += 1\n",
    "        mod_error = mod_outputs - target_onehot\n",
    "        \n",
    "        # compute the delta_w for the batch\n",
    "        delta_w_all = []\n",
    "        v_w = []\n",
    "        for l_idx,w in enumerate(net.parameters()):\n",
    "            v_w.append(torch.zeros(w.shape))\n",
    "            \n",
    "        for l in range(len(layers_act)):\n",
    "            \n",
    "            # update for the last layer\n",
    "            if l == len(layers_act)-1:\n",
    "                \n",
    "                if len(layers_act)>1:\n",
    "                    delta_w = -error.T @ layers_act[-2]\n",
    "                else:\n",
    "                    delta_w = -error.T @ inputs\n",
    "            \n",
    "            # update for the first layer\n",
    "            elif l == 0:\n",
    "                delta_w = -((layers_act[l] - mod_layers_act[l]).T * (1-layers_act[l].T**2)) @ inputs\n",
    "            \n",
    "            # update for the hidden layers (not first, not last)\n",
    "            elif l>0 and l<len(layers_act)-1:\n",
    "                delta_w = -(layers_act[l] - mod_layers_act[l]).T * (1-layers_act[l].T**2) @ layers_act[l-1]\n",
    "            \n",
    "            delta_w_all.append(delta_w)\n",
    "                \n",
    "        # apply the weight change\n",
    "        if optim == 'SGD':\n",
    "            for l_idx,w in enumerate(net.parameters()):\n",
    "                with torch.no_grad():\n",
    "                    w += eta * delta_w_all[l_idx]/batch_size # specify for which layer\n",
    "                    \n",
    "        elif optim == 'mom':\n",
    "            for l_idx,w in enumerate(net.parameters()):\n",
    "                with torch.no_grad():\n",
    "                    v_w_all[l_idx] = gamma * v_w_all[l_idx] + eta * delta_w_all[l_idx]/batch_size\n",
    "                    w += v_w_all[l_idx]\n",
    "                    \n",
    "        \n",
    "        # keep track of the loss\n",
    "        loss = criterion(outputs, target)\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "        if i%500 == 499:\n",
    "            print('[%d, %5d] loss: %.3f' %\n",
    "                  (epoch + 1, i + 1, running_loss / 500))\n",
    "            running_loss = 0.0\n",
    "                    \n",
    "    print('Testing...')\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    # since we're not training, we don't need to calculate the gradients for our outputs\n",
    "    with torch.no_grad():\n",
    "        for test_data in testloader:\n",
    "            test_images, test_labels = test_data\n",
    "            test_images = torch.flatten(test_images, 1) # flatten all dimensions except batch\n",
    "            # calculate outputs by running images through the network\n",
    "            test_outputs = net(test_images,do_masks=None)\n",
    "            # the class with the highest energy is what we choose as prediction\n",
    "            _, predicted = torch.max(test_outputs.data, 1)\n",
    "            total += test_labels.size(0)\n",
    "            correct += (predicted == test_labels).sum().item()\n",
    "\n",
    "    print('Test accuracy epoch {}: {} %'.format(epoch, 100 * correct / total))\n",
    "    test_accs.append(100 * correct / total)\n",
    "\n",
    "print('Finished Training')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
