{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b370459-d5ea-41c6-9000-eb11321bf953",
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightning as L\n",
    "\n",
    "from machine_annotators_training import LitMyModule, LitCIFAR10, LitCIFAR100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a21947f-7e04-4aa2-aeb4-85bc57f839ce",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a5736d3-64d0-41e0-8c7d-2e1f46f518b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_module = LitCIFAR10(512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81f9101a-293e-4de3-a023-8514c6f7e684",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = L.Trainer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c14e68-c636-4622-83a4-83296fdcdc36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# my_module = LitMyModule(lr=1e-2, num_epochs=100, batch_size=512)\n",
    "# trainer.test(model=my_module, datamodule=data_module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8e80175-1428-416d-88e3-abb90fe57d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/sandy-sun-327/syn-epoch=29-global_step=0.ckpt' # 0.645\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/sandy-sun-327/syn-epoch=59-global_step=0.ckpt' # 0.741\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/sandy-sun-327/syn-epoch=99-global_step=0.ckpt' # 0.775\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/apricot-sun-326/syn-epoch=09-global_step=0.ckpt' #  0.840\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/apricot-sun-326/syn-epoch=19-global_step=0.ckpt' # 0.904\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/magic-pyramid-358/syn-epoch=09-global_step=0.ckpt' # 0.5487\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/sandy-sun-327/syn-epoch=09-global_step=0.ckpt' #  \n",
    "\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/divine-eon-489/syn-epoch=19-global_step=0.ckpt' #  \n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/stoic-glade-490/syn-epoch=19-global_step=0.ckpt' #  \n",
    "\n",
    "# The followings could potentially work \n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/glorious-breeze-526/syn-epoch=19-global_step=0.ckpt' #  0.67\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/swift-dew-527/syn-epoch=19-global_step=0.ckpt' #  0.702\n",
    "checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/serene-wind-528/syn-epoch=19-global_step=0.ckpt' # 0.666\n",
    "\n",
    "my_module = LitMyModule.load_from_checkpoint(checkpoint_path, lr=1e-2, num_epochs=100, batch_size=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771d2f09-ae6b-4c8c-897a-72cbe590519d",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(model=my_module, datamodule=data_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37da6cbd-2b07-4357-a7f6-a37b47d8671b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "# CIFAR100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3875b895-3109-4242-a6bd-afbfb3a4b3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_module = LitCIFAR100(512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb3ce7e3-a031-4ed9-ac72-888b63cdb765",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = L.Trainer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa9f3f57-9123-454e-9a3d-e39d713f2c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/twilight-bee-1135/syn-epoch=34-global_step=0.ckpt' # 0.8164\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/twilight-bee-1135/syn-epoch=14-global_step=0.ckpt' # 0.69499\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/twilight-bee-1135/syn-epoch=09-global_step=0.ckpt' # 0.60214\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/twilight-bee-1135/syn-epoch=04-global_step=0.ckpt' # 0.4295\n",
    "\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/serene-cloud-1133/syn-epoch=04-global_step=0.ckpt' # 0.35032\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/serene-cloud-1133/syn-epoch=19-global_step=0.ckpt' # 0.76212\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/serene-cloud-1133/syn-epoch=09-global_step=0.ckpt' # 0.5393\n",
    "\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/olive-music-1139/syn-epoch=09-global_step=0.ckpt' # 0.74128\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/olive-music-1139/syn-epoch=04-global_step=0.ckpt' # 0.43081\n",
    "checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/olive-music-1139/syn-epoch=14-global_step=0.ckpt' # 0.91035\n",
    "\n",
    "my_module = LitMyModule.load_from_checkpoint(checkpoint_path, lr=1e-2, num_epochs=100, batch_size=512, K=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5917ed61-0d0e-48e5-a429-9daac66fa5d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(model=my_module, datamodule=data_module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "499ea5fc-4a78-43af-807f-6f31d6fa004c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4a37fa72-7bb7-4d51-8ca8-9fa036f3b77a",
   "metadata": {},
   "source": [
    "# Fashion MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "60114c9f-1a08-48cd-a2f2-71ce229b6b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightning as L\n",
    "from machine_annotators_training_fmnist import LitFashionMNIST, LitMyModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7abf9a0f-0b91-4265-8faa-46de3cb329da",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/scratch/tri/venvs/pytorch12/lib/python3.8/site-packages/torchvision/transforms/v2/_deprecated.py:41: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.\n",
      "  warnings.warn(\n",
      "Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "/scratch/tri/venvs/pytorch12/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n"
     ]
    }
   ],
   "source": [
    "data_module = LitFashionMNIST(512)\n",
    "trainer = L.Trainer()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6a1f58eb-88c3-4778-a203-4abd90f4227e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fmnist_regression_annotator.pkl: 0.75193333333\n",
    "# fmnist_gaussiannb_annotator.pkl: 0.5877833\n",
    "# fmnist_kmeans_annotator.pkl: 0.4675 \n",
    "# fmnist_knn_annotator.pkl: 0.87836\n",
    "\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/confused-meadow-1190/syn-epoch=01-global_step=0.ckpt' # 0.7014\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/zany-field-1192/epoch=0-step=40.ckpt' # 0.463\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/zany-field-1192/epoch=1-step=80.ckpt' # 0.60218\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/zany-field-1192/epoch=1-step=100.ckpt' # 0.65745\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/astral-wood-1194/epoch=1-step=100.ckpt' # 0.7077\n",
    "# checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/astral-wood-1194/epoch=2-step=160.ckpt' # 0.74589\n",
    "checkpoint_path = f'/scratch/tri/shahana_outlier/lightning_saved_models/machine_annotator/astral-wood-1194/epoch=0-step=40.ckpt' # 0.4923\n",
    "\n",
    "my_module = LitMyModule.load_from_checkpoint(checkpoint_path, lr=1e-2, num_epochs=100, batch_size=512, K=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "1333c1b3-5b08-4479-ba33-fa5dc5d21d65",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test size: 60000\n",
      "Done SETTING data module for FashionMNIST!\n",
      "Testing DataLoader 0: 100%|█████████████████████████████████████████████████████████| 118/118 [00:03<00:00, 36.44it/s]\n",
      "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "       Test metric             DataLoader 0\n",
      "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "        test/acc            0.49230000376701355\n",
      "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'test/acc': 0.49230000376701355}]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.test(model=my_module, datamodule=data_module)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "247905bb-b2f9-410e-b324-94fcafd0ff7b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
