{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3dcf99d2-d5c9-4cce-973a-09ff3bffb81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import time\n",
    "import os\n",
    "from CL_BRUNO_TMLR import *\n",
    "import pandas as pd\n",
    "from torch.utils.data import Subset\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from tiny_imagenet_torch import TinyImageNet\n",
    "device='cuda'\n",
    "torch.manual_seed(314159)\n",
    "np.random.seed(314159)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fd51741e-a677-46d7-9cc6-a84846f51dfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocess_train = transforms.Compose([\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.RandomVerticalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])\n",
    "])\n",
    "\n",
    "preprocess_test = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2762])\n",
    "])\n",
    "\n",
    "tinyimg_train = TinyImageNet(\n",
    "    root='./tinyimagenet',\n",
    "    train=True,\n",
    "    download=False,\n",
    "    transform=preprocess_train\n",
    ")\n",
    "\n",
    "tinyimg_test = TinyImageNet(\n",
    "    root='./tinyimagenet',\n",
    "    train=False,\n",
    "    download=False,\n",
    "    transform=preprocess_test\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6661d5a9-a86a-45cc-98c1-ef3364e55797",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\test\\CL_BRUNO_TMLR.py:610: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  idx = torch.tensor(dataset.targets) == i\n"
     ]
    }
   ],
   "source": [
    "# pretrain feature vector using the first batch of 10 classes\n",
    "tinyimg_01_train = get_subset(list(range(20)), tinyimg_train)\n",
    "tinyimg_01_test = get_subset(list(range(20)), tinyimg_test)\n",
    "tinyimg_01_train_loader = DataLoader(tinyimg_01_train, batch_size=128, shuffle=True)\n",
    "tinyimg_01_test_loader = DataLoader(tinyimg_01_test, batch_size=128, shuffle=True)\n",
    "dataset_sizes = {'train': len(tinyimg_01_train), 'val': len(tinyimg_01_test)}\n",
    "my_loader = {'train': tinyimg_01_train_loader, 'val': tinyimg_01_test_loader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5ec68314-4b59-40d0-a015-63c51b4a917c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialise the model\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[20], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=20, device=device)\n",
    "test_model = test_model.to(device)\n",
    "# test_model.init_extractor(my_loader, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "684676c9-42bb-48a2-8546-cb0f6f733909",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\86304\\AppData\\Local\\Temp\\ipykernel_24984\\4056311831.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  transformed_y_test = torch.tensor(tinyimg_test.targets, dtype=torch.long)\n",
      "C:\\Users\\86304\\AppData\\Local\\Temp\\ipykernel_24984\\4056311831.py:28: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  transformed_y_train = torch.tensor(tinyimg_train.targets, dtype=torch.long)\n"
     ]
    }
   ],
   "source": [
    "tinyimg_loader_train = DataLoader(tinyimg_train, batch_size=2000)\n",
    "tinyimg_loader_test = DataLoader(tinyimg_test, batch_size=2000)\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     transformed_x_train = []\n",
    "#     transformed_y_train = torch.tensor([])\n",
    "#     for x, y in tinyimg_loader_train:\n",
    "#         transformed_x_train += [test_model.my_extractor(x.to(device))]\n",
    "#         transformed_y_train = torch.cat((transformed_y_train, y))\n",
    "#     transformed_x_train = torch.vstack(transformed_x_train).cpu()\n",
    "#     transformed_y_train = transformed_y_train.to(torch.long)\n",
    "\n",
    "#     transformed_x_test = []\n",
    "#     transformed_y_test = torch.tensor([])\n",
    "#     for x, y in tinyimg_loader_test:\n",
    "#         transformed_x_test += [test_model.my_extractor(x.to(device))]\n",
    "#         transformed_y_test = torch.cat((transformed_y_test, y))\n",
    "#     transformed_x_test = torch.vstack(transformed_x_test).cpu()\n",
    "#     transformed_y_test = transformed_y_test.to(torch.long)\n",
    "\n",
    "# pd.DataFrame(transformed_x_train.numpy()).to_csv('tinyimg_pretrained_feature_train.csv', index=False)\n",
    "# pd.DataFrame(transformed_x_test.numpy()).to_csv('tinyimg_pretrained_feature_test.csv', index=False)\n",
    "\n",
    "transformed_x_train = torch.tensor(pd.read_csv('tinyimg_pretrained_feature_train.csv').to_numpy(), dtype=torch.float)\n",
    "transformed_x_test = torch.tensor(pd.read_csv('tinyimg_pretrained_feature_test.csv').to_numpy(), dtype=torch.float)\n",
    "\n",
    "transformed_y_test = torch.tensor(tinyimg_test.targets, dtype=torch.long)\n",
    "transformed_y_train = torch.tensor(tinyimg_train.targets, dtype=torch.long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a2e594c3-d957-44a1-ac72-9b88963efb32",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_split = [list(range(i*20, (i+1)*20)) for i in range(10)]\n",
    "CIL_tinyimg_train = {}\n",
    "CIL_tinyimg_test = {}\n",
    "# nhwc to nchw\n",
    "for _ in range(len(task_split)):\n",
    "    task_id = np.array([i for i,j in enumerate(transformed_y_train) if j in task_split[_]])\n",
    "    perm = torch.randperm(len(task_id))\n",
    "    CIL_tinyimg_train['X_{}'.format(_)] = transformed_x_train[task_id][perm].to(device)\n",
    "    CIL_tinyimg_train['y_{}'.format(_)] = transformed_y_train[task_id][perm].to(device)\n",
    "    task_id = np.array([i for i, j in enumerate(transformed_y_test) if j in task_split[_]])\n",
    "    CIL_tinyimg_test['X_{}'.format(_)] = transformed_x_test[task_id].to(device)\n",
    "    CIL_tinyimg_test['y_{}'.format(_)] = transformed_y_test[task_id].to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4f09801-6325-47eb-ae1e-3cda22f4cc21",
   "metadata": {},
   "source": [
    "# with functional regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d33058d0-b344-4cca-8c8c-cedcc9caa45c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.3530)\n"
     ]
    }
   ],
   "source": [
    "# set alignment_reg=0. to turn off alignment regularizer\n",
    "train_loss, test_loss = test_model.train_init(CIL_tinyimg_train['X_0'], CIL_tinyimg_train['y_0'],\n",
    "                                              torch.zeros(CIL_tinyimg_train['y_0'].shape[0], dtype=torch.long).to(device),\n",
    "                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)\n",
    "                                              # context_portion=0.2)\n",
    "\n",
    "\n",
    "N = len(CIL_tinyimg_test['y_0'])\n",
    "my_id_test = range(len(CIL_tinyimg_test['y_0']))\n",
    "# runnable, check outputs\n",
    "q = torch.zeros((N, 20))\n",
    "p = torch.zeros(N)\n",
    "for i,j in enumerate(my_id_test):\n",
    "    a, b = test_model.prediction(CIL_tinyimg_test['X_0'][j], 0)\n",
    "    q[i] = a.cpu()\n",
    "    p[i] = b.cpu()\n",
    "print(torch.sum(q.cpu().argmax(1) != CIL_tinyimg_test['y_0'][my_id_test].cpu())/N)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "abde646b-3f35-4752-93bd-4e7558bb33f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1\n",
      "misclassification: 0.521\n",
      "Batch 2\n",
      "misclassification: 0.605\n",
      "Batch 3\n",
      "misclassification: 0.693\n",
      "Batch 4\n",
      "misclassification: 0.774\n",
      "Batch 5\n",
      "misclassification: 0.838\n",
      "Batch 6\n",
      "misclassification: 0.867\n",
      "Batch 7\n",
      "misclassification: 0.869\n",
      "Batch 8\n",
      "misclassification: 0.871\n",
      "Batch 9\n",
      "misclassification: 0.883\n"
     ]
    }
   ],
   "source": [
    "batch_sizes = [128]*9\n",
    "# doing CIL\n",
    "for batch_id in range(1, 10):\n",
    "    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_tinyimg_train['X_{}'.format(batch_id)],\n",
    "                                                                         y_new=CIL_tinyimg_train['y_{}'.format(batch_id)],\n",
    "                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),\n",
    "                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,\n",
    "                                                                         embedding_reg=0.1)\n",
    "\n",
    "    acc = 0.\n",
    "    for hist_id in range(batch_id+1):\n",
    "        N = len(CIL_tinyimg_test['y_{}'.format(hist_id)])\n",
    "        my_id_test = range(len(CIL_tinyimg_test['y_{}'.format(hist_id)]))\n",
    "        q = torch.zeros((N, (batch_id + 1) * 20))\n",
    "        p = torch.zeros(N)\n",
    "        for i, j in enumerate(my_id_test):\n",
    "            a, b = test_model.prediction(CIL_tinyimg_test['X_{}'.format(hist_id)][j], 0)\n",
    "            q[i] = a.cpu()\n",
    "            p[i] = b.cpu()\n",
    "        acc += torch.sum(q.cpu().argmax(1) != CIL_tinyimg_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N\n",
    "    print('Batch {}'.format(batch_id))\n",
    "    print('misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23f43296-c84f-41f2-b262-344d8f342adc",
   "metadata": {},
   "source": [
    "# without funtional regularization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ca7f821f-aeb3-48bc-868d-2ff3b44dd2fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 1\n",
      "misclassification: 0.558\n",
      "Batch 2\n",
      "misclassification: 0.665\n",
      "Batch 3\n",
      "misclassification: 0.743\n",
      "Batch 4\n",
      "misclassification: 0.824\n",
      "Batch 5\n",
      "misclassification: 0.888\n",
      "Batch 6\n",
      "misclassification: 0.892\n",
      "Batch 7\n",
      "misclassification: 0.913\n",
      "Batch 8\n",
      "misclassification: 0.934\n",
      "Batch 9\n",
      "misclassification: 0.944\n"
     ]
    }
   ],
   "source": [
    "# initialise the model\n",
    "test_model = CLBruno(x_dim=512, y_dim=128, task_dim=1, cond_dim=129, conv=False, task_num=1,\n",
    "                     y_cat_num=[20], single_task=True, n_dense_block=6, n_hidden_dense=128,\n",
    "                     activation=nn.Tanh(), mu_init=0., var_init=1., corr_init=0.1, extractor=True, init_out=20, device=device)\n",
    "test_model = test_model.to(device)\n",
    "\n",
    "# set alignment_reg=0. to turn off alignment regularizer\n",
    "train_loss, test_loss = test_model.train_init(CIL_tinyimg_train['X_0'], CIL_tinyimg_train['y_0'],\n",
    "                                              torch.zeros(CIL_tinyimg_train['y_0'].shape[0], dtype=torch.long).to(device),\n",
    "                                              batch_size=128, epoch=30, weight_decay=0., lr=1e-3, embedding_reg=0.1)\n",
    "                                              # context_portion=0.2)\n",
    "\n",
    "batch_sizes = [128]*9\n",
    "# doing CIL\n",
    "for batch_id in range(1, 10):\n",
    "    train_loss1, test_loss1, reg_loss1 = test_model.train_continual_task(X_new=CIL_tinyimg_train['X_{}'.format(batch_id)],\n",
    "                                                                         y_new=CIL_tinyimg_train['y_{}'.format(batch_id)],\n",
    "                                                                         task_id=0, epoch=30, batch_size=int(batch_sizes[batch_id-1]),\n",
    "                                                                         weight_decay=0., lr=1e-3, n_pseudo=128,\n",
    "                                                                         embedding_reg=0.1, alignment_reg=0.)\n",
    "\n",
    "    acc = 0.\n",
    "    for hist_id in range(batch_id+1):\n",
    "        N = len(CIL_tinyimg_test['y_{}'.format(hist_id)])\n",
    "        my_id_test = range(len(CIL_tinyimg_test['y_{}'.format(hist_id)]))\n",
    "        q = torch.zeros((N, (batch_id + 1) * 20))\n",
    "        p = torch.zeros(N)\n",
    "        for i, j in enumerate(my_id_test):\n",
    "            a, b = test_model.prediction(CIL_tinyimg_test['X_{}'.format(hist_id)][j], 0)\n",
    "            q[i] = a.cpu()\n",
    "            p[i] = b.cpu()\n",
    "        acc += torch.sum(q.cpu().argmax(1) != CIL_tinyimg_test['y_{}'.format(hist_id)][my_id_test].cpu()) / N\n",
    "    print('Batch {}'.format(batch_id))\n",
    "    print('misclassification: {}'.format(np.round(acc.numpy()/(batch_id+1), 3)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3769cc0c-0a9e-4fa7-ad26-1b614a560061",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
