{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "94441c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from transformers import BertModel\n",
    "\n",
    "from models.PersonSAM import PersonSAM\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ef408822",
   "metadata": {},
   "outputs": [],
   "source": [
    "sam_path = r'F:\\preTrainedModels\\sam-vit-base-tbpr'\n",
    "language_model_path = r'F:\\preTrainedModels\\bert-base-uncased'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1abe5572",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at F:\\preTrainedModels\\sam-vit-base-tbpr were not used when initializing PersonSAM: ['prompt_encoder.mask_embed.conv1.bias', 'prompt_encoder.mask_embed.layer_norm1.bias', 'prompt_encoder.no_mask_embed.weight', 'prompt_encoder.point_embed.0.weight', 'prompt_encoder.mask_embed.layer_norm1.weight', 'prompt_encoder.mask_embed.layer_norm2.bias', 'prompt_encoder.point_embed.2.weight', 'prompt_encoder.mask_embed.conv2.bias', 'prompt_encoder.mask_embed.layer_norm2.weight', 'prompt_encoder.shared_embedding.positional_embedding', 'prompt_encoder.point_embed.1.weight', 'prompt_encoder.mask_embed.conv2.weight', 'prompt_encoder.not_a_point_embed.weight', 'prompt_encoder.point_embed.3.weight', 'prompt_encoder.mask_embed.conv3.bias', 'prompt_encoder.mask_embed.conv1.weight', 'prompt_encoder.mask_embed.conv3.weight']\n",
      "- This IS expected if you are initializing PersonSAM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing PersonSAM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of PersonSAM were not initialized from the model checkpoint at F:\\preTrainedModels\\sam-vit-base-tbpr and are newly initialized: ['no_mask_embed.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Some weights of the model checkpoint at F:\\preTrainedModels\\bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "sam_model = PersonSAM.from_pretrained(sam_path)\n",
    "text_prompt_encoder = BertModel.from_pretrained(language_model_path)\n",
    "sam_model.set_pretrained_text_encoder(text_prompt_encoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "64b49ace",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set frozen parameters\n",
    "for name, param in sam_model.named_parameters():\n",
    "    if name.startswith(\"vision_encoder\") or name.startswith(\"prompt_encoder\") \\\n",
    "        or name.startswith(\"text_prompt_encoder.embeddings\") or name.startswith(\"text_prompt_encoder.encoder\"):\n",
    "        param.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "10316fa3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5.322468"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(p.numel() for p in sam_model.parameters() if p.requires_grad) / 1000000.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d595d5c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
