{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract Your Own ImageNet Pretrained Model for Downstream Tasks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This file serves as an example of how you may extract your ImageNet pretrained models from this repo to a form that can be loaded when using it for the downstream tasks.\n",
    "  \n",
    "First, set up the libraries to import and the GPU device to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '9'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, put in the dataset and the config name of the pretrained model to lead the script to the right path. We need all parameters in the CPU mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 8M parameters in the state dict\n"
     ]
    }
   ],
   "source": [
    "dataset = 'imagenet'\n",
    "config_name = 'cls_mdeq_SMALLER'\n",
    "\n",
    "# Start extracting state dict\n",
    "pd = torch.load(f'output/{dataset}/{config_name}/checkpoint.pth.tar')['state_dict']\n",
    "pd_cpu = {k: pd[k].cpu() for k in pd}\n",
    "count = 0\n",
    "for k in pd:\n",
    "    if \"classifier\" not in k and \\\n",
    "       \"final\" not in k and \\\n",
    "       \"incre_modules\" not in k and \\\n",
    "       \"downsamp_modules\" not in k and \\\n",
    "       \"copy\" not in k and \\\n",
    "       \"deq\" not in k:\n",
    "            count += pd[k].nelement()\n",
    "size = str(count // 1000000) + \"M\"\n",
    "print(f\"Found {size} parameters in the state dict\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved at ./pretrained_models/mdeq_SMALLER.pth\n"
     ]
    }
   ],
   "source": [
    "# Start storing to the pretrained_models directory\n",
    "new_name = '_'.join(config_name.split('_')[1:])\n",
    "new_path = f'./pretrained_models/{new_name}.pth'\n",
    "torch.save(pd_cpu, new_path)\n",
    "print(f\"Saved at {new_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert Old version DEQ to New version DEQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 16M parameters in the state dict\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "pd = torch.load('pretrained_models/mdeq_SMALL.pth')\n",
    "new_pd = {}\n",
    "count = 0\n",
    "for k in pd:\n",
    "    if \"copy\" not in k and \\\n",
    "       \"deq\" not in k:\n",
    "            new_pd[k] = pd[k].clone().detach().cpu()\n",
    "            count += pd[k].nelement()\n",
    "size = str(count // 1000000) + \"M\"\n",
    "print(f\"Found {size} parameters in the state dict\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "torch.save(new_pd, f'pretrained_models/mdeq_SMALL_new.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Cityscapes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 8M parameters in the state dict\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "pd = torch.load('pretrained_models/mdeq_SMALL_Cityscapes.pth')\n",
    "new_pd = {}\n",
    "count = 0\n",
    "for k in pd:\n",
    "    if \"classifier\" not in k and \\\n",
    "       \"final\" not in k and \\\n",
    "       \"incre_modules\" not in k and \\\n",
    "       \"downsamp_modules\" not in k and \\\n",
    "       \"copy\" not in k and \\\n",
    "       \"deq\" not in k:\n",
    "            new_pd[k] = pd[k].clone().detach().cpu()\n",
    "            count += pd[k].nelement()\n",
    "size = str(count // 1000000) + \"M\"\n",
    "print(f\"Found {size} parameters in the state dict\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(new_pd, f'pretrained_models/mdeq_SMALL_Cityscapes_new.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
