{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "######## Loading the Vision Language Model from huggingface VLModel (BERT + ConvNext) ########\n",
    "import torch.nn as nn\n",
    "from huggingface_hub import PyTorchModelHubMixin\n",
    "\n",
    "class HFModel(\n",
    "    nn.Module,\n",
    "    PyTorchModelHubMixin\n",
    "):\n",
    "    def __init__(self, model):\n",
    "        super().__init__()\n",
    "        self.custom_model = model\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.custom_model(x)\n",
    "\n",
    "hf_model = HFModel.from_pretrained(\"will be shared upon acceptance\")\n",
    "vision_encoder = hf_model.vision_encoder ### ConvNext Model \n",
    "text_encoder = hf_model.text_encoder ### BERT Model \n",
    "fusion_module = hf_model.fusion_module ### Fusion Model \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.12 ('pytorch_light')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.12"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "691730d83c554cbe42ce42c85ae3aeec2114e87f06410b30204f3d9c0d65d36c"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
