{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import load_split_mnist\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "import random\n",
    "from torchsummary import summary\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCHSIZE = 256\n",
    "LR = 0.01\n",
    "MOMENTUM = 0.9\n",
    "EPOCH = 10\n",
    "DEVICE = 'mps'\n",
    "GRAD_EST_BATCHSIZE = 32\n",
    "GRAD_EST_EPOCHS = 1\n",
    "EWC_LAMBDA = 1000000.\n",
    "ROTATED = True \n",
    "\n",
    "SEED = 42\n",
    "\n",
    "# seed all the things\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)\n",
    "random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load datasets\n",
    "train, test = load_split_mnist.load()\n",
    "\n",
    "train_t1, train_t2 = train\n",
    "test_t1, test_t2 = test\n",
    "\n",
    "# dataloaders\n",
    "train_loader_t1 = torch.utils.data.DataLoader(\n",
    "    train_t1, batch_size=BATCHSIZE, shuffle=True\n",
    ")\n",
    "train_loader_subspace_est_t1 = torch.utils.data.DataLoader(\n",
    "    train_t1, batch_size=GRAD_EST_BATCHSIZE, shuffle=False\n",
    ")\n",
    "\n",
    "train_loader_t2 = torch.utils.data.DataLoader(\n",
    "    train_t2, batch_size=BATCHSIZE, shuffle=True\n",
    ")\n",
    "test_loader_t1 = torch.utils.data.DataLoader(\n",
    "    test_t1, batch_size=BATCHSIZE, shuffle=False\n",
    ")  # test data must not be shuffled!!! Otherwise embeddings are not paired\n",
    "test_loader_t2 = torch.utils.data.DataLoader(\n",
    "    test_t2, batch_size=BATCHSIZE, shuffle=False\n",
    ")\n",
    "\n",
    "# %%\n",
    "\n",
    "# get dataset shapes for train sets\n",
    "n_train_t1 = len(train_t1)\n",
    "n_train_t2 = len(train_t2)\n",
    "\n",
    "print(f\"n_train_t1: {n_train_t1}\")\n",
    "print(f\"n_train_t2: {n_train_t2}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def label_smoothing(y, alpha):\n",
    "    # convert y to one hot\n",
    "    y = torch.eye(5).to(y.device)[y]\n",
    "    return y * (1 - alpha) + alpha / y.size(1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a two layer mlp with relu\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim):\n",
    "        super(CNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 16, 3, 1, bias=False)\n",
    "        self.mpool1 = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(16, 16, 3, 1, bias=False)\n",
    "        self.mpool2 = nn.MaxPool2d(2, 2)\n",
    "        self.fc1 = nn.Linear(400, hidden_dim, bias=False)\n",
    "\n",
    "        self.readout1 = nn.Linear(hidden_dim, output_dim, bias=False)\n",
    "        self.readout2 = nn.Linear(hidden_dim, output_dim, bias=False)\n",
    "        self.readouts = nn.ModuleList([self.readout1, self.readout2])\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = self.mpool1(x)\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = self.mpool2(x)\n",
    "        x = x.view(-1, 400)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.readouts[t](x)\n",
    "        return x\n",
    "    \n",
    "cnn = CNN(28*28, 128, 5)\n",
    "summary(cnn)\n",
    "cnn.to(DEVICE)\n",
    "\n",
    "# optimizer and loss function\n",
    "optimizer = torch.optim.SGD(cnn.parameters(), lr=LR, momentum=MOMENTUM)\n",
    "loss_func = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train step\n",
    "def train_step(model, optimizer, loss_func, x, y, t):\n",
    "    # x,y,t to model device\n",
    "    x,y,t = x.to(DEVICE), y.to(DEVICE), t.to(DEVICE)\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(x, t)\n",
    "    loss = loss_func(output, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item()\n",
    "\n",
    "# test step\n",
    "@torch.no_grad()\n",
    "def test_step(model, loss_func, x, y, t):\n",
    "    # x,y,t to model device\n",
    "    x,y,t = x.to(DEVICE), y.to(DEVICE), t.to(DEVICE)\n",
    "    model.eval()\n",
    "    output = model(x, t)\n",
    "    loss = loss_func(output, y)\n",
    "    return loss.item()\n",
    "\n",
    "def train_epoch(model, optimizer, loss_func, train_loader, t):\n",
    "    loss = 0.0\n",
    "    for x, y in tqdm(train_loader):\n",
    "        loss += train_step(model, optimizer, loss_func, x, y, t)\n",
    "    return loss / len(train_loader)\n",
    "\n",
    "@torch.no_grad()\n",
    "def test_epoch(model, loss_func, test_loader, t):\n",
    "    loss = 0.0\n",
    "    for x, y in test_loader:\n",
    "        loss += test_step(model, loss_func, x, y, t)\n",
    "    return loss / len(test_loader)\n",
    "\n",
    "@torch.no_grad()\n",
    "def compute_accuracy_for_dataset(model, test_loader, t):\n",
    "    t = torch.tensor(t).to(DEVICE)\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for x, y in test_loader:\n",
    "        x, y = x.to(DEVICE), y.to(DEVICE)\n",
    "        output = model(x, t)\n",
    "        _, predicted = torch.max(output.data, 1)\n",
    "        total += y.size(0)\n",
    "        correct += (predicted == y).sum().item()\n",
    "    return correct / total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train for 10 epochs on first task\n",
    "accs_t1 = []\n",
    "accs_t2 = []\n",
    "\n",
    "for epoch in (range(EPOCH)):\n",
    "    train_loss = train_epoch(cnn, optimizer, loss_func, train_loader_t1, torch.tensor(0))\n",
    "    test_loss = test_epoch(cnn, loss_func, test_loader_t1, torch.tensor(0))\n",
    "    test_acc = compute_accuracy_for_dataset(cnn, test_loader_t1, torch.tensor(0))\n",
    "    accs_t1.append(test_acc)\n",
    "    accs_t2.append(np.nan)\n",
    "    print(f\"Epoch: {epoch}, train_loss: {train_loss}, test_loss: {test_loss}, test_acc: {test_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "estimate the subspace spanned by gradients for the first task at each layer\n",
    "'''\n",
    "\n",
    "layer_1_grads = []\n",
    "layer_2_grads = []\n",
    "layer_3_grads = []\n",
    "\n",
    "loss_func_grad_est = nn.CrossEntropyLoss()\n",
    "\n",
    "t = torch.tensor(0).to(DEVICE)\n",
    "\n",
    "for _ in range(GRAD_EST_EPOCHS):\n",
    "    for x, y in tqdm(train_loader_subspace_est_t1):\n",
    "        x, y = x.to(DEVICE), y.to(DEVICE)\n",
    "        optimizer.zero_grad()\n",
    "        model_output = cnn(x, t)\n",
    "        loss = loss_func_grad_est(model_output, y)\n",
    "        loss.backward()\n",
    "        layer_1_grads.append(cnn.conv1.weight.grad.detach().cpu().numpy().reshape(-1))\n",
    "        layer_2_grads.append(cnn.conv2.weight.grad.detach().cpu().numpy().reshape(-1))\n",
    "        layer_3_grads.append(cnn.fc1.weight.grad.detach().cpu().numpy().reshape(-1))\n",
    "\n",
    "# numpy array\n",
    "layer_1_grads = np.array(layer_1_grads)\n",
    "layer_2_grads = np.array(layer_2_grads)\n",
    "layer_3_grads = np.array(layer_3_grads)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(layer_1_grads.shape)\n",
    "print(layer_2_grads.shape)\n",
    "print(layer_3_grads.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "from torch.linalg import svd, svdvals\n",
    "from scipy.linalg import svd as scipy_svd\n",
    "from sklearn.utils.extmath import svd_flip\n",
    "# import utilities for timing\n",
    "import time\n",
    "\n",
    "# pca as done by sklearn\n",
    "from sklearn.decomposition import PCA\n",
    "A = layer_2_grads\n",
    "start = time.time()\n",
    "pca = PCA()\n",
    "pca.fit(A)\n",
    "end = time.time()\n",
    "\n",
    "print(f\"sklearn PCA took {end-start} seconds\")\n",
    "\n",
    "\n",
    "# manual pca\n",
    "M = deepcopy(A)\n",
    "\n",
    "tick = time.time()\n",
    "mean_ = np.mean(M, axis=0)\n",
    "M -= mean_\n",
    "\n",
    "U, S, Vt = scipy_svd(M, full_matrices=False)\n",
    "# flip eigenvectors' sign to enforce deterministic output\n",
    "U, Vt = svd_flip(U, Vt)\n",
    "explained_variance_ = (S**2) / (len(M) - 1)\n",
    "tock = time.time()\n",
    "print(f\"manual PCA took {tock-tick} seconds\")\n",
    "\n",
    "\n",
    "def svd_trick(M):\n",
    "    sig = np.dot(M, M.T) / M.shape[0]\n",
    "    u2, s, _ = scipy_svd(sig)\n",
    "    u = np.dot(M.T, u2) / np.sqrt(s*M.shape[0])\n",
    "\n",
    "# compute again with svd trick because number of samples is much smaller than number of features\n",
    "M2 = deepcopy(A)\n",
    "tick = time.time()\n",
    "mean_ = np.mean(M2, axis=0)\n",
    "M2 -= mean_\n",
    "\n",
    "\n",
    "\n",
    "# plot eigenvalues for both methods\n",
    "plt.figure()\n",
    "plt.plot(pca.explained_variance_, label=\"sklearn\")\n",
    "plt.plot(explained_variance_, label=\"manual\")\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "# are the eigenvectors the same?\n",
    "np.allclose(pca.components_, Vt)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "A = torch.tensor(layer_1_grads)\n",
    "start = time.time()\n",
    "U, S, V = svd(A)\n",
    "end = time.time()\n",
    "print(f\"torch.linalg.svd took {end-start} seconds\")\n",
    "print(U.shape)\n",
    "print(V.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the mean gradient at each layer for task 1\n",
    "layer_1_grads_mean = np.mean(layer_1_grads, axis=0)\n",
    "layer_2_grads_mean = np.mean(layer_2_grads, axis=0)\n",
    "layer_3_grads_mean = np.mean(layer_3_grads, axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "pca_layer1 = PCA()\n",
    "pca_layer1.fit(layer_1_grads)\n",
    "pca_layer2 = PCA()\n",
    "pca_layer2.fit(layer_2_grads)\n",
    "pca_layer3 = PCA()\n",
    "pca_layer3.fit(layer_3_grads)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# plot cdf of explained variance\n",
    "plt.figure()\n",
    "plt.plot(np.cumsum(pca_layer1.explained_variance_ratio_), label='layer 1')\n",
    "plt.plot(np.cumsum(pca_layer2.explained_variance_ratio_), label='layer 2')\n",
    "plt.plot(np.cumsum(pca_layer3.explained_variance_ratio_), label='layer 3')\n",
    "plt.xlabel('number of components')\n",
    "plt.ylabel('cumulative explained variance')\n",
    "plt.legend()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def project_gradient(subspace, grad):\n",
    "    '''\n",
    "    project gradient onto subspace\n",
    "\n",
    "    subspace        - numpy array of shape (subspace_dim, layer_dim)\n",
    "    grad            - numpy array of shape (layer_dim,)\n",
    "    '''\n",
    "    subspace = subspace.to(grad.device)\n",
    "    projection = ((subspace @ grad).T @ subspace).T\n",
    "    return grad - projection\n",
    "\n",
    "# for each layer, get the components\n",
    "layer_1_components = pca_layer1.components_\n",
    "layer_2_components = pca_layer2.components_\n",
    "layer_3_components = pca_layer3.components_\n",
    "\n",
    "# normalize\n",
    "layer_1_components = layer_1_components / np.linalg.norm(layer_1_components, axis=1, keepdims=True)\n",
    "layer_2_components = layer_2_components / np.linalg.norm(layer_2_components, axis=1, keepdims=True)\n",
    "layer_3_components = layer_3_components / np.linalg.norm(layer_3_components, axis=1, keepdims=True)\n",
    "layer_1_components = layer_1_components.T\n",
    "layer_2_components = layer_2_components.T\n",
    "layer_3_components = layer_3_components.T\n",
    "\n",
    "print(layer_1_components.shape)\n",
    "print(layer_2_components.shape)\n",
    "print(layer_3_components.shape)\n",
    "\n",
    "\n",
    "proj_l1 = torch.tensor(layer_1_components).to(DEVICE)\n",
    "proj_l2 = torch.tensor(layer_2_components).to(DEVICE)\n",
    "proj_l3 = torch.tensor(layer_3_components).to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print projection shapes\n",
    "print(proj_l1.shape)\n",
    "print(proj_l2.shape)\n",
    "print(proj_l3.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def compute_rotated_diagonal_fisher(model, layers, data, criterion, task, projections):\n",
    "    infos = []\n",
    "    for layer, projection in zip(layers, projections):\n",
    "        fish_info = torch.zeros(projection.shape[1]).to(layer.weight.device)\n",
    "        fish_info = fish_info.reshape(-1)\n",
    "        N = 0\n",
    "        for x, y in data:\n",
    "            # make sure x and y live on the same device as the layer parameters\n",
    "            x = x.to(layer.weight.device)\n",
    "            y = y.to(layer.weight.device)\n",
    "            task = torch.tensor(task).to(layer.weight.device)\n",
    "            model.zero_grad()\n",
    "            yhat = model(x, task)\n",
    "            loss = criterion(yhat, y)\n",
    "            loss.backward()\n",
    "            grads = layer.weight.grad\n",
    "            grads = grads.reshape(-1)\n",
    "            grads = grads @ projection\n",
    "            grads = grads**2\n",
    "            fish_info += grads\n",
    "            N += len(x)\n",
    "        fish_info /= N\n",
    "        infos.append(fish_info)\n",
    "    return infos \n",
    "\n",
    "fisher_infos = compute_rotated_diagonal_fisher(cnn, [cnn.conv1, cnn.conv2, cnn.fc1], train_loader_t1, loss_func, 0, [proj_l1, proj_l2, proj_l3])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot distribution of fisher information\n",
    "plt.figure()\n",
    "plt.hist(fisher_infos[0].cpu().numpy(), label='layer 1')\n",
    "plt.hist(fisher_infos[1].cpu().numpy(), label='layer 2', alpha=0.5)\n",
    "plt.hist(fisher_infos[2].cpu().numpy(), label='layer 3', alpha=.5)\n",
    "plt.xlabel('fisher information')\n",
    "plt.ylabel('count')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# recompute fisher info without projection\n",
    "dim_l1 = np.prod(cnn.conv1.weight.shape)\n",
    "dim_l2 = np.prod(cnn.conv2.weight.shape)\n",
    "dim_l3 = np.prod(cnn.fc1.weight.shape)\n",
    "\n",
    "fisher_infos_unrotated = compute_rotated_diagonal_fisher(cnn, [cnn.conv1, cnn.conv2, cnn.fc1], train_loader_t1, loss_func, 0, [torch.eye(dim_l1).to(DEVICE), torch.eye(dim_l2).to(DEVICE), torch.eye(dim_l3).to(DEVICE)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.hist(fisher_infos_unrotated[0].cpu().numpy(), bins=100, label='layer 1')\n",
    "plt.hist(fisher_infos_unrotated[1].cpu().numpy(), bins=100, label='layer 2', alpha=0.5)\n",
    "plt.hist(fisher_infos_unrotated[2].cpu().numpy(), bins=100, label='layer 3', alpha=.5)\n",
    "plt.xlabel('fisher information')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compare max fisher info for each layer between methods\n",
    "print(f\"layer 1: aligned: {torch.max(fisher_infos[0])}, unaligned: {torch.max(fisher_infos_unrotated[0])}\")\n",
    "print(f\"layer 2: aligned: {torch.max(fisher_infos[1])}, unaligned: {torch.max(fisher_infos_unrotated[1])}\")\n",
    "print(f\"layer 3: aligned: {torch.max(fisher_infos[2])}, unaligned: {torch.max(fisher_infos_unrotated[2])}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_regularisation_term(w_new, w_old, fisher_info, lambda_ewc, projection):\n",
    "    w_old = w_old.detach().reshape(-1)\n",
    "    w_new = w_new.reshape(-1)\n",
    "    w_old_proj = w_old @ projection\n",
    "    w_new_proj = w_new @ projection\n",
    "    square_diff = (w_old_proj - w_new_proj)**2\n",
    "    weighted_square_diff = fisher_info * square_diff\n",
    "    return (lambda_ewc / 2) * weighted_square_diff.sum()\n",
    "\n",
    "\n",
    "def train_step_ewc(model, old_model, optimizer, loss_func, x, y, t, fisher_infos, lambda_ewc, projections):\n",
    "    '''\n",
    "    train the network on a batch of data, projecting the gradient to be orthogonal to v\n",
    "    '''\n",
    "    x,y,t = x.to(DEVICE), y.to(DEVICE), t.to(DEVICE)\n",
    "    \n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    output = model(x, t)\n",
    "    loss = loss_func(output, y)\n",
    "\n",
    "    # ewc loss\n",
    "    ewc_loss = 0.0\n",
    "    ewc_loss += compute_regularisation_term(model.conv1.weight, old_model.conv1.weight, fisher_infos[0], lambda_ewc, projections[0])\n",
    "    ewc_loss += compute_regularisation_term(model.conv2.weight, old_model.conv2.weight, fisher_infos[1], lambda_ewc, projections[1])\n",
    "    ewc_loss += compute_regularisation_term(model.fc1.weight, old_model.fc1.weight, fisher_infos[2], lambda_ewc, projections[2])\n",
    "\n",
    "    loss += ewc_loss\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def train_epoch_ewc(model, old_model, optimizer, loss_func, train_loader, t, fisher_infos, lambda_ewc, projections):\n",
    "    loss = 0.0\n",
    "    for x, y in tqdm(train_loader):\n",
    "        loss += train_step_ewc(model, old_model, optimizer, loss_func, x, y, t, fisher_infos, lambda_ewc, projections)\n",
    "    return loss / len(train_loader)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluate task 1\n",
    "t0 = torch.tensor(0).to(DEVICE)\n",
    "test_acc_1 = compute_accuracy_for_dataset(cnn, test_loader_t1, t0)\n",
    "\n",
    "# evaluate task 2\n",
    "t1 = torch.tensor(1).to(DEVICE)\n",
    "test_acc_2 = compute_accuracy_for_dataset(cnn, test_loader_t2, t1)\n",
    "\n",
    "print(f\"Task 1 accuracy: {test_acc_1}, Task 2 accuracy: {test_acc_2}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "\n",
    "# reset optimizer so that we do not carry over momentum from previous training\n",
    "optimizer = torch.optim.SGD(cnn.parameters(), lr=LR, momentum=MOMENTUM)\n",
    "old_model = deepcopy(cnn)\n",
    "\n",
    "# train on task 2\n",
    "for i in range(EPOCH):\n",
    "    if ROTATED:\n",
    "        train_loss = train_epoch_ewc(cnn, old_model, optimizer, loss_func, train_loader_t2, torch.tensor(1), fisher_infos, EWC_LAMBDA, [proj_l1, proj_l2, proj_l3])\n",
    "    else:\n",
    "        train_loss = train_epoch_ewc(cnn, old_model, optimizer, loss_func, train_loader_t2, torch.tensor(1), fisher_infos_unrotated, EWC_LAMBDA, [torch.eye(dim_l1).to(DEVICE), torch.eye(dim_l2).to(DEVICE), torch.eye(dim_l3).to(DEVICE)])\n",
    "    test_acc_t1 = compute_accuracy_for_dataset(cnn, test_loader_t1, torch.tensor(0))\n",
    "    test_acc_t2 = compute_accuracy_for_dataset(cnn, test_loader_t2, torch.tensor(1))\n",
    "    accs_t1.append(test_acc_t1)\n",
    "    accs_t2.append(test_acc_t2)\n",
    "    print(f\"Epoch: {i}, train_loss: {train_loss}, test_acc_t1: {test_acc_t1}, test_acc_t2: {test_acc_t2}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot accuracies\n",
    "\n",
    "# ticks style, talk context\n",
    "sns.set_context('talk')\n",
    "sns.set_style('ticks')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(accs_t1, label='task 1', marker='o')\n",
    "plt.plot(accs_t2, label='task 2', marker='o')\n",
    "# vertical dotted line at epochs + 0.5\n",
    "plt.axvline(EPOCH - 0.5, linestyle='--', color='k')\n",
    "plt.xlabel('epoch')\n",
    "plt.ylabel('accuracy')\n",
    "plt.legend()\n",
    "# plt.ylim(0.5, 1)\n",
    "sns.despine(trim=False)\n",
    "plt.tight_layout()\n",
    "\n",
    "# save plot to desktop\n",
    "plt.savefig(f\"/Users/daniel/Desktop/ewc_rotated_{ROTATED}_seed_{SEED}_lambda_{EWC_LAMBDA}_grad_est_epochs_{GRAD_EST_EPOCHS}_grad_est_bs_{GRAD_EST_BATCHSIZE}.pdf\", dpi=300, bbox_inches='tight')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NOTE\n",
    "\n",
    "It seems that we are making errors in the nullspace projection. As we approach the optimum the errors become large relative to the gradient and thus Adam messes things up I guess?\n",
    "\n",
    "Alternative explanation: Adam uses different momentum terms PER PARAMETER. This means the actual gradient updates can fall in a space that is forbidden by the projection, as momentum is applied after the projection (check this). If we want to use adam we would have to insert the projection step into the momentum computation done by adam"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt",
   "language": "python",
   "name": "pt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
