{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f70d5708",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5193e765",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a42b7b68",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"gladia\", job_type=\"dev\")\n",
    "\n",
    "artifact_path = f\"gladia/cycle-consistent-model-merging/ResNet22_2_1:v2\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "570dd7ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'n01443537': '0',\n",
      " 'n01629819': '1',\n",
      " 'n01641577': '2',\n",
      " 'n01644900': '3',\n",
      " 'n01698640': '4',\n",
      " 'n01742172': '5',\n",
      " 'n01768244': '6',\n",
      " 'n01770393': '7',\n",
      " 'n01774384': '8',\n",
      " 'n01774750': '9',\n",
      " 'n01784675': '10',\n",
      " 'n01882714': '11',\n",
      " 'n01910747': '12',\n",
      " 'n01917289': '13',\n",
      " 'n01944390': '14',\n",
      " 'n01950731': '15',\n",
      " 'n01983481': '16',\n",
      " 'n01984695': '17',\n",
      " 'n02002724': '18',\n",
      " 'n02056570': '19',\n",
      " 'n02058221': '20',\n",
      " 'n02074367': '21',\n",
      " 'n02094433': '22',\n",
      " 'n02099601': '23',\n",
      " 'n02099712': '24',\n",
      " 'n02106662': '25',\n",
      " 'n02113799': '26',\n",
      " 'n02123045': '27',\n",
      " 'n02123394': '28',\n",
      " 'n02124075': '29',\n",
      " 'n02125311': '30',\n",
      " 'n02129165': '31',\n",
      " 'n02132136': '32',\n",
      " 'n02165456': '33',\n",
      " 'n02226429': '34',\n",
      " 'n02231487': '35',\n",
      " 'n02233338': '36',\n",
      " 'n02236044': '37',\n",
      " 'n02268443': '38',\n",
      " 'n02279972': '39',\n",
      " 'n02281406': '40',\n",
      " 'n02321529': '41',\n",
      " 'n02364673': '42',\n",
      " 'n02395406': '43',\n",
      " 'n02403003': '44',\n",
      " 'n02410509': '45',\n",
      " 'n02415577': '46',\n",
      " 'n02423022': '47',\n",
      " 'n02437312': '48',\n",
      " 'n02480495': '49',\n",
      " 'n02481823': '50',\n",
      " 'n02486410': '51',\n",
      " 'n02504458': '52',\n",
      " 'n02509815': '53',\n",
      " 'n02666347': '54',\n",
      " 'n02669723': '55',\n",
      " 'n02699494': '56',\n",
      " 'n02769748': '57',\n",
      " 'n02788148': '58',\n",
      " 'n02791270': '59',\n",
      " 'n02793495': '60',\n",
      " 'n02795169': '61',\n",
      " 'n02802426': '62',\n",
      " 'n02808440': '63',\n",
      " 'n02814533': '64',\n",
      " 'n02814860': '65',\n",
      " 'n02815834': '66',\n",
      " 'n02823428': '67',\n",
      " 'n02837789': '68',\n",
      " 'n02841315': '69',\n",
      " 'n02843684': '70',\n",
      " 'n02883205': '71',\n",
      " 'n02892201': '72',\n",
      " 'n02909870': '73',\n",
      " 'n02917067': '74',\n",
      " 'n02927161': '75',\n",
      " 'n02948072': '76',\n",
      " 'n02950826': '77',\n",
      " 'n02963159': '78',\n",
      " 'n02977058': '79',\n",
      " 'n02988304': '80',\n",
      " 'n03014705': '81',\n",
      " 'n03026506': '82',\n",
      " 'n03042490': '83',\n",
      " 'n03085013': '84',\n",
      " 'n03089624': '85',\n",
      " 'n03100240': '86',\n",
      " 'n03126707': '87',\n",
      " 'n03160309': '88',\n",
      " 'n03179701': '89',\n",
      " 'n03201208': '90',\n",
      " 'n03255030': '91',\n",
      " 'n03355925': '92',\n",
      " 'n03373237': '93',\n",
      " 'n03388043': '94',\n",
      " 'n03393912': '95',\n",
      " 'n03400231': '96',\n",
      " 'n03404251': '97',\n",
      " 'n03424325': '98',\n",
      " 'n03444034': '99',\n",
      " 'n03447447': '100',\n",
      " 'n03544143': '101',\n",
      " 'n03584254': '102',\n",
      " 'n03599486': '103',\n",
      " 'n03617480': '104',\n",
      " 'n03637318': '105',\n",
      " 'n03649909': '106',\n",
      " 'n03662601': '107',\n",
      " 'n03670208': '108',\n",
      " 'n03706229': '109',\n",
      " 'n03733131': '110',\n",
      " 'n03763968': '111',\n",
      " 'n03770439': '112',\n",
      " 'n03796401': '113',\n",
      " 'n03814639': '114',\n",
      " 'n03837869': '115',\n",
      " 'n03838899': '116',\n",
      " 'n03854065': '117',\n",
      " 'n03891332': '118',\n",
      " 'n03902125': '119',\n",
      " 'n03930313': '120',\n",
      " 'n03937543': '121',\n",
      " 'n03970156': '122',\n",
      " 'n03977966': '123',\n",
      " 'n03980874': '124',\n",
      " 'n03983396': '125',\n",
      " 'n03992509': '126',\n",
      " 'n04008634': '127',\n",
      " 'n04023962': '128',\n",
      " 'n04070727': '129',\n",
      " 'n04074963': '130',\n",
      " 'n04099969': '131',\n",
      " 'n04118538': '132',\n",
      " 'n04133789': '133',\n",
      " 'n04146614': '134',\n",
      " 'n04149813': '135',\n",
      " 'n04179913': '136',\n",
      " 'n04251144': '137',\n",
      " 'n04254777': '138',\n",
      " 'n04259630': '139',\n",
      " 'n04265275': '140',\n",
      " 'n04275548': '141',\n",
      " 'n04285008': '142',\n",
      " 'n04311004': '143',\n",
      " 'n04328186': '144',\n",
      " 'n04356056': '145',\n",
      " 'n04366367': '146',\n",
      " 'n04371430': '147',\n",
      " 'n04376876': '148',\n",
      " 'n04398044': '149',\n",
      " 'n04399382': '150',\n",
      " 'n04417672': '151',\n",
      " 'n04456115': '152',\n",
      " 'n04465666': '153',\n",
      " 'n04486054': '154',\n",
      " 'n04487081': '155',\n",
      " 'n04501370': '156',\n",
      " 'n04507155': '157',\n",
      " 'n04532106': '158',\n",
      " 'n04532670': '159',\n",
      " 'n04540053': '160',\n",
      " 'n04560804': '161',\n",
      " 'n04562935': '162',\n",
      " 'n04596742': '163',\n",
      " 'n04598010': '164',\n",
      " 'n06596364': '165',\n",
      " 'n07056680': '166',\n",
      " 'n07583066': '167',\n",
      " 'n07614500': '168',\n",
      " 'n07615774': '169',\n",
      " 'n07646821': '170',\n",
      " 'n07647870': '171',\n",
      " 'n07657664': '172',\n",
      " 'n07695742': '173',\n",
      " 'n07711569': '174',\n",
      " 'n07715103': '175',\n",
      " 'n07720875': '176',\n",
      " 'n07749582': '177',\n",
      " 'n07753592': '178',\n",
      " 'n07768694': '179',\n",
      " 'n07871810': '180',\n",
      " 'n07873807': '181',\n",
      " 'n07875152': '182',\n",
      " 'n07920052': '183',\n",
      " 'n07975909': '184',\n",
      " 'n08496334': '185',\n",
      " 'n08620881': '186',\n",
      " 'n08742578': '187',\n",
      " 'n09193705': '188',\n",
      " 'n09246464': '189',\n",
      " 'n09256479': '190',\n",
      " 'n09332890': '191',\n",
      " 'n09428293': '192',\n",
      " 'n12267677': '193',\n",
      " 'n12520864': '194',\n",
      " 'n13001041': '195',\n",
      " 'n13652335': '196',\n",
      " 'n13652994': '197',\n",
      " 'n13719102': '198',\n",
      " 'n14991210': '199'}"
     ]
    }
   ],
   "source": [
    "model.metadata.class_vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "df7b69a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(\n",
    "    plugins=[NNCheckpointIO(jailing_dir=\"./tmp\")],\n",
    ")\n",
    "\n",
    "temp_path = \"temp_checkpoint.ckpt\"\n",
    "\n",
    "trainer.strategy.connect(model)\n",
    "trainer.save_checkpoint(temp_path)\n",
    "\n",
    "model_class = model.__class__.__module__ + \".\" + model.__class__.__qualname__\n",
    "\n",
    "artifact_name = f\"\"\n",
    "model_artifact = wandb.Artifact(\n",
    "    name=\"CIFAR100_ResNet22_16_2\",\n",
    "    type=\"checkpoint\",\n",
    "    metadata={\"model_identifier\": \"ResNet22_16\", \"model_class\": model_class},\n",
    ")\n",
    "\n",
    "model_artifact.add_file(temp_path + \".zip\", name=\"trained.ckpt.zip\")\n",
    "run.log_artifact(model_artifact)\n",
    "\n",
    "os.remove(temp_path + \".zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0017109e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e26cd666",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_name = \"ResNet22_2_1\"\n",
    "curr_version = \"v2\"\n",
    "new_name = \"tiny_imagenet_ResNet22_2_1\"\n",
    "model_id = \"ResNet22_2_1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bf5feeeb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b34ed59c",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_name = \"ResNet22_2_1\"\n",
    "curr_version = \"v2\"\n",
    "new_name = \"tiny_imagenet_ResNet22_2_1\"\n",
    "model_id = \"ResNet22_2_1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3c008c22",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"gladia\", job_type=\"logistics\")\n",
    "\n",
    "artifact_path = f\"gladia/cycle-consistent-model-merging/{current_name}:{curr_version}\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dd38127f",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(\n",
    "    plugins=[NNCheckpointIO(jailing_dir=\"./tmp\")],\n",
    ")\n",
    "\n",
    "temp_path = \"temp_checkpoint.ckpt\"\n",
    "\n",
    "trainer.strategy.connect(model)\n",
    "trainer.save_checkpoint(temp_path)\n",
    "\n",
    "model_class = model.__class__.__module__ + \".\" + model.__class__.__qualname__\n",
    "\n",
    "artifact_name = f\"\"\n",
    "model_artifact = wandb.Artifact(\n",
    "    name=new_name,\n",
    "    type=\"checkpoint\",\n",
    "    metadata={\"model_identifier\": model_id, \"model_class\": model_class},\n",
    ")\n",
    "\n",
    "model_artifact.add_file(temp_path + \".zip\", name=\"trained.ckpt.zip\")\n",
    "run.log_artifact(model_artifact)\n",
    "\n",
    "os.remove(temp_path + \".zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "28169fe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "42461372",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_name = \"ResNet22_2_2\"\n",
    "curr_version = \"v0\"\n",
    "new_name = \"tiny_imagenet_ResNet22_2_2\"\n",
    "model_id = \"ResNet22_2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2660da8b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Finishing last run (ID:006qdgea) before initializing another..."
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da101ea5d3134bca80b8a4a82e1eceb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='4.292 MB of 4.308 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=0.996308…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">eager-pine-2360</strong> at: <a href='https://wandb.ai/gladia/cycle-consistent-model-merging/runs/006qdgea' target=\"_blank\">https://wandb.ai/gladia/cycle-consistent-model-merging/runs/006qdgea</a><br/>Synced 7 W&B file(s), 0 media file(s), 1 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20240327_100818-006qdgea/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Successfully finished last run (ID:006qdgea). Initializing new run:<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0be3019db2c6495ab773d4713472627c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.016669480666557015, max=1.0…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.16.5 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.8"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/donato/Code/model-merging/cycle-consistent-model-merging/notebooks/dev/wandb/run-20240327_100952-aw37lurw</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/gladia/cycle-consistent-model-merging/runs/aw37lurw' target=\"_blank\">tough-capybara-2361</a></strong> to <a href='https://wandb.ai/gladia/cycle-consistent-model-merging' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/gladia/cycle-consistent-model-merging' target=\"_blank\">https://wandb.ai/gladia/cycle-consistent-model-merging</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/gladia/cycle-consistent-model-merging/runs/aw37lurw' target=\"_blank\">https://wandb.ai/gladia/cycle-consistent-model-merging/runs/aw37lurw</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"gladia\", job_type=\"logistics\")\n",
    "\n",
    "artifact_path = f\"gladia/cycle-consistent-model-merging/{current_name}:{curr_version}\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "15c85235",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">2024-03-27 10:10:14 </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> GPU available: <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-style: italic\">True</span> <span style=\"font-weight: bold\">(</span>cuda<span style=\"font-weight: bold\">)</span>, used: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>     <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1751\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1751</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m2024-03-27 10:10:14\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m GPU available: \u001b[3;92mTrue\u001b[0m \u001b[1m(\u001b[0mcuda\u001b[1m)\u001b[0m, used: \u001b[3;91mFalse\u001b[0m     \u001b]8;id=523753;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=776996;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1751\u001b\\\u001b[2m1751\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">                    </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> TPU available: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>, using: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> TPU cores    <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1754\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1754</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m                   \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m TPU available: \u001b[3;91mFalse\u001b[0m, using: \u001b[1;36m0\u001b[0m TPU cores    \u001b]8;id=158084;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=61832;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1754\u001b\\\u001b[2m1754\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">                    </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> IPU available: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>, using: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> IPUs         <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1757\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1757</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m                   \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m IPU available: \u001b[3;91mFalse\u001b[0m, using: \u001b[1;36m0\u001b[0m IPUs         \u001b]8;id=681821;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=960749;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1757\u001b\\\u001b[2m1757\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">                    </span><span style=\"color: #000080; text-decoration-color: #000080\">INFO    </span> HPU available: <span style=\"color: #ff0000; text-decoration-color: #ff0000; font-style: italic\">False</span>, using: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> HPUs         <a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">pytorch_lightning.utilities.rank_zero</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1760\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">1760</span></a>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[2;36m                   \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m HPU available: \u001b[3;91mFalse\u001b[0m, using: \u001b[1;36m0\u001b[0m HPUs         \u001b]8;id=79692;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py\u001b\\\u001b[2mpytorch_lightning.utilities.rank_zero\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=36353;file:///home/donato/miniconda3/envs/ccmm/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py#1760\u001b\\\u001b[2m1760\u001b[0m\u001b]8;;\u001b\\\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    plugins=[NNCheckpointIO(jailing_dir=\"./tmp\")],\n",
    ")\n",
    "\n",
    "temp_path = \"temp_checkpoint.ckpt\"\n",
    "\n",
    "trainer.strategy.connect(model)\n",
    "trainer.save_checkpoint(temp_path)\n",
    "\n",
    "model_class = model.__class__.__module__ + \".\" + model.__class__.__qualname__\n",
    "\n",
    "artifact_name = f\"\"\n",
    "model_artifact = wandb.Artifact(\n",
    "    name=new_name,\n",
    "    type=\"checkpoint\",\n",
    "    metadata={\"model_identifier\": model_id, \"model_class\": model_class},\n",
    ")\n",
    "\n",
    "model_artifact.add_file(temp_path + \".zip\", name=\"trained.ckpt.zip\")\n",
    "run.log_artifact(model_artifact)\n",
    "\n",
    "os.remove(temp_path + \".zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9e3bde70",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_name = \"ResNet22_2_3\"\n",
    "curr_version = \"v0\"\n",
    "new_name = \"tiny_imagenet_ResNet22_2_3\"\n",
    "model_id = \"ResNet22_2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e46bc1c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"gladia\", job_type=\"logistics\")\n",
    "\n",
    "artifact_path = f\"gladia/cycle-consistent-model-merging/{current_name}:{curr_version}\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
