{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code here presents evidence of implicit regularization in two-layer BN models, as presented in Section 4.1 of the paper. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preliminaries and Function Definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "from datetime import datetime\n",
    "import pickle\n",
    "import re\n",
    "import gc\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.checkpoint as cp\n",
    "from collections import OrderedDict\n",
    "from torch import Tensor\n",
    "from torch.jit.annotations import List\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.autograd import Variable\n",
    "import importlib\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable, axes_size\n",
    "from matplotlib.colors import ListedColormap, LinearSegmentedColormap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class PrepareData(torch.utils.data.Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        if not torch.is_tensor(X):\n",
    "            self.X = torch.from_numpy(X)\n",
    "        else:\n",
    "            self.X = X\n",
    "        if not torch.is_tensor(y):\n",
    "            self.y = torch.from_numpy(y)\n",
    "        else:\n",
    "            self.y = y\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "    \n",
    "def one_hot(labels, num_classes=10):\n",
    "    \"\"\"Embedding labels to one-hot form.\n",
    "\n",
    "    Args:\n",
    "      labels: (LongTensor) class labels, sized [N,].\n",
    "      num_classes: (int) number of classes.\n",
    "\n",
    "    Returns:\n",
    "      (tensor) encoded labels, sized [N, #classes].\n",
    "    \"\"\"\n",
    "    y = torch.eye(num_classes) \n",
    "    return y[labels] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# functions for solving the standard nonconvex batchnorm problem\n",
    "class BatchNormNet(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, d_out, momentum=1.0, bias=True):\n",
    "        super(BatchNormNet, self).__init__()\n",
    "        self.layer1 = torch.nn.Linear(d, num_neurons, bias=False)\n",
    "        self.gamma = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=True)\n",
    "        torch.nn.init.ones_(self.gamma)\n",
    "        self.eta = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=bias)\n",
    "        torch.nn.init.zeros_(self.eta)\n",
    "        \n",
    "        self.act = torch.nn.ReLU()\n",
    "        self.layer2 = torch.nn.Linear(num_neurons, d_out, bias=False)\n",
    "        \n",
    "        self.momentum = momentum\n",
    "        \n",
    "        self.mean = torch.nn.Parameter(data=torch.zeros(num_neurons), requires_grad=False)\n",
    "        self.std = torch.nn.Parameter(data=torch.ones(num_neurons), requires_grad=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.training:\n",
    "            out = self.layer1(x)\n",
    "            next_mean = torch.mean(out, 0)\n",
    "            self.mean = torch.nn.Parameter(data =self.momentum*next_mean + (1-self.momentum)*self.mean.data, requires_grad=False)\n",
    "            \n",
    "            out = out - next_mean\n",
    "            next_std = torch.norm(out, dim=0)\n",
    "            self.std = torch.nn.Parameter(data=self.momentum*next_std + (1-self.momentum)*self.std.data, requires_grad=False)\n",
    "            \n",
    "            out = out/next_std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.gamma*self.eta))\n",
    "        else:\n",
    "            out = self.layer1(x)\n",
    "            out = out - self.mean\n",
    "            out = out/self.std\n",
    "            out = self.layer2(self.act(out*self.gamma + self.gamma*self.eta))\n",
    "\n",
    "        return out\n",
    "\n",
    "def loss_func(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    # standard weight decay regularization \n",
    "    for p in model.named_parameters():\n",
    "        if p[0] in ['gamma', 'eta', 'layer2.weight']:\n",
    "            loss = loss + beta/2 * torch.norm(p[1])**2\n",
    "    return loss\n",
    "\n",
    "def validation(model, testloader, beta, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float().to(device)\n",
    "        _y = Variable(_y).float().to(device)\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        if _y.shape[1] == 1:\n",
    "            test_correct += torch.eq(yhat>0.5, _y).float().sum()\n",
    "        else:\n",
    "            test_correct += torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "def batchnorm_full_gd(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, verbose=False, batch_size=None, bn_momentum=1.0, \n",
    "                      print_freq=5, bias=True, device='cpu', optimizer='SGD'):\n",
    "    device = torch.device(device)\n",
    "    # D_in is input dimension, H is hidden dimension, D_out is output dimension.\n",
    "    D_in, H, D_out = A_train.shape[1], num_neurons, y_train.shape[1]\n",
    "    \n",
    "    # use gd if no batch size is passed in \n",
    "    if batch_size is None:\n",
    "        batch_size = A_train.shape[0]\n",
    "        \n",
    "    # create the model\n",
    "    model = BatchNormNet(D_in, H, D_out, momentum=bn_momentum, bias=bias).to(device)\n",
    "    if optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)\n",
    "    elif optimizer=='Adagrad':\n",
    "         optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "        \n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=10, \n",
    "                                                           verbose=verbose)\n",
    "    \n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(A_train.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData(X=A_train, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=A_test, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation(model, ds_test, beta, device) # loss on the entire test set\n",
    "    \n",
    "    u, s, v = torch.svd(A_train-torch.mean(A_train, dim=0))\n",
    "    sv = (torch.diag(s)@ v.t()).to(device)\n",
    "    init_first_layer = sv @ model.layer1.weight.data.clone().t()\n",
    "    init_first_layer = init_first_layer / torch.norm(init_first_layer, dim = 1).unsqueeze(1)\n",
    "    first_layer_weight_diffs = np.zeros((num_epochs,D_in))\n",
    "    \n",
    "    iter_no = 0\n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            \n",
    "            \n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func(yhat, _y, model, beta)\n",
    "            if D_out == 1:\n",
    "                correct = torch.eq(yhat>0.5, _y).float().sum()\n",
    "            else:\n",
    "                correct = torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "        \n",
    "        scheduler.step(loss.item())\n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation(model, ds_test, beta, device) # loss on the entire test set\n",
    "        curr_first_layer = sv@model.layer1.weight.data.clone().t()\n",
    "        curr_first_layer = curr_first_layer / torch.norm(curr_first_layer, dim=1).unsqueeze(1)\n",
    "        first_layer_weight_diffs[i] = torch.sum(curr_first_layer*init_first_layer, dim=1).cpu().numpy()\n",
    "\n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "    return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, first_layer_weight_diffs, model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# equivalent whitened (still nonconvex) model\n",
    "class FCNet(torch.nn.Module):\n",
    "    def __init__(self, d, num_neurons, d_out=1,bias=True, optimizer='SGD', wu=None):\n",
    "        super(FCNet, self).__init__()\n",
    "        self.layer1 = torch.nn.Linear(d, num_neurons, bias=False)\n",
    "        self.eta = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=bias)\n",
    "        torch.nn.init.zeros_(self.eta)\n",
    "        \n",
    "        self.gamma = torch.nn.Parameter(data=torch.empty(num_neurons), requires_grad=True)\n",
    "        torch.nn.init.ones_(self.gamma)\n",
    "        \n",
    "        self.act = torch.nn.ReLU()\n",
    "        self.layer2 = torch.nn.Linear(num_neurons, d_out, bias=False)\n",
    "        \n",
    "        if optimizer!='SGD':\n",
    "            self.wu_factor = 1\n",
    "        else:\n",
    "            self.wu_factor = wu\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.layer2(self.act(self.gamma*self.layer1(x)/torch.norm(self.layer1.weight, dim=1) + self.eta/self.wu_factor))\n",
    "        return out\n",
    "\n",
    "def loss_func_no_bn(yhat, y, model, beta):\n",
    "    loss = 0.5 * torch.norm(yhat - y)**2\n",
    "    \n",
    "    # standard weight-decay\n",
    "    for p in model.named_parameters():\n",
    "        if p[0] in ['gamma', 'layer2.weight']:\n",
    "            loss = loss + beta/2 * torch.norm(p[1])**2\n",
    "        elif p[0] == 'eta':\n",
    "            loss = loss + beta/2*torch.norm(p[1]/model.wu_factor)**2\n",
    "    return loss\n",
    "\n",
    "def validation_no_bn(model, testloader, beta, device):\n",
    "    test_loss = 0\n",
    "    test_correct = 0\n",
    "    \n",
    "    model.eval()\n",
    "\n",
    "    for ix, (_x, _y) in enumerate(testloader):\n",
    "        _x = Variable(_x).float().to(device)\n",
    "        _y = Variable(_y).float().to(device)\n",
    "\n",
    "        output = model.forward(_x)\n",
    "        yhat = model(_x).float()\n",
    "\n",
    "        loss = loss_func_no_bn(yhat, _y, model, beta)\n",
    "\n",
    "        test_loss += loss.item()\n",
    "        if _y.shape[1] == 1:\n",
    "            test_correct += torch.eq(yhat>0.5, _y).float().sum()\n",
    "        else:\n",
    "            test_correct += torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "\n",
    "    return test_loss, test_correct\n",
    "\n",
    "# take data matrices A_train and A_test before whitening (whitening occurs inside function)\n",
    "def no_batchnorm_full_gd(A_train, y_train, A_test, y_test, num_epochs, num_neurons, beta, \n",
    "                       learning_rate, verbose=False, batch_size=None, print_freq=5, \n",
    "                         bias=True, momentum=0.9,\n",
    "                        device='cpu', optimizer='SGD'):\n",
    "    device = torch.device(device)\n",
    "    \n",
    "    print('preparing data')\n",
    "    # subtract means \n",
    "    train_means = torch.mean(A_train, dim=0)\n",
    "    A_train = A_train - train_means\n",
    "    A_test = A_test - train_means\n",
    "    d = A_train.shape[1]\n",
    "    \n",
    "    # PCA whitening--u is our new data matrix\n",
    "    u, s, v = torch.svd(A_train)\n",
    "    wu = torch.mean(s**2)\n",
    "    test_data = A_test @ v @  torch.diag(1/s)\n",
    "    train_data = u\n",
    "    \n",
    "    D_in, H, D_out = train_data.shape[1], num_neurons, y_train.shape[1]\n",
    "    if batch_size is None:\n",
    "        batch_size = A_train.shape[0]\n",
    "    # create the model\n",
    "    model = FCNet(D_in, H, D_out, bias=bias, optimizer=optimizer, wu=wu).to(device)\n",
    "    if optimizer=='SGD':\n",
    "        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)#,\n",
    "    elif optimizer=='Adagrad':\n",
    "         optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)\n",
    "    else:\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, \n",
    "                                                           mode='min', \n",
    "                                                           factor=0.5, \n",
    "                                                           patience=10, \n",
    "                                                           verbose=verbose)\n",
    "\n",
    "    # arrays for saving the loss and accuracy    \n",
    "    losses = np.zeros((int(num_epochs*np.ceil(A_train.shape[0] / batch_size))))\n",
    "    accs = np.zeros(losses.shape)\n",
    "    losses_test = np.zeros((num_epochs+1))\n",
    "    accs_test = np.zeros((num_epochs+1))\n",
    "    times = np.zeros((losses.shape[0]+1))\n",
    "    times[0] = time.time()\n",
    "    \n",
    "    # dataset loaders (minibatch)\n",
    "    ds = PrepareData(X=train_data, y=y_train)\n",
    "    ds = DataLoader(ds, batch_size=batch_size, shuffle=True)\n",
    "    ds_test = PrepareData(X=test_data, y=y_test)\n",
    "    ds_test = DataLoader(ds_test, batch_size=A_test.shape[0], shuffle=False)\n",
    "    \n",
    "    losses_test[0], accs_test[0] = validation_no_bn(model, ds_test, beta, device) # loss on the entire test set\n",
    "    \n",
    "    init_first_layer = model.layer1.weight.data.clone().t()\n",
    "    init_first_layer = init_first_layer / torch.norm(init_first_layer, dim = 1).unsqueeze(1)\n",
    "\n",
    "    \n",
    "    first_layer_weight_diffs = np.zeros((num_epochs, D_in))\n",
    "\n",
    "    iter_no = 0\n",
    "    \n",
    "    print('starting training')\n",
    "    for i in range(num_epochs):\n",
    "        model.train()\n",
    "        for ix, (_x, _y) in enumerate(ds):\n",
    "            _x = Variable(_x).float().to(device)\n",
    "            _y = Variable(_y).float().to(device)\n",
    "            \n",
    "            \n",
    "            yhat = model(_x).float()\n",
    "            \n",
    "            loss = loss_func_no_bn(yhat, _y, model, beta)\n",
    "            if _y.shape[1] == 1:\n",
    "                correct = torch.eq(yhat>0.5, _y).float().sum()\n",
    "            else:\n",
    "                correct = torch.eq(torch.argmax(yhat, 1), torch.argmax(_y, 1)).float().sum()\n",
    "            optimizer.zero_grad() # zero the gradients on each pass before the update\n",
    "            loss.backward() # backpropagate the loss through the model\n",
    "            optimizer.step() # update the gradients w.r.t the loss\n",
    "\n",
    "            losses[iter_no] = loss.item() # loss on the minibatch\n",
    "            accs[iter_no] = correct\n",
    "        \n",
    "            iter_no += 1\n",
    "            times[iter_no] = time.time()\n",
    "            \n",
    "        scheduler.step(loss.item())\n",
    "        \n",
    "        # get test loss and accuracy\n",
    "        losses_test[i+1], accs_test[i+1] = validation_no_bn(model, ds_test, beta, device) # loss on the entire test set\n",
    "        curr_first_layer= model.layer1.weight.data.t()/torch.norm(model.layer1.weight.data.t(), dim=1).unsqueeze(1) # d x P\n",
    "        first_layer_weight_diffs[i] = torch.sum(curr_first_layer*init_first_layer, dim=1).cpu().numpy()\n",
    "        \n",
    "        if i % print_freq == 0:\n",
    "            print(\"Epoch [{}/{}], loss: {} acc: {}, test loss: {} test acc: {}\".format(i, num_epochs,\n",
    "                    np.round(losses[iter_no-1], 3), np.round(accs[iter_no-1]/batch_size, 3), \n",
    "                    np.round(losses_test[i+1], 3), np.round(accs_test[i+1]/A_test.shape[0], 3)))\n",
    "            \n",
    "    return losses, accs/batch_size, losses_test, accs_test/A_test.shape[0], times, first_layer_weight_diffs, model.to('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load CIFAR-10 Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = '' # CHANGE DIRECTORY HERE\n",
    "normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))\n",
    "train_dataset = datasets.CIFAR10(\n",
    "    directory, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))\n",
    "\n",
    "test_dataset = datasets.CIFAR10(\n",
    "    directory, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    normalize,\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dummy_loader = torch.utils.data.DataLoader(\n",
    "        train_dataset, batch_size=50000, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "print('loading train data')\n",
    "for A, y in dummy_loader:\n",
    "    break\n",
    "    \n",
    "dummy_test_loader = torch.utils.data.DataLoader(\n",
    "        test_dataset, batch_size=10000, shuffle=False,\n",
    "        pin_memory=True, sampler=None)\n",
    "\n",
    "print('loading test data')\n",
    "for A_test, y_test in dummy_test_loader:\n",
    "    break\n",
    "    \n",
    "A = A.view(A.shape[0], -1)\n",
    "A_test = A_test.view(A_test.shape[0], -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = one_hot(y)\n",
    "labels_test = one_hot(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define Network Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = 1e-4\n",
    "num_neurons = 1000\n",
    "learning_rate = 1e-4\n",
    "batch_size = 1000\n",
    "bn_momentum = 0.1\n",
    "num_epochs = 501\n",
    "momentum = 0.9"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_bn = batchnorm_full_gd(A, labels, A_test,\n",
    "                              labels_test, num_epochs, num_neurons, beta,\n",
    "                              learning_rate, verbose=True, batch_size = batch_size,\n",
    "                              bn_momentum=bn_momentum, bias=True, device='cuda', \n",
    "                              optimizer='SGD')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_whitened = no_batchnorm_full_gd(A, labels, A_test,\n",
    "                                        labels_test, num_epochs, num_neurons, beta, \n",
    "                                        learning_rate, verbose=True, batch_size = batch_size, bias=True,\n",
    "                                        momentum=momentum, device='cuda', optimizer='SGD')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8, 6))\n",
    "plt.grid()\n",
    "\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "plt.plot(np.arange(num_epochs), results_bn[1][::50], color='blue', linewidth=2, label='BN, training')\n",
    "plt.plot(np.arange(num_epochs), results_whitened[1][::50], color='orange', linewidth=2,label='Whitened, training')\n",
    "plt.plot(np.arange(num_epochs), results_bn[3][1:], color='blue',  linewidth=2,linestyle='dashed', label='BN, test')\n",
    "plt.plot(np.arange(num_epochs), results_whitened[3][1:], color='orange', linewidth=2, \n",
    "         label='Whitened, test', linestyle='--')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.xlabel('Epoch')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "aspect = 3072/501\n",
    "pad_fraction = 0.5\n",
    "hot = plt.cm.get_cmap('hot_r')\n",
    "newcolors = hot(np.linspace(0, 1, 100))\n",
    "newcolors_2 =  hot(np.linspace(0, 1, 200))\n",
    "newcolors[80:, :] = newcolors_2[160:180]\n",
    "newcmp = ListedColormap(newcolors)\n",
    "plt.rcParams.update({'font.size': 20})\n",
    "\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "ax = plt.gca()\n",
    "im = plt.imshow(results_bn[-2].T, cmap=newcmp, interpolation='nearest', aspect='auto')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Singular Value Direction')\n",
    "divider = make_axes_locatable(ax)\n",
    "width = axes_size.AxesY(ax, aspect=1./(15*aspect))\n",
    "pad = axes_size.Fraction(pad_fraction, width)\n",
    "cax = divider.append_axes(\"right\", size=width, pad=pad)\n",
    "plt.colorbar(im, cax=cax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 8))\n",
    "ax = plt.gca()\n",
    "im = plt.imshow(results_whitened[-2].T, cmap=newcmp, interpolation='nearest', aspect='auto')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Singular Value Direction')\n",
    "divider = make_axes_locatable(ax)\n",
    "width = axes_size.AxesY(ax, aspect=1./(15*aspect))\n",
    "pad = axes_size.Fraction(pad_fraction, width)\n",
    "cax = divider.append_axes(\"right\", size=width, pad=pad)\n",
    "plt.colorbar(im, cax=cax)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
