{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33aaec0d-3499-447d-ba03-679e4425329a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/richard/miniconda3/envs/dl_env/lib/python3.12/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.23). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.\n",
      "  check_for_updates()\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "SCRIPT_DIR = os.path.dirname(os.path.abspath(\".\"))\n",
    "sys.path.append(SCRIPT_DIR)\n",
    "import helper\n",
    "from utils import data_utils\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import training_utils\n",
    "from utils import data_utils\n",
    "import torch\n",
    "from model import models\n",
    "import json\n",
    "import os\n",
    "from model import lightning_models\n",
    "import math\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "79ffaa5d-4bbf-46b4-902e-227ea0218b58",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_dir = \"../best_training_imagenet\"\n",
    "pretrained_model_path = os.path.join(model_dir,\"ssl\",'ssl-epoch=999.ckpt')\n",
    "save_model_path = os.path.join(model_dir,\"ssl\",'last_epoch_backbone_resnet50.ckpt')\n",
    "# if loading from GPU\n",
    "device = torch.device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "76c75285-9edc-4d82-bbaa-ca72489a61c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading default settings...\n",
      "[SemiSL]does not exist in the config file\n",
      "[TL]does not exist in the config file\n",
      "[KNN]does not exist in the config file\n",
      "[SemiSL]does not exist in the config file\n",
      "[TL]does not exist in the config file\n",
      "[KNN]does not exist in the config file\n",
      "[INFO]\n",
      "num_nodes = 4\n",
      "gpus_per_node = 4\n",
      "cpus_per_gpu = 12\n",
      "prefetch_factor = 2\n",
      "precision = 16-mixed\n",
      "fix_random_seed = True\n",
      "strategy = ddp\n",
      "if_profile = False\n",
      "\n",
      "[DATA]\n",
      "dataset = IMAGENET1K\n",
      "n_views = 4\n",
      "n_trans = 2\n",
      "augmentation_package = albumentations\n",
      "augmentations = ['RandomResizedCrop', 'GaussianBlur', 'RandomGrayscale', 'ColorJitter', 'RandomHorizontalFlip', 'RandomSolarize']\n",
      "crop_size = [224]\n",
      "crop_min_scale = [0.08]\n",
      "crop_max_scale = [1.0]\n",
      "hflip_prob = [0.5]\n",
      "blur_kernel_size = [23]\n",
      "blur_prob = [0.5]\n",
      "grayscale_prob = [0.2]\n",
      "jitter_brightness = [0.8]\n",
      "jitter_contrast = [0.8]\n",
      "jitter_saturation = [0.8]\n",
      "jitter_hue = [0.2]\n",
      "jitter_prob = [0.8]\n",
      "solarize_prob = [0.0]\n",
      "imagenet_train_dir = ../datasets/imagenet1k/train/train.lmdb\n",
      "imagenet_val_dir = ../datasets/imagenet1k/val/val.lmdb\n",
      "\n",
      "[SSL]\n",
      "backbone = resnet50\n",
      "use_projection_head = True\n",
      "proj_dim = [8192, 8192]\n",
      "proj_out_dim = 512\n",
      "optimizer = LARS\n",
      "lr = 11.0\n",
      "lr_scale = linear\n",
      "lr_scheduler = cosine-warmup\n",
      "grad_accumulation_steps = 1\n",
      "momentum = 0.0\n",
      "weight_decay = 0.015\n",
      "restart_epochs = -1\n",
      "lars_eta = 0.001\n",
      "loss_function = RepulsiveEllipsoidPackingLossUnitNorm\n",
      "lw0 = 0.0\n",
      "lw1 = 1.0\n",
      "lw2 = 0.0\n",
      "rs = 7.0\n",
      "pot_pow = 2.0\n",
      "warmup_epochs = 10\n",
      "n_epochs = 1000\n",
      "batch_size = 1024\n",
      "save_every_n_epochs = 10\n",
      "restart_training = False\n",
      "max_mem_size = 0\n",
      "max_grad_norm = -1.0\n",
      "\n",
      "[LC]\n",
      "output_dim = 1000\n",
      "optimizer = SGD\n",
      "use_batch_norm = False\n",
      "lr = 0.2\n",
      "lr_scale = linear\n",
      "lr_scheduler = cosine\n",
      "weight_decay = 1e-05\n",
      "momentum = 0.9\n",
      "loss_function = CrossEntropyLoss\n",
      "n_epochs = 100\n",
      "batch_size = 1024\n",
      "save_every_n_epochs = 50\n",
      "apply_simple_augmentations = True\n",
      "standardize_to_imagenet = True\n",
      "restart_training = False\n",
      "lr_sweep = [0.1, 3.0, 4.5, 6.0, 7.5]\n",
      "\n",
      "lr overrided by lr_sweep!!\n",
      "max_mem_size is dummy for RepulsiveEllipsoidPackingLossUnitNorm\n",
      "lw2 is dummy for RepulsiveEllipsoidPackingLossUnitNorm\n"
     ]
    }
   ],
   "source": [
    "config = helper.Config(model_dir,default_config_file=\"../default_configs/default_config_imagenet1k.ini\")\n",
    "if \"CIFAR\" in config.DATA[\"dataset\"] or \"MNIST\" in config.DATA[\"dataset\"]:\n",
    "    prune_backbone = True\n",
    "else:\n",
    "    prune_backbone = False\n",
    "ssl_model = lightning_models.CLAP(backbone_name = config.SSL[\"backbone\"],\n",
    "                                  prune = prune_backbone,\n",
    "                                  use_projection_head=config.SSL[\"use_projection_head\"],\n",
    "                                  proj_dim = config.SSL[\"proj_dim\"],\n",
    "                                  proj_out_dim = config.SSL[\"proj_out_dim\"],\n",
    "                                  loss_name= config.SSL[\"loss_function\"],\n",
    "                                  optim_name = config.SSL[\"optimizer\"],\n",
    "                                  lr = 1.0,\n",
    "                                  scheduler_name = config.SSL[\"lr_scheduler\"],\n",
    "                                  momentum = config.SSL[\"momentum\"],\n",
    "                                  weight_decay = config.SSL[\"weight_decay\"],\n",
    "                                  eta = config.SSL[\"lars_eta\"],\n",
    "                                  warmup_epochs = config.SSL[\"warmup_epochs\"],\n",
    "                                  n_epochs = config.SSL[\"n_epochs\"],\n",
    "                                  n_views = config.DATA[\"n_views\"],\n",
    "                                  batch_size = config.SSL[\"batch_size\"],\n",
    "                                  lw0 = config.SSL[\"lw0\"],\n",
    "                                  lw1 = config.SSL[\"lw1\"],\n",
    "                                  lw2 = config.SSL[\"lw2\"],\n",
    "                                  rs = config.SSL[\"rs\"],\n",
    "                                  pot_pow = config.SSL[\"pot_pow\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baa0e16b-b879-458d-8601-1772c501b2a3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "dc449c08-5c10-4eb7-bdae-148e19a56832",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found pretrained model at ../best_training_imagenet/ssl/ssl-epoch=999.ckpt, loading...\n",
      "max_mem_size is dummy for RepulsiveEllipsoidPackingLossUnitNorm\n",
      "lw2 is dummy for RepulsiveEllipsoidPackingLossUnitNorm\n"
     ]
    }
   ],
   "source": [
    "if os.path.isfile(pretrained_model_path):\n",
    "    print(f'Found pretrained model at {pretrained_model_path}, loading...')\n",
    "    ssl_model = lightning_models.CLAP.load_from_checkpoint(pretrained_model_path)\n",
    "else:\n",
    "    assert \"Model not found!\"\n",
    "torch.save(ssl_model.backbone.net.state_dict(),save_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f510ac88-0928-4325-90ea-2a48cab43fcf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#test if the model can be loaded properly\n",
    "model = torchvision.models.resnet50()\n",
    "model.fc = torch.nn.Identity()\n",
    "model.load_state_dict(torch.load(save_model_path,weights_only=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "86ea2284-b175-44dd-9633-947e0fcceb76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv1\n",
      "bn1\n",
      "relu\n",
      "maxpool\n",
      "layer1\n",
      "layer1.0\n",
      "layer1.0.conv1\n",
      "layer1.0.bn1\n",
      "layer1.0.conv2\n",
      "layer1.0.bn2\n",
      "layer1.0.conv3\n",
      "layer1.0.bn3\n",
      "layer1.0.relu\n",
      "layer1.0.downsample\n",
      "layer1.0.downsample.0\n",
      "layer1.0.downsample.1\n",
      "layer1.1\n",
      "layer1.1.conv1\n",
      "layer1.1.bn1\n",
      "layer1.1.conv2\n",
      "layer1.1.bn2\n",
      "layer1.1.conv3\n",
      "layer1.1.bn3\n",
      "layer1.1.relu\n",
      "layer1.2\n",
      "layer1.2.conv1\n",
      "layer1.2.bn1\n",
      "layer1.2.conv2\n",
      "layer1.2.bn2\n",
      "layer1.2.conv3\n",
      "layer1.2.bn3\n",
      "layer1.2.relu\n",
      "layer2\n",
      "layer2.0\n",
      "layer2.0.conv1\n",
      "layer2.0.bn1\n",
      "layer2.0.conv2\n",
      "layer2.0.bn2\n",
      "layer2.0.conv3\n",
      "layer2.0.bn3\n",
      "layer2.0.relu\n",
      "layer2.0.downsample\n",
      "layer2.0.downsample.0\n",
      "layer2.0.downsample.1\n",
      "layer2.1\n",
      "layer2.1.conv1\n",
      "layer2.1.bn1\n",
      "layer2.1.conv2\n",
      "layer2.1.bn2\n",
      "layer2.1.conv3\n",
      "layer2.1.bn3\n",
      "layer2.1.relu\n",
      "layer2.2\n",
      "layer2.2.conv1\n",
      "layer2.2.bn1\n",
      "layer2.2.conv2\n",
      "layer2.2.bn2\n",
      "layer2.2.conv3\n",
      "layer2.2.bn3\n",
      "layer2.2.relu\n",
      "layer2.3\n",
      "layer2.3.conv1\n",
      "layer2.3.bn1\n",
      "layer2.3.conv2\n",
      "layer2.3.bn2\n",
      "layer2.3.conv3\n",
      "layer2.3.bn3\n",
      "layer2.3.relu\n",
      "layer3\n",
      "layer3.0\n",
      "layer3.0.conv1\n",
      "layer3.0.bn1\n",
      "layer3.0.conv2\n",
      "layer3.0.bn2\n",
      "layer3.0.conv3\n",
      "layer3.0.bn3\n",
      "layer3.0.relu\n",
      "layer3.0.downsample\n",
      "layer3.0.downsample.0\n",
      "layer3.0.downsample.1\n",
      "layer3.1\n",
      "layer3.1.conv1\n",
      "layer3.1.bn1\n",
      "layer3.1.conv2\n",
      "layer3.1.bn2\n",
      "layer3.1.conv3\n",
      "layer3.1.bn3\n",
      "layer3.1.relu\n",
      "layer3.2\n",
      "layer3.2.conv1\n",
      "layer3.2.bn1\n",
      "layer3.2.conv2\n",
      "layer3.2.bn2\n",
      "layer3.2.conv3\n",
      "layer3.2.bn3\n",
      "layer3.2.relu\n",
      "layer3.3\n",
      "layer3.3.conv1\n",
      "layer3.3.bn1\n",
      "layer3.3.conv2\n",
      "layer3.3.bn2\n",
      "layer3.3.conv3\n",
      "layer3.3.bn3\n",
      "layer3.3.relu\n",
      "layer3.4\n",
      "layer3.4.conv1\n",
      "layer3.4.bn1\n",
      "layer3.4.conv2\n",
      "layer3.4.bn2\n",
      "layer3.4.conv3\n",
      "layer3.4.bn3\n",
      "layer3.4.relu\n",
      "layer3.5\n",
      "layer3.5.conv1\n",
      "layer3.5.bn1\n",
      "layer3.5.conv2\n",
      "layer3.5.bn2\n",
      "layer3.5.conv3\n",
      "layer3.5.bn3\n",
      "layer3.5.relu\n",
      "layer4\n",
      "layer4.0\n",
      "layer4.0.conv1\n",
      "layer4.0.bn1\n",
      "layer4.0.conv2\n",
      "layer4.0.bn2\n",
      "layer4.0.conv3\n",
      "layer4.0.bn3\n",
      "layer4.0.relu\n",
      "layer4.0.downsample\n",
      "layer4.0.downsample.0\n",
      "layer4.0.downsample.1\n",
      "layer4.1\n",
      "layer4.1.conv1\n",
      "layer4.1.bn1\n",
      "layer4.1.conv2\n",
      "layer4.1.bn2\n",
      "layer4.1.conv3\n",
      "layer4.1.bn3\n",
      "layer4.1.relu\n",
      "layer4.2\n",
      "layer4.2.conv1\n",
      "layer4.2.bn1\n",
      "layer4.2.conv2\n",
      "layer4.2.bn2\n",
      "layer4.2.conv3\n",
      "layer4.2.bn3\n",
      "layer4.2.relu\n",
      "avgpool\n",
      "fc\n"
     ]
    }
   ],
   "source": [
    "for name,module in model.named_modules():\n",
    "    if name:\n",
    "        print(name)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59a8d3b8-a301-43e6-8ea7-02c9baa7cf47",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
