{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "94441c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from transformers import BertModel\n",
    "\n",
    "from models.PersonSAM import PersonSAM\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ef408822",
   "metadata": {},
   "outputs": [],
   "source": [
    "sam_path = r'F:\\preTrainedModels\\sam-vit-base-tbpr'\n",
    "language_model_path = r'F:\\preTrainedModels\\bert-base-uncased'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1abe5572",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at F:\\preTrainedModels\\sam-vit-base-tbpr were not used when initializing PersonSAM: ['prompt_encoder.mask_embed.conv1.bias', 'prompt_encoder.mask_embed.layer_norm1.bias', 'prompt_encoder.no_mask_embed.weight', 'prompt_encoder.point_embed.0.weight', 'prompt_encoder.mask_embed.layer_norm1.weight', 'prompt_encoder.mask_embed.layer_norm2.bias', 'prompt_encoder.point_embed.2.weight', 'prompt_encoder.mask_embed.conv2.bias', 'prompt_encoder.mask_embed.layer_norm2.weight', 'prompt_encoder.shared_embedding.positional_embedding', 'prompt_encoder.point_embed.1.weight', 'prompt_encoder.mask_embed.conv2.weight', 'prompt_encoder.not_a_point_embed.weight', 'prompt_encoder.point_embed.3.weight', 'prompt_encoder.mask_embed.conv3.bias', 'prompt_encoder.mask_embed.conv1.weight', 'prompt_encoder.mask_embed.conv3.weight']\n",
      "- This IS expected if you are initializing PersonSAM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing PersonSAM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of PersonSAM were not initialized from the model checkpoint at F:\\preTrainedModels\\sam-vit-base-tbpr and are newly initialized: ['no_mask_embed.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Some weights of the model checkpoint at F:\\preTrainedModels\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "sam_model = PersonSAM.from_pretrained(sam_path)\n",
    "text_prompt_encoder = BertModel.from_pretrained(language_model_path)\n",
    "sam_model.set_pretrained_text_encoder(text_prompt_encoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "64b49ace",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set frozen parameters\n",
    "for name, param in sam_model.named_parameters():\n",
    "    if name.startswith(\"vision_encoder\") or name.startswith(\"prompt_encoder\") \\\n",
    "        or name.startswith(\"text_prompt_encoder.embeddings\") or name.startswith(\"text_prompt_encoder.encoder\"):\n",
    "        param.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b33a73ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method Module.requires_grad_ of PersonSAM(\n",
       "  (shared_image_embedding): SamPositionalEmbedding()\n",
       "  (vision_encoder): SamVisionEncoder(\n",
       "    (patch_embed): SamPatchEmbeddings(\n",
       "      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
       "    )\n",
       "    (layers): ModuleList(\n",
       "      (0-11): 12 x SamVisionLayer(\n",
       "        (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
       "        (attn): SamVisionAttention(\n",
       "          (qkv): Linear(in_features=768, out_features=2304, bias=True)\n",
       "          (proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "        )\n",
       "        (layer_norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)\n",
       "        (mlp): SamMLPBlock(\n",
       "          (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "          (act): GELUActivation()\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (neck): SamVisionNeck(\n",
       "      (conv1): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (layer_norm1): SamLayerNorm()\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (layer_norm2): SamLayerNorm()\n",
       "    )\n",
       "  )\n",
       "  (no_mask_embed): Embedding(1, 256)\n",
       "  (mask_decoder): SamMaskDecoder(\n",
       "    (iou_token): Embedding(1, 256)\n",
       "    (mask_tokens): Embedding(4, 256)\n",
       "    (transformer): SamTwoWayTransformer(\n",
       "      (layers): ModuleList(\n",
       "        (0-1): 2 x SamTwoWayAttentionBlock(\n",
       "          (self_attn): SamAttention(\n",
       "            (q_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (k_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (v_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "            (out_proj): Linear(in_features=256, out_features=256, bias=True)\n",
       "          )\n",
       "          (layer_norm1): LayerNorm((256,), eps=1e-06, elementwise_affine=True)\n",
       "          (cross_attn_token_to_image): SamAttention(\n",
       "            (q_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (k_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (v_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (out_proj): Linear(in_features=128, out_features=256, bias=True)\n",
       "          )\n",
       "          (layer_norm2): LayerNorm((256,), eps=1e-06, elementwise_affine=True)\n",
       "          (mlp): SamMLPBlock(\n",
       "            (lin1): Linear(in_features=256, out_features=2048, bias=True)\n",
       "            (lin2): Linear(in_features=2048, out_features=256, bias=True)\n",
       "            (act): ReLU()\n",
       "          )\n",
       "          (layer_norm3): LayerNorm((256,), eps=1e-06, elementwise_affine=True)\n",
       "          (layer_norm4): LayerNorm((256,), eps=1e-06, elementwise_affine=True)\n",
       "          (cross_attn_image_to_token): SamAttention(\n",
       "            (q_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (k_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (v_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "            (out_proj): Linear(in_features=128, out_features=256, bias=True)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (final_attn_token_to_image): SamAttention(\n",
       "        (q_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "        (k_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "        (v_proj): Linear(in_features=256, out_features=128, bias=True)\n",
       "        (out_proj): Linear(in_features=128, out_features=256, bias=True)\n",
       "      )\n",
       "      (layer_norm_final_attn): LayerNorm((256,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (upscale_conv1): ConvTranspose2d(256, 64, kernel_size=(2, 2), stride=(2, 2))\n",
       "    (upscale_conv2): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))\n",
       "    (upscale_layer_norm): SamLayerNorm()\n",
       "    (activation): GELU(approximate='none')\n",
       "    (output_hypernetworks_mlps): ModuleList(\n",
       "      (0-3): 4 x SamFeedForward(\n",
       "        (activation): ReLU()\n",
       "        (proj_in): Linear(in_features=256, out_features=256, bias=True)\n",
       "        (proj_out): Linear(in_features=256, out_features=32, bias=True)\n",
       "        (layers): ModuleList(\n",
       "          (0): Linear(in_features=256, out_features=256, bias=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (iou_prediction_head): SamFeedForward(\n",
       "      (activation): ReLU()\n",
       "      (proj_in): Linear(in_features=256, out_features=256, bias=True)\n",
       "      (proj_out): Linear(in_features=256, out_features=4, bias=True)\n",
       "      (layers): ModuleList(\n",
       "        (0): Linear(in_features=256, out_features=256, bias=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (text_prompt_encoder): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (text_proj): Linear(in_features=768, out_features=256, bias=True)\n",
       "  (transform_linear): Linear(in_features=768, out_features=256, bias=True)\n",
       "  (transform_attention): MultiheadAttention(\n",
       "    (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)\n",
       "  )\n",
       ")>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sam_model.base_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "10316fa3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "203.885028"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(p.numel() for p in sam_model.parameters() if p.re) / 1000000.0"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
