{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "96999ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cb84c91f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "wandb version 0.16.5 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.8"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/donato/Code/model-merging/cycle-consistent-model-merging/notebooks/dev/wandb/run-20240326_203904-z9y7hc9h</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/gladia/cycle-consistent-model-merging/runs/z9y7hc9h' target=\"_blank\">true-moon-2344</a></strong> to <a href='https://wandb.ai/gladia/cycle-consistent-model-merging' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/gladia/cycle-consistent-model-merging' target=\"_blank\">https://wandb.ai/gladia/cycle-consistent-model-merging</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/gladia/cycle-consistent-model-merging/runs/z9y7hc9h' target=\"_blank\">https://wandb.ai/gladia/cycle-consistent-model-merging/runs/z9y7hc9h</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"gladia\", job_type=\"dev\")\n",
    "\n",
    "artifact_path = f\"gladia/cycle-consistent-model-merging/ResNet22_16_1:v1\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "78e5e920",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'apple': 0,\n",
      " 'aquarium_fish': 1,\n",
      " 'baby': 2,\n",
      " 'bear': 3,\n",
      " 'beaver': 4,\n",
      " 'bed': 5,\n",
      " 'bee': 6,\n",
      " 'beetle': 7,\n",
      " 'bicycle': 8,\n",
      " 'bottle': 9,\n",
      " 'bowl': 10,\n",
      " 'boy': 11,\n",
      " 'bridge': 12,\n",
      " 'bus': 13,\n",
      " 'butterfly': 14,\n",
      " 'camel': 15,\n",
      " 'can': 16,\n",
      " 'castle': 17,\n",
      " 'caterpillar': 18,\n",
      " 'cattle': 19,\n",
      " 'chair': 20,\n",
      " 'chimpanzee': 21,\n",
      " 'clock': 22,\n",
      " 'cloud': 23,\n",
      " 'cockroach': 24,\n",
      " 'couch': 25,\n",
      " 'crab': 26,\n",
      " 'crocodile': 27,\n",
      " 'cup': 28,\n",
      " 'dinosaur': 29,\n",
      " 'dolphin': 30,\n",
      " 'elephant': 31,\n",
      " 'flatfish': 32,\n",
      " 'forest': 33,\n",
      " 'fox': 34,\n",
      " 'girl': 35,\n",
      " 'hamster': 36,\n",
      " 'house': 37,\n",
      " 'kangaroo': 38,\n",
      " 'keyboard': 39,\n",
      " 'lamp': 40,\n",
      " 'lawn_mower': 41,\n",
      " 'leopard': 42,\n",
      " 'lion': 43,\n",
      " 'lizard': 44,\n",
      " 'lobster': 45,\n",
      " 'man': 46,\n",
      " 'maple_tree': 47,\n",
      " 'motorcycle': 48,\n",
      " 'mountain': 49,\n",
      " 'mouse': 50,\n",
      " 'mushroom': 51,\n",
      " 'oak_tree': 52,\n",
      " 'orange': 53,\n",
      " 'orchid': 54,\n",
      " 'otter': 55,\n",
      " 'palm_tree': 56,\n",
      " 'pear': 57,\n",
      " 'pickup_truck': 58,\n",
      " 'pine_tree': 59,\n",
      " 'plain': 60,\n",
      " 'plate': 61,\n",
      " 'poppy': 62,\n",
      " 'porcupine': 63,\n",
      " 'possum': 64,\n",
      " 'rabbit': 65,\n",
      " 'raccoon': 66,\n",
      " 'ray': 67,\n",
      " 'road': 68,\n",
      " 'rocket': 69,\n",
      " 'rose': 70,\n",
      " 'sea': 71,\n",
      " 'seal': 72,\n",
      " 'shark': 73,\n",
      " 'shrew': 74,\n",
      " 'skunk': 75,\n",
      " 'skyscraper': 76,\n",
      " 'snail': 77,\n",
      " 'snake': 78,\n",
      " 'spider': 79,\n",
      " 'squirrel': 80,\n",
      " 'streetcar': 81,\n",
      " 'sunflower': 82,\n",
      " 'sweet_pepper': 83,\n",
      " 'table': 84,\n",
      " 'tank': 85,\n",
      " 'telephone': 86,\n",
      " 'television': 87,\n",
      " 'tiger': 88,\n",
      " 'tractor': 89,\n",
      " 'train': 90,\n",
      " 'trout': 91,\n",
      " 'tulip': 92,\n",
      " 'turtle': 93,\n",
      " 'wardrobe': 94,\n",
      " 'whale': 95,\n",
      " 'willow_tree': 96,\n",
      " 'wolf': 97,\n",
      " 'woman': 98,\n",
      " 'worm': 99}"
     ]
    }
   ],
   "source": [
    "model.metadata.class_vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "642b33e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">2024-03-26 20:39:17 </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> GPU available: <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-style: italic\">True</span> <span style=\"font-weight: bold\">(</span>cuda<span style=\"font-weight: bold\">)</span>, used: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>     <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1751\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1751</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m2024-03-26 20:39:17\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m GPU available: \u001b[3;92mTrue\u001b[0m \u001b[1m(\u001b[0mcuda\u001b[1m)\u001b[0m, used: \u001b[3;91mFalse\u001b[0m     \u001b]8;id=314026;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=419155;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1751\u001b\\\u001b[2m1751\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">                    </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> TPU available: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>, using: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> TPU cores    <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1754\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1754</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m                   \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m TPU available: \u001b[3;91mFalse\u001b[0m, using: \u001b[1;36m0\u001b[0m TPU cores    \u001b]8;id=474375;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=814815;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1754\u001b\\\u001b[2m1754\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">                    </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> IPU available: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>, using: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> IPUs         <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1757\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1757</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m                   \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m IPU available: \u001b[3;91mFalse\u001b[0m, using: \u001b[1;36m0\u001b[0m IPUs         \u001b]8;id=386379;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=601591;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1757\u001b\\\u001b[2m1757\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">                    </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> HPU available: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>, using: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> HPUs         <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1760\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1760</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m                   \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m HPU available: \u001b[3;91mFalse\u001b[0m, using: \u001b[1;36m0\u001b[0m HPUs         \u001b]8;id=728787;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=381736;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1760\u001b\\\u001b[2m1760\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    plugins=[NNCheckpointIO(jailing_dir=\"./tmp\")],\n",
    ")\n",
    "\n",
    "temp_path = \"temp_checkpoint.ckpt\"\n",
    "\n",
    "trainer.strategy.connect(model)\n",
    "trainer.save_checkpoint(temp_path)\n",
    "\n",
    "model_class = model.__class__.__module__ + \".\" + model.__class__.__qualname__\n",
    "\n",
    "artifact_name = f\"\"\n",
    "model_artifact = wandb.Artifact(\n",
    "    name=\"CIFAR100_ResNet22_16_1\",\n",
    "    type=\"checkpoint\",\n",
    "    metadata={\"model_identifier\": \"ResNet22_16\", \"model_class\": model_class},\n",
    ")\n",
    "\n",
    "model_artifact.add_file(temp_path + \".zip\", name=\"trained.ckpt.zip\")\n",
    "run.log_artifact(model_artifact)\n",
    "\n",
    "os.remove(temp_path + \".zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a1ed6fc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"gladia\", job_type=\"dev\")\n",
    "\n",
    "artifact_path = f\"gladia/cycle-consistent-model-merging/CIFAR100_ResNet22_16_1:v0\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
