{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import medmnist\n",
    "import numpy as np\n",
    "\n",
    "from configuration import config_jup\n",
    "from utils.data_loader import get_loader_with_assignment\n",
    "from utils.train_utils import get_logger, initialize_model, save_results\n",
    "from utils.utils_memory import Memory\n",
    "from utils.cl_utils import Client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = config_jup.base_parser()\n",
    "client_id = 0\n",
    "args.n_runs = 1\n",
    "logger = get_logger(args)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    args.cuda = True\n",
    "    args.device = f'cuda:0'\n",
    "else:\n",
    "    args.device = 'cpu' \n",
    "\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for run in range(args.n_runs):\n",
    "    if args.dataset_name in medmnist.INFO.keys():\n",
    "        cls_assignment = None\n",
    "        loader_client, cls_assignment = get_loader_with_assignment(args, None, None)\n",
    "        print(cls_assignment)\n",
    "    else:\n",
    "        np.random.seed(run)\n",
    "        cls_assignment = np.arange(args.n_classes)\n",
    "        np.random.shuffle(cls_assignment)\n",
    "        loader_client, _ = get_loader_with_assignment(args, cls_assignment.tolist(), run)\n",
    "\n",
    "    # for reproducibility purposes\n",
    "    np.random.seed(run)\n",
    "    torch.manual_seed(run)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    model, optimizer, criterion = initialize_model(args)\n",
    "    memory_client = Memory(args)\n",
    "    client = Client(args, loader_client, model, optimizer, criterion, memory_client, client_id, cls_assignment)\n",
    "\n",
    "    while not client.train_completed:\n",
    "        samples, labels = client.get_next_batch()\n",
    "\n",
    "        if samples is not None:\n",
    "            if args.with_memory:\n",
    "                if client.task_id == 0:\n",
    "                    client.train_with_update(samples, labels)\n",
    "                else:\n",
    "                    client.train_with_memory(samples, labels)\n",
    "            else:\n",
    "                client.train(samples, labels)\n",
    "\n",
    "        else:\n",
    "            print(f'Run {run} - Client {client.client_id} - Task {client.task_id} completed - {client.get_current_task()}')\n",
    "            # compute loss train\n",
    "            logger = client.compute_loss(logger, run)\n",
    "            print(f'Run {run} - Client {client.client_id} - Test time - Task {client.task_id}')\n",
    "            if args.model_name == 'resnetmc':\n",
    "                logger = client.testMC(logger, run)\n",
    "                logger = client.validationMC(logger, run)\n",
    "            else:\n",
    "                logger = client.test(logger, run)\n",
    "                logger = client.validation(logger, run)\n",
    "            logger = client.forgetting(logger, run)\n",
    "\n",
    "            if client.task_id + 1 >= args.n_tasks:\n",
    "                client.train_completed = True\n",
    "                print(f'Run {run} - Client {client.client_id} - Train completed')\n",
    "            else:\n",
    "                client.task_id += 1\n",
    "\n",
    "    if args.model_name == 'resnetmc':\n",
    "        logger = client.balanced_accuracyMC(logger, run)\n",
    "    else:\n",
    "        logger = client.balanced_accuracy(logger, run)\n",
    "    print()\n",
    "\n",
    "    print(logger['test']['acc'][client_id][run])\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'Final accuracy: {np.mean(np.mean(logger[\"test\"][\"acc\"][client_id], 0)[args.n_tasks-1,:], 0)}')\n",
    "print(f'Final forgetting: {np.mean(logger[\"test\"][\"forget\"][client_id])}')\n",
    "print(f'Final balanced accuracy: {np.mean(logger[\"test\"][\"bal_acc\"][client_id])}')\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save training results\n",
    "save_results(args, logger)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Show images in the memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "from torchvision import transforms\n",
    "from utils.data_loader import get_statistics\n",
    "\n",
    "\n",
    "mean, std, n_classes, inp_size, in_channels = get_statistics(args)\n",
    "\n",
    "invTrans = transforms.Compose([ transforms.Normalize(mean = np.dot(0, mean),\n",
    "                                                     std = np.divide(1, std)),\n",
    "                                transforms.Normalize(mean = np.dot(-1, mean),\n",
    "                                                     std = np.divide(std, std)),\n",
    "                               ])\n",
    "\n",
    "def show_images(args, imgs, class_id):\n",
    "    dir_plot = f'./images/{args.dataset_name}/{args.memory_size}/{args.uncertainty_score}/{args.balanced_step}/{class_id}'\n",
    "    if not os.path.exists(dir_plot):\n",
    "        os.makedirs(dir_plot)\n",
    "\n",
    "    n_rows = len(imgs) // 10\n",
    "\n",
    "    if n_rows > 1:\n",
    "        fix, axs = plt.subplots(nrows=n_rows, ncols=10, squeeze=False, figsize=(5, n_rows/2))\n",
    "        for n_row in range(n_rows+1):\n",
    "            for n_col in range(10):\n",
    "                img_idx = n_col + n_row * 10\n",
    "                if img_idx == len(imgs): break\n",
    "                img = transforms.ToPILImage()(invTrans(imgs[img_idx]).to('cpu'))\n",
    "                axs[n_row, n_col].imshow(np.asarray(img))\n",
    "                axs[n_row, n_col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "    else:\n",
    "        fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)\n",
    "        for i, img in enumerate(imgs):\n",
    "            img = transforms.ToPILImage()(invTrans(img).to('cpu'))\n",
    "            axs[0, i].imshow(np.asarray(img))\n",
    "            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "\n",
    "    plt.subplots_adjust(hspace=0, wspace=0)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_id = 2 # choose any class_id\n",
    "mem_class = client.memory.x[client.memory.y == class_id]\n",
    "show_images(args, mem_class, class_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create LT-version of CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from torchvision import datasets, transforms\n",
    "from utils.data_loader import get_data_per_class\n",
    "\n",
    "data_transforms = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "    ])\n",
    "\n",
    "train = datasets.CIFAR10('./data/raw/', train=True,  download=True, transform=data_transforms)\n",
    "max_num = len(train) / args.n_classes\n",
    "imb_factor = 0.1\n",
    "w_per_cls = []\n",
    "for idx in range(args.n_classes):\n",
    "    num = max_num * (imb_factor**(idx/(args.n_classes-1)))\n",
    "    w = num / max_num\n",
    "    w_per_cls.append(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_dict = get_data_per_class(args)\n",
    "skip = args.n_classes_per_task\n",
    "\n",
    "for run in range(args.n_runs):\n",
    "    dir_output = f'{args.dir_data}/data_splits/CL/{args.dataset_name}/run{run}/'\n",
    "    loader_fn = f'{dir_output}/{args.dataset_name}_split.pkl'\n",
    "    cls_assignment_fn = f'{dir_output}/{args.dataset_name}_cls_assignment.pkl'\n",
    "    cls_assignment = pickle.load(open(cls_assignment_fn, 'rb'))\n",
    "    print(cls_assignment)\n",
    "\n",
    "    # for each data split (i.e., train/val/test)\n",
    "    ds_out = {}\n",
    "    for name_ds, ds in ds_dict.items():\n",
    "        split_ds = []\n",
    "        for i in range(0, args.n_classes, skip):\n",
    "            w_list = w_per_cls[i:i+skip]\n",
    "            t_list = cls_assignment[i:i+skip]\n",
    "            task_ds_tmp_x = []\n",
    "            task_ds_tmp_y = []\n",
    "            for idx, class_id in enumerate(t_list):\n",
    "                class_x, class_y = ds[class_id]\n",
    "                num_per_class = int(w_list[idx]*len(class_y))\n",
    "                task_ds_tmp_x.append(class_x[:num_per_class])\n",
    "                task_ds_tmp_y.append(class_y[:num_per_class])\n",
    "\n",
    "            task_ds_x = torch.cat(task_ds_tmp_x)\n",
    "            task_ds_y = torch.cat(task_ds_tmp_y)\n",
    "            split_ds += [(task_ds_x, task_ds_y)]\n",
    "        ds_out[name_ds] = split_ds\n",
    "\n",
    "    ds_list = [ds_out['train'], ds_out['val'], ds_out['test']]\n",
    "    loader_list = []\n",
    "    for ds in ds_list:\n",
    "        loader_tmp = []\n",
    "        for task_data in ds:\n",
    "            images, label = task_data\n",
    "            indices = torch.from_numpy(np.random.choice(images.size(0), images.size(0), replace=False))\n",
    "            images = images[indices]\n",
    "            label = label[indices]\n",
    "            task_ds = torch.utils.data.TensorDataset(images, label)\n",
    "            task_loader = torch.utils.data.DataLoader(task_ds, batch_size=args.batch_size, drop_last=True)\n",
    "            loader_tmp.append(task_loader)\n",
    "        loader_list.append(loader_tmp)\n",
    "\n",
    "    dir_output = f'{args.dir_data}/data_splits/CL/{args.dataset_name}LT/run{run}/'\n",
    "    loader_fn = f'{dir_output}/{args.dataset_name}LT_split.pkl'\n",
    "    cls_assignment_fn = f'{dir_output}/{args.dataset_name}LT_cls_assignment.pkl'\n",
    "    if not os.path.exists(loader_fn):\n",
    "        os.makedirs(dir_output)\n",
    "\n",
    "    # save data splits and cls_assignment\n",
    "    with open(loader_fn, 'wb') as outfile:\n",
    "        pickle.dump(loader_list, outfile)\n",
    "        outfile.close()\n",
    "    with open(cls_assignment_fn, 'wb') as outfile:\n",
    "        pickle.dump(cls_assignment, outfile)\n",
    "        outfile.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
