{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9df91294-23c4-4d1e-9b0d-f890873696f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import time\n",
    "import os\n",
    "from CL_BRUNO_TMLR import *\n",
    "import pandas as pd\n",
    "from torch.utils.data import Subset\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "device='cuda'\n",
    "torch.manual_seed(314159)\n",
    "np.random.seed(314159)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9821024f-7be6-4450-95b7-4b5a9e8935dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess_train = transforms.Compose([\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.RandomVerticalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])\n",
    "])\n",
    "\n",
    "preprocess_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])\n",
    "])\n",
    "\n",
    "\n",
    "cifar_train = torchvision.datasets.CIFAR100(root='./CIFAR100', train=True, transform=preprocess_train)\n",
    "cifar_test = torchvision.datasets.CIFAR100(root='./CIFAR100', train=False, transform=preprocess_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "69f3a34a-3a6a-475d-8927-2225d69e29e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pretrain feature vector using the first batch of 10 classes\n",
    "cifar_01_train = get_subset(list(range(10)), cifar_train)\n",
    "cifar_01_test = get_subset(list(range(10)), cifar_test)\n",
    "cifar_01_train_loader = DataLoader(cifar_01_train, batch_size=64)\n",
    "cifar_01_test_loader = DataLoader(cifar_01_test, batch_size=64)\n",
    "dataset_sizes = {'train': len(cifar_01_train), 'val': len(cifar_01_test)}\n",
    "my_loader = {'train': cifar_01_train_loader, 'val': cifar_01_test_loader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0af4d346-f493-4f25-8c25-bf50d6df4ecc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/29\n",
      "Epoch 1/29\n",
      "Epoch 2/29\n",
      "Epoch 3/29\n",
      "Epoch 4/29\n",
      "Epoch 5/29\n",
      "Epoch 6/29\n",
      "Epoch 7/29\n",
      "Epoch 8/29\n",
      "Epoch 9/29\n",
      "Epoch 10/29\n",
      "Epoch 11/29\n",
      "Epoch 12/29\n",
      "Epoch 13/29\n",
      "Epoch 14/29\n",
      "Epoch 15/29\n",
      "Epoch 16/29\n",
      "Epoch 17/29\n",
      "Epoch 18/29\n",
      "Epoch 19/29\n",
      "Epoch 20/29\n",
      "Epoch 21/29\n",
      "Epoch 22/29\n",
      "Epoch 23/29\n",
      "Epoch 24/29\n",
      "Epoch 25/29\n",
      "Epoch 26/29\n",
      "Epoch 27/29\n",
      "Epoch 28/29\n",
      "Epoch 29/29\n"
     ]
    }
   ],
   "source": [
    "# initialise the model\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[10], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=10, device=device)\n",
    "test_model = test_model.to(device)\n",
    "test_model.init_extractor(my_loader, 30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0a1dac3d-7bd8-4580-9a63-282feefcdf87",
   "metadata": {},
   "outputs": [],
   "source": [
    "cifar_loader_train = DataLoader(cifar_train, batch_size=64)\n",
    "cifar_loader_test = DataLoader(cifar_test, batch_size=64)\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     transformed_x_train = []\n",
    "#     transformed_y_train = torch.tensor([])\n",
    "#     for x, y in cifar_loader_train:\n",
    "#         transformed_x_train += [test_model.my_extractor(x.to(device))]\n",
    "#         transformed_y_train = torch.cat((transformed_y_train, y))\n",
    "#     transformed_x_train = torch.vstack(transformed_x_train).cpu()\n",
    "#     transformed_y_train = transformed_y_train.to(torch.long)\n",
    "\n",
    "#     transformed_x_test = []\n",
    "#     transformed_y_test = torch.tensor([])\n",
    "#     for x, y in cifar_loader_test:\n",
    "#         transformed_x_test += [test_model.my_extractor(x.to(device))]\n",
    "#         transformed_y_test = torch.cat((transformed_y_test, y))\n",
    "#     transformed_x_test = torch.vstack(transformed_x_test).cpu()\n",
    "#     transformed_y_test = transformed_y_test.to(torch.long)\n",
    "\n",
    "# pd.DataFrame(transformed_x_train.numpy()).to_csv('iCIFAR100_pretrained_feature_train.csv', index=False)\n",
    "# pd.DataFrame(transformed_x_test.numpy()).to_csv('iCIFAR100_pretrained_feature_test.csv', index=False)\n",
    "\n",
    "transformed_x_train = torch.tensor(pd.read_csv('iCIFAR100_pretrained_feature_train.csv').to_numpy(), dtype=torch.float)\n",
    "transformed_x_test = torch.tensor(pd.read_csv('iCIFAR100_pretrained_feature_test.csv').to_numpy(), dtype=torch.float)\n",
    "\n",
    "transformed_y_test = torch.tensor(cifar_test.targets, dtype=torch.long)\n",
    "transformed_y_train = torch.tensor(cifar_train.targets, dtype=torch.long)\n",
    "\n",
    "# initialise the model\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[10], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=10, device=device)\n",
    "test_model = test_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe1d39d5-ae54-452d-ac81-10086281fb51",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_split = [list(range(i*10, (i+1)*10)) for i in range(10)]\n",
    "CIL_cifar_train = {}\n",
    "CIL_cifar_test = {}\n",
    "# nhwc to nchw\n",
    "for _ in range(len(task_split)):\n",
    "    task_id = np.array([i for i,j in enumerate(transformed_y_train) if j in task_split[_]])\n",
    "    CIL_cifar_train['X_{}'.format(_)] = transformed_x_train[task_id].to(device)\n",
    "    CIL_cifar_train['y_{}'.format(_)] = transformed_y_train[task_id].to(device)\n",
    "    task_id = np.array([i for i, j in enumerate(transformed_y_test) if j in task_split[_]])\n",
    "    CIL_cifar_test['X_{}'.format(_)] = transformed_x_test[task_id].to(device)\n",
    "    CIL_cifar_test['y_{}'.format(_)] = transformed_y_test[task_id].to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c72d49e-8228-4d83-900e-428c1bc129ec",
   "metadata": {},
   "source": [
    "# CL-BRUNO with functional regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "fff54399-d28c-4bd0-9973-29e86a678e09",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.1540)\n"
     ]
    }
   ],
   "source": [
    "# set alignment_reg=0. to turn off alignment regularizer\n",
    "train_loss, test_loss = test_model.train_init(CIL_cifar_train['X_0'], CIL_cifar_train['y_0'],\n",
    "                                              torch.zeros(CIL_cifar_train['y_0'].shape[0], dtype=torch.long).to(device),\n",
    "                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)\n",
    "                                              # context_portion=0.2)\n",
    "\n",
    "\n",
    "N = len(CIL_cifar_test['y_0'])\n",
    "my_id_test = range(len(CIL_cifar_test['y_0']))\n",
    "# runnable, check outputs\n",
    "q = torch.zeros((N, 10))\n",
    "p = torch.zeros(N)\n",
    "for i,j in enumerate(my_id_test):\n",
    "    a, b = test_model.prediction(CIL_cifar_test['X_0'][j], 0)\n",
    "    q[i] = a.cpu()\n",
    "    p[i] = b.cpu()\n",
    "print(torch.sum(q.cpu().argmax(1) != CIL_cifar_test['y_0'][my_id_test].cpu())/N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "73797c3c-454d-4590-9234-83eaedc72c25",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1\n",
      "misclassification: 0.368\n",
      "Batch 2\n",
      "misclassification: 0.457\n",
      "Batch 3\n",
      "misclassification: 0.522\n",
      "Batch 4\n",
      "misclassification: 0.631\n",
      "Batch 5\n",
      "misclassification: 0.672\n",
      "Batch 6\n",
      "misclassification: 0.708\n",
      "Batch 7\n",
      "misclassification: 0.74\n",
      "Batch 8\n",
      "misclassification: 0.762\n",
      "Batch 9\n",
      "misclassification: 0.788\n"
     ]
    }
   ],
   "source": [
    "batch_sizes = [128]*9\n",
    "# doing CIL\n",
    "for batch_id in range(1, 10):\n",
    "    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_cifar_train['X_{}'.format(batch_id)],\n",
    "                                                                         y_new=CIL_cifar_train['y_{}'.format(batch_id)],\n",
    "                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),\n",
    "                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,\n",
    "                                                                         embedding_reg=0.1)\n",
    "\n",
    "    acc = 0.\n",
    "    for hist_id in range(batch_id+1):\n",
    "        N = len(CIL_cifar_test['y_{}'.format(hist_id)])\n",
    "        my_id_test = range(len(CIL_cifar_test['y_{}'.format(hist_id)]))\n",
    "        q = torch.zeros((N, (batch_id + 1) * 10))\n",
    "        p = torch.zeros(N)\n",
    "        for i, j in enumerate(my_id_test):\n",
    "            a, b = test_model.prediction(CIL_cifar_test['X_{}'.format(hist_id)][j], 0)\n",
    "            q[i] = a.cpu()\n",
    "            p[i] = b.cpu()\n",
    "        acc += torch.sum(q.cpu().argmax(1) != CIL_cifar_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N\n",
    "    print('Batch {}'.format(batch_id))\n",
    "    print('misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "431f71e9-d465-42ae-b004-3b6a8205735c",
   "metadata": {},
   "source": [
    "# working out ECE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "3c72303d-72f4-4f66-9c54-829a0eb6086e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# working out ECE, code taken from https://github.com/Jonathan-Pearce/calibration-toolbox/blob/master/metrics.py\n",
    "from scipy.special import softmax\n",
    "class CELoss(object):\n",
    "\n",
    "    def compute_bin_boundaries(self, probabilities = np.array([])):\n",
    "\n",
    "        #uniform bin spacing\n",
    "        if probabilities.size == 0:\n",
    "            bin_boundaries = np.linspace(0, 1, self.n_bins + 1)\n",
    "            self.bin_lowers = bin_boundaries[:-1]\n",
    "            self.bin_uppers = bin_boundaries[1:]\n",
    "        else:\n",
    "            #size of bins \n",
    "            bin_n = int(self.n_data/self.n_bins)\n",
    "\n",
    "            bin_boundaries = np.array([])\n",
    "\n",
    "            probabilities_sort = np.sort(probabilities)  \n",
    "\n",
    "            for i in range(0,self.n_bins):\n",
    "                bin_boundaries = np.append(bin_boundaries,probabilities_sort[i*bin_n])\n",
    "            bin_boundaries = np.append(bin_boundaries,1.0)\n",
    "\n",
    "            self.bin_lowers = bin_boundaries[:-1]\n",
    "            self.bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "\n",
    "    def get_probabilities(self, output, labels, logits):\n",
    "        #If not probabilities apply softmax!\n",
    "        if logits:\n",
    "            self.probabilities = softmax(output, axis=1)\n",
    "        else:\n",
    "            self.probabilities = output\n",
    "\n",
    "        self.labels = labels\n",
    "        self.confidences = np.max(self.probabilities, axis=1)\n",
    "        self.predictions = np.argmax(self.probabilities, axis=1)\n",
    "        self.accuracies = np.equal(self.predictions,labels)\n",
    "\n",
    "    def binary_matrices(self):\n",
    "        idx = np.arange(self.n_data)\n",
    "        #make matrices of zeros\n",
    "        pred_matrix = np.zeros([self.n_data,self.n_class])\n",
    "        label_matrix = np.zeros([self.n_data,self.n_class])\n",
    "        #self.acc_matrix = np.zeros([self.n_data,self.n_class])\n",
    "        pred_matrix[idx,self.predictions] = 1\n",
    "        label_matrix[idx,self.labels] = 1\n",
    "\n",
    "        self.acc_matrix = np.equal(pred_matrix, label_matrix)\n",
    "\n",
    "\n",
    "    def compute_bins(self, index = None):\n",
    "        self.bin_prop = np.zeros(self.n_bins)\n",
    "        self.bin_acc = np.zeros(self.n_bins)\n",
    "        self.bin_conf = np.zeros(self.n_bins)\n",
    "        self.bin_score = np.zeros(self.n_bins)\n",
    "\n",
    "        if index == None:\n",
    "            confidences = self.confidences\n",
    "            accuracies = self.accuracies\n",
    "        else:\n",
    "            confidences = self.probabilities[:,index]\n",
    "            accuracies = self.acc_matrix[:,index]\n",
    "\n",
    "\n",
    "        for i, (bin_lower, bin_upper) in enumerate(zip(self.bin_lowers, self.bin_uppers)):\n",
    "            # Calculated |confidence - accuracy| in each bin\n",
    "            in_bin = np.greater(confidences,bin_lower.item()) * np.less_equal(confidences,bin_upper.item())\n",
    "            self.bin_prop[i] = np.mean(in_bin)\n",
    "\n",
    "            if self.bin_prop[i].item() > 0:\n",
    "                self.bin_acc[i] = np.mean(accuracies[in_bin])\n",
    "                self.bin_conf[i] = np.mean(confidences[in_bin])\n",
    "                self.bin_score[i] = np.abs(self.bin_conf[i] - self.bin_acc[i])\n",
    "\n",
    "class MaxProbCELoss(CELoss):\n",
    "    def loss(self, output, labels, n_bins = 15, logits = True):\n",
    "        self.n_bins = n_bins\n",
    "        super().compute_bin_boundaries()\n",
    "        super().get_probabilities(output, labels, logits)\n",
    "        super().compute_bins()\n",
    "\n",
    "#http://people.cs.pitt.edu/~milos/research/AAAI_Calibration.pdf\n",
    "class ECELoss(MaxProbCELoss):\n",
    "\n",
    "    def loss(self, output, labels, n_bins = 15, logits = True):\n",
    "        super().loss(output, labels, n_bins, logits)\n",
    "        return np.dot(self.bin_prop,self.bin_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "88ecea01-8782-439b-bca6-040af91723a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3031906019434333\n"
     ]
    }
   ],
   "source": [
    "# working out ECE\n",
    "ansss = []\n",
    "true_label = []\n",
    "for hist_id in range(batch_id+1):\n",
    "    print(hist_id)\n",
    "    N = len(CIL_cifar_test['y_{}'.format(hist_id)])\n",
    "    my_id_test = range(len(CIL_cifar_test['y_{}'.format(hist_id)]))\n",
    "    q = torch.zeros((N, (batch_id + 1) * 10))\n",
    "    p = torch.zeros(N)\n",
    "    for i, j in enumerate(my_id_test):\n",
    "        a = test_model.prediction(test_model, CIL_cifar_test['X_{}'.format(hist_id)][j], 0)\n",
    "        q[i] = a.cpu()\n",
    "    ansss += [q * 1.0]\n",
    "    true_label += [CIL_cifar_test['y_{}'.format(hist_id)][my_id_test].cpu()]\n",
    "    \n",
    "pred = torch.vstack(ansss).numpy()\n",
    "label = torch.concatenate(true_label).numpy()\n",
    "print(ECELoss().loss(output=pred, labels=label, n_bins = 10, logits = False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54c785ee-0685-46f6-8cca-419f9bdac450",
   "metadata": {},
   "source": [
    "# without functional regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4781c780-7b5f-4304-b2ba-0472f01a0bfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1\n",
      "misclassification: 0.394\n",
      "Batch 2\n",
      "misclassification: 0.544\n",
      "Batch 3\n",
      "misclassification: 0.606\n",
      "Batch 4\n",
      "misclassification: 0.697\n",
      "Batch 5\n",
      "misclassification: 0.742\n",
      "Batch 6\n",
      "misclassification: 0.783\n",
      "Batch 7\n",
      "misclassification: 0.808\n",
      "Batch 8\n",
      "misclassification: 0.831\n",
      "Batch 9\n",
      "misclassification: 0.854\n"
     ]
    }
   ],
   "source": [
    "# set alignment_reg=0. to turn off alignment regularizer\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[10], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=10, device=device)\n",
    "test_model = test_model.to(device)\n",
    "\n",
    "train_loss, test_loss = test_model.train_init(CIL_cifar_train['X_0'], CIL_cifar_train['y_0'],\n",
    "                                              torch.zeros(CIL_cifar_train['y_0'].shape[0], dtype=torch.long).to(device),\n",
    "                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)\n",
    "                                              # context_portion=0.2)\n",
    "batch_sizes = [128]*9\n",
    "# doing CIL\n",
    "for batch_id in range(1, 10):\n",
    "    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_cifar_train['X_{}'.format(batch_id)],\n",
    "                                                                         y_new=CIL_cifar_train['y_{}'.format(batch_id)],\n",
    "                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),\n",
    "                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,\n",
    "                                                                         embedding_reg=0.1, alignment_reg=0.)\n",
    "\n",
    "    acc = 0.\n",
    "    for hist_id in range(batch_id+1):\n",
    "        N = len(CIL_cifar_test['y_{}'.format(hist_id)])\n",
    "        my_id_test = range(len(CIL_cifar_test['y_{}'.format(hist_id)]))\n",
    "        q = torch.zeros((N, (batch_id + 1) * 10))\n",
    "        p = torch.zeros(N)\n",
    "        for i, j in enumerate(my_id_test):\n",
    "            a, b = test_model.prediction(CIL_cifar_test['X_{}'.format(hist_id)][j], 0)\n",
    "            q[i] = a.cpu()\n",
    "            p[i] = b.cpu()\n",
    "        acc += torch.sum(q.cpu().argmax(1) != CIL_cifar_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N\n",
    "    print('Batch {}'.format(batch_id))\n",
    "    print('misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0afa5d14-52f3-4742-8323-c422f0f50be7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
