{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f11f8c13-fc85-4684-9bda-795d5acaed30",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import time\n",
    "import os\n",
    "from CL_BRUNO_TMLR import *\n",
    "import pandas as pd\n",
    "from torch.utils.data import Subset\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "device='cuda'\n",
    "torch.manual_seed(314159)\n",
    "np.random.seed(314159)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1fc3cdce-9b44-4ed8-9c3d-ada1f641238e",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess_train = transforms.Compose([\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.RandomVerticalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])\n",
    "])\n",
    "\n",
    "preprocess_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])\n",
    "])\n",
    "\n",
    "\n",
    "cifar_train = torchvision.datasets.CIFAR10(root='./CIFAR10', train=True, transform=preprocess_train)\n",
    "cifar_test = torchvision.datasets.CIFAR10(root='./CIFAR10', train=False, transform=preprocess_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c19059b-d699-4545-a519-75604f143b92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pretrain feature vector using the first batch of 10 classes\n",
    "cifar_01_train = get_subset(list(range(2)), cifar_train)\n",
    "cifar_01_test = get_subset(list(range(2)), cifar_test)\n",
    "cifar_01_train_loader = DataLoader(cifar_01_train, batch_size=64)\n",
    "cifar_01_test_loader = DataLoader(cifar_01_test, batch_size=64)\n",
    "dataset_sizes = {'train': len(cifar_01_train), 'val': len(cifar_01_test)}\n",
    "my_loader = {'train': cifar_01_train_loader, 'val': cifar_01_test_loader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "266de119-1663-491c-a165-1fbdcddf4bdb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/29\n",
      "Epoch 1/29\n",
      "Epoch 2/29\n",
      "Epoch 3/29\n",
      "Epoch 4/29\n",
      "Epoch 5/29\n",
      "Epoch 6/29\n",
      "Epoch 7/29\n",
      "Epoch 8/29\n",
      "Epoch 9/29\n",
      "Epoch 10/29\n",
      "Epoch 11/29\n",
      "Epoch 12/29\n",
      "Epoch 13/29\n",
      "Epoch 14/29\n",
      "Epoch 15/29\n",
      "Epoch 16/29\n",
      "Epoch 17/29\n",
      "Epoch 18/29\n",
      "Epoch 19/29\n",
      "Epoch 20/29\n",
      "Epoch 21/29\n",
      "Epoch 22/29\n",
      "Epoch 23/29\n",
      "Epoch 24/29\n",
      "Epoch 25/29\n",
      "Epoch 26/29\n",
      "Epoch 27/29\n",
      "Epoch 28/29\n",
      "Epoch 29/29\n"
     ]
    }
   ],
   "source": [
    "# initialise the model\n",
    "# test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "#                      y_cat_num=[2], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "#                      activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=2, device=device)\n",
    "# test_model = test_model.to(device)\n",
    "# test_model.init_extractor(my_loader, 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a2009236-f5ca-4eb0-a69c-5ece248f69ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_loader_train = DataLoader(cifar_train, batch_size=500)\n",
    "cifar_loader_test = DataLoader(cifar_test, batch_size=500)\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     transformed_x_train = []\n",
    "#     transformed_y_train = torch.tensor([])\n",
    "#     for x, y in cifar_loader_train:\n",
    "#         transformed_x_train += [test_model.my_extractor(x.to(device))]\n",
    "#         transformed_y_train = torch.cat((transformed_y_train, y))\n",
    "#     transformed_x_train = torch.vstack(transformed_x_train).cpu()\n",
    "#     transformed_y_train = transformed_y_train.to(torch.long)\n",
    "\n",
    "#     transformed_x_test = []\n",
    "#     transformed_y_test = torch.tensor([])\n",
    "#     for x, y in cifar_loader_test:\n",
    "#         transformed_x_test += [test_model.my_extractor(x.to(device))]\n",
    "#         transformed_y_test = torch.cat((transformed_y_test, y))\n",
    "#     transformed_x_test = torch.vstack(transformed_x_test).cpu()\n",
    "#     transformed_y_test = transformed_y_test.to(torch.long)\n",
    "\n",
    "# pd.DataFrame(transformed_x_train.numpy()).to_csv('iCIFAR10_pretrained_feature_train.csv', index=False)\n",
    "# pd.DataFrame(transformed_x_test.numpy()).to_csv('iCIFAR10_pretrained_feature_test.csv', index=False)\n",
    "\n",
    "transformed_x_train = torch.tensor(pd.read_csv('iCIFAR10_pretrained_feature_train.csv').to_numpy(), dtype=torch.float)\n",
    "transformed_x_test = torch.tensor(pd.read_csv('iCIFAR10_pretrained_feature_test.csv').to_numpy(), dtype=torch.float)\n",
    "\n",
    "transformed_y_test = torch.tensor(cifar_test.targets, dtype=torch.long)\n",
    "transformed_y_train = torch.tensor(cifar_train.targets, dtype=torch.long)\n",
    "\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[2], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=2, device=device)\n",
    "test_model = test_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fc959bf9-ded9-457c-a4c6-26781fecc4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_split = [list(range(i*2, (i+1)*2)) for i in range(5)]\n",
    "CIL_cifar_train = {}\n",
    "CIL_cifar_test = {}\n",
    "# nhwc to nchw\n",
    "for _ in range(len(task_split)):\n",
    "    task_id = np.array([i for i,j in enumerate(transformed_y_train) if j in task_split[_]])\n",
    "    perm = torch.randperm(len(task_id))\n",
    "    CIL_cifar_train['X_{}'.format(_)] = transformed_x_train[task_id][perm].to(device)\n",
    "    CIL_cifar_train['y_{}'.format(_)] = transformed_y_train[task_id][perm].to(device)\n",
    "    task_id = np.array([i for i, j in enumerate(transformed_y_test) if j in task_split[_]])\n",
    "    CIL_cifar_test['X_{}'.format(_)] = transformed_x_test[task_id].to(device)\n",
    "    CIL_cifar_test['y_{}'.format(_)] = transformed_y_test[task_id].to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d32c2aba-310d-4070-ae5e-1242c6556816",
   "metadata": {},
   "source": [
    "# with functional regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bf42e620-8f1f-4861-90b1-81ec54b125f8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0230)\n"
     ]
    }
   ],
   "source": [
    "# set alignment_reg=0. to turn off alignment regularizer\n",
    "train_loss, test_loss = test_model.train_init(CIL_cifar_train['X_0'], CIL_cifar_train['y_0'],\n",
    "                                              torch.zeros(CIL_cifar_train['y_0'].shape[0], dtype=torch.long).to(device),\n",
    "                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3)\n",
    "\n",
    "\n",
    "N = len(CIL_cifar_test['y_0'])\n",
    "my_id_test = range(len(CIL_cifar_test['y_0']))\n",
    "# runnable, check outputs\n",
    "q = torch.zeros((N, 2))\n",
    "p = torch.zeros(N)\n",
    "for i,j in enumerate(my_id_test):\n",
    "    a, b = test_model.prediction(CIL_cifar_test['X_0'][j], 0)\n",
    "    q[i] = a.cpu()\n",
    "    p[i] = b.cpu()\n",
    "print(torch.sum(q.cpu().argmax(1) != CIL_cifar_test['y_0'][my_id_test].cpu())/N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "65f03c06-ab6a-4868-a38c-439aa6e7e836",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1\n",
      "Misclassification: 0.278\n",
      "Batch 2\n",
      "Misclassification: 0.482\n",
      "Batch 3\n",
      "Misclassification: 0.566\n",
      "Batch 4\n",
      "Misclassification: 0.579\n"
     ]
    }
   ],
   "source": [
    "batch_sizes = [128]*4\n",
    "# doing CIL\n",
    "for batch_id in range(1, 5):\n",
    "    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_cifar_train['X_{}'.format(batch_id)],\n",
    "                                                                         y_new=CIL_cifar_train['y_{}'.format(batch_id)],\n",
    "                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),\n",
    "                                                                         weight_decay=0., lr=1e-3, n_pseudo=128)\n",
    "\n",
    "    acc = 0.\n",
    "    for hist_id in range(batch_id+1):\n",
    "        N = len(CIL_cifar_test['y_{}'.format(hist_id)])\n",
    "        my_id_test = range(len(CIL_cifar_test['y_{}'.format(hist_id)]))\n",
    "        q = torch.zeros((N, (batch_id + 1) * 2))\n",
    "        p = torch.zeros(N)\n",
    "        for i, j in enumerate(my_id_test):\n",
    "            a, b = test_model.prediction(CIL_cifar_test['X_{}'.format(hist_id)][j], 0)\n",
    "            q[i] = a.cpu()\n",
    "            p[i] = b.cpu()\n",
    "        acc += torch.sum(q.cpu().argmax(1) != CIL_cifar_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N\n",
    "    print('Batch {}'.format(batch_id))\n",
    "    print('Misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55215a3d-6c23-40f4-8d9b-ebad87729f96",
   "metadata": {},
   "source": [
    "# Working out ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ecc53c17-c283-4b06-b0bc-e13c8001b288",
   "metadata": {},
   "outputs": [],
   "source": [
    "# working out ECE, code taken from https://github.com/Jonathan-Pearce/calibration-toolbox/blob/master/metrics.py\n",
    "from scipy.special import softmax\n",
    "class CELoss(object):\n",
    "\n",
    "    def compute_bin_boundaries(self, probabilities = np.array([])):\n",
    "\n",
    "        #uniform bin spacing\n",
    "        if probabilities.size == 0:\n",
    "            bin_boundaries = np.linspace(0, 1, self.n_bins + 1)\n",
    "            self.bin_lowers = bin_boundaries[:-1]\n",
    "            self.bin_uppers = bin_boundaries[1:]\n",
    "        else:\n",
    "            #size of bins \n",
    "            bin_n = int(self.n_data/self.n_bins)\n",
    "\n",
    "            bin_boundaries = np.array([])\n",
    "\n",
    "            probabilities_sort = np.sort(probabilities)  \n",
    "\n",
    "            for i in range(0,self.n_bins):\n",
    "                bin_boundaries = np.append(bin_boundaries,probabilities_sort[i*bin_n])\n",
    "            bin_boundaries = np.append(bin_boundaries,1.0)\n",
    "\n",
    "            self.bin_lowers = bin_boundaries[:-1]\n",
    "            self.bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "\n",
    "    def get_probabilities(self, output, labels, logits):\n",
    "        #If not probabilities apply softmax!\n",
    "        if logits:\n",
    "            self.probabilities = softmax(output, axis=1)\n",
    "        else:\n",
    "            self.probabilities = output\n",
    "\n",
    "        self.labels = labels\n",
    "        self.confidences = np.max(self.probabilities, axis=1)\n",
    "        self.predictions = np.argmax(self.probabilities, axis=1)\n",
    "        self.accuracies = np.equal(self.predictions,labels)\n",
    "\n",
    "    def binary_matrices(self):\n",
    "        idx = np.arange(self.n_data)\n",
    "        #make matrices of zeros\n",
    "        pred_matrix = np.zeros([self.n_data,self.n_class])\n",
    "        label_matrix = np.zeros([self.n_data,self.n_class])\n",
    "        #self.acc_matrix = np.zeros([self.n_data,self.n_class])\n",
    "        pred_matrix[idx,self.predictions] = 1\n",
    "        label_matrix[idx,self.labels] = 1\n",
    "\n",
    "        self.acc_matrix = np.equal(pred_matrix, label_matrix)\n",
    "\n",
    "\n",
    "    def compute_bins(self, index = None):\n",
    "        self.bin_prop = np.zeros(self.n_bins)\n",
    "        self.bin_acc = np.zeros(self.n_bins)\n",
    "        self.bin_conf = np.zeros(self.n_bins)\n",
    "        self.bin_score = np.zeros(self.n_bins)\n",
    "\n",
    "        if index == None:\n",
    "            confidences = self.confidences\n",
    "            accuracies = self.accuracies\n",
    "        else:\n",
    "            confidences = self.probabilities[:,index]\n",
    "            accuracies = self.acc_matrix[:,index]\n",
    "\n",
    "\n",
    "        for i, (bin_lower, bin_upper) in enumerate(zip(self.bin_lowers, self.bin_uppers)):\n",
    "            # Calculated |confidence - accuracy| in each bin\n",
    "            in_bin = np.greater(confidences,bin_lower.item()) * np.less_equal(confidences,bin_upper.item())\n",
    "            self.bin_prop[i] = np.mean(in_bin)\n",
    "\n",
    "            if self.bin_prop[i].item() > 0:\n",
    "                self.bin_acc[i] = np.mean(accuracies[in_bin])\n",
    "                self.bin_conf[i] = np.mean(confidences[in_bin])\n",
    "                self.bin_score[i] = np.abs(self.bin_conf[i] - self.bin_acc[i])\n",
    "\n",
    "class MaxProbCELoss(CELoss):\n",
    "    def loss(self, output, labels, n_bins = 15, logits = True):\n",
    "        self.n_bins = n_bins\n",
    "        super().compute_bin_boundaries()\n",
    "        super().get_probabilities(output, labels, logits)\n",
    "        super().compute_bins()\n",
    "\n",
    "#http://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf\n",
    "class ECELoss(MaxProbCELoss):\n",
    "\n",
    "    def loss(self, output, labels, n_bins = 15, logits = True):\n",
    "        super().loss(output, labels, n_bins, logits)\n",
    "        return np.dot(self.bin_prop,self.bin_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7db04748-2016-4354-80d0-9f0b6b9e6c14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.24494490709900857\n"
     ]
    }
   ],
   "source": [
    "# working out ECE\n",
    "ansss = []\n",
    "true_label = []\n",
    "batch_id=4\n",
    "for hist_id in range(batch_id+1):\n",
    "    print(hist_id)\n",
    "    N = len(CIL_cifar_test['y_{}'.format(hist_id)])\n",
    "    my_id_test = range(len(CIL_cifar_test['y_{}'.format(hist_id)]))\n",
    "    q = torch.zeros((N, (batch_id + 1) * 2))\n",
    "    p = torch.zeros(N)\n",
    "    for i, j in enumerate(my_id_test):\n",
    "        a, b = test_model.prediction(CIL_cifar_test['X_{}'.format(hist_id)][j], 0)\n",
    "        q[i] = a.cpu()\n",
    "    ansss += [q * 1.0]\n",
    "    true_label += [CIL_cifar_test['y_{}'.format(hist_id)][my_id_test].cpu()]\n",
    "    \n",
    "pred = torch.vstack(ansss).numpy()\n",
    "label = torch.concatenate(true_label).numpy()\n",
    "print(ECELoss().loss(output=pred, labels=label, n_bins = 10, logits = False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62af035f-7810-4f92-af7d-71a2caf00824",
   "metadata": {},
   "source": [
    "# without funtional regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "03d2f665-2fb5-4d29-9925-d9ca565cea6f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1\n",
      "Misclassification: 0.278\n",
      "Batch 2\n",
      "Misclassification: 0.587\n",
      "Batch 3\n",
      "Misclassification: 0.663\n",
      "Batch 4\n",
      "Misclassification: 0.743\n"
     ]
    }
   ],
   "source": [
    "# set alignment_reg=0. to turn off alignment regularizer\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[2], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=2, device=device)\n",
    "test_model = test_model.to(device)\n",
    "\n",
    "train_loss, test_loss = test_model.train_init(CIL_cifar_train['X_0'], CIL_cifar_train['y_0'],\n",
    "                                              torch.zeros(CIL_cifar_train['y_0'].shape[0], dtype=torch.long).to(device),\n",
    "                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)\n",
    "                                              # context_portion=0.2)\n",
    "batch_sizes = [128]*4\n",
    "# doing CIL\n",
    "for batch_id in range(1, 5):\n",
    "    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_cifar_train['X_{}'.format(batch_id)],\n",
    "                                                                         y_new=CIL_cifar_train['y_{}'.format(batch_id)],\n",
    "                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),\n",
    "                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,\n",
    "                                                                         embedding_reg=0.1, alignment_reg=0.)\n",
    "\n",
    "    acc = 0.\n",
    "    for hist_id in range(batch_id+1):\n",
    "        N = len(CIL_cifar_test['y_{}'.format(hist_id)])\n",
    "        my_id_test = range(len(CIL_cifar_test['y_{}'.format(hist_id)]))\n",
    "        q = torch.zeros((N, (batch_id + 1) * 2))\n",
    "        p = torch.zeros(N)\n",
    "        for i, j in enumerate(my_id_test):\n",
    "            a, b = test_model.prediction(CIL_cifar_test['X_{}'.format(hist_id)][j], 0)\n",
    "            q[i] = a.cpu()\n",
    "            p[i] = b.cpu()\n",
    "        acc += torch.sum(q.cpu().argmax(1) != CIL_cifar_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N\n",
    "    print('Batch {}'.format(batch_id))\n",
    "    print('Misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2930f96f-7ec0-4f32-8d13-1f293d6424a3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
