{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "import wandb\n",
    "from nn_core.model_logging import NNLogger\n",
    "from nn_core.serialization import NNCheckpointIO\n",
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_name = \"ResNet22_2_3\"\n",
    "curr_version = \"v0\"\n",
    "new_name = \"tiny_imagenet_ResNet22_2_3\"\n",
    "model_id = \"ResNet22_2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run = wandb.init(project=\"cycle-consistent-model-merging\", entity=\"ANONYMIZED\", job_type=\"logistics\")\n",
    "\n",
    "artifact_path = f\"ANONYMIZED/cycle-consistent-model-merging/{current_name}:{curr_version}\"\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "model = load_model_from_artifact(run, artifact_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(\n",
    "    plugins=[NNCheckpointIO(jailing_dir=\"./tmp\")],\n",
    ")\n",
    "\n",
    "temp_path = \"temp_checkpoint.ckpt\"\n",
    "\n",
    "trainer.strategy.connect(model)\n",
    "trainer.save_checkpoint(temp_path)\n",
    "\n",
    "model_class = model.__class__.__module__ + \".\" + model.__class__.__qualname__\n",
    "\n",
    "artifact_name = f\"\"\n",
    "model_artifact = wandb.Artifact(\n",
    "    name=new_name,\n",
    "    type=\"checkpoint\",\n",
    "    metadata={\"model_identifier\": model_id, \"model_class\": model_class},\n",
    ")\n",
    "\n",
    "model_artifact.add_file(temp_path + \".zip\", name=\"trained.ckpt.zip\")\n",
    "run.log_artifact(model_artifact)\n",
    "\n",
    "os.remove(temp_path + \".zip\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
