{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10dc9b37-e9af-4896-9151-2af37b48e7ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fbd89aa-a61e-40b2-b5a2-84207476ab53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4c8d3f5-4d6e-480b-a79e-1a958916be68",
   "metadata": {},
   "source": [
    "## Image Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85111346-8be7-4486-a1ce-409758bb7d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from src.model.components.image_encoder import ImageEncoder, available_models\n",
    "encoder = ImageEncoder(\"resnet18\", pretrained=False)\n",
    "img = torch.zeros(2,3,224,224)\n",
    "encoder.eval()\n",
    "x = encoder(img)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fa0252a-308e-47b9-bc8e-5c06bd84aaee",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = ImageEncoder(\"inception_v3\", pretrained=False, promise_input_dim=299)\n",
    "img = torch.zeros(2,3,299,299)\n",
    "encoder.eval()\n",
    "x = encoder(img)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4fc8402-0a66-4245-b6c5-1840ffc1b5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6cd2d64e-75c2-42c4-a4ad-9aa845c8f906",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = timm.create_model(\"swin_large_patch4_window7_224\", num_classes=0, pretrained=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6024ff8-bb99-4c5b-8692-c48c8c1b52b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = model(torch.randn(2, 3, 224, 224))\n",
    "o.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9174d338-28c1-406f-b354-a06b52fd1dce",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.num_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e446c9a0-3f5d-4b00-9269-966a4edecbf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8acc4a3c-2136-4d99-8a23-e21bb8043cfb",
   "metadata": {},
   "source": [
    "## Context Encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035bf2fa-a2a1-4f54-9655-8b26d4fec644",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from src.model.components.context_encoder import SimpleLSTMContext, GLocalContext, TransformerContext"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42917dab-3802-46f6-9ecd-4b1d452daaa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = SimpleLSTMContext(bidirectional=True)\n",
    "x = torch.zeros(3,5,512).float()\n",
    "mask = torch.rand(3,5) > 0.5\n",
    "y = encoder(x, mask)\n",
    "print(f\"context: {y['context'].shape}\")\n",
    "print(f\"state h_n: {y['state'][0].shape}\")\n",
    "print(f\"state c_n: {y['state'][1].shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc8202b0-51b0-46dc-828a-216a615da4ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = GLocalContext()\n",
    "x = torch.zeros(2,5,512).float()\n",
    "mask = torch.rand(2,5) > 0.5\n",
    "y = encoder(x, mask)\n",
    "print(f\"context: {y['context'].shape}\")\n",
    "print(f\"state h_n: {y['state'][0].shape}\")\n",
    "print(f\"state c_n: {y['state'][1].shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef51942-17b2-4480-ab5d-4f738c3fe501",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = TransformerContext()\n",
    "x = torch.zeros(2,5,512).float()\n",
    "mask = torch.rand(2,5) > 0.5\n",
    "y = encoder(x, mask)\n",
    "print(f\"context: {y['context'].shape}\")\n",
    "for i, attn in enumerate(y['attention']):\n",
    "    print(f\"{i}: {attn.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ef8a2a1-0d62-4f70-8d12-ffdcc602f854",
   "metadata": {},
   "source": [
    "## text decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b23557fd-4f69-4349-bf1c-9e3abfc46bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from src.model.components.text_decoder import ContextLSTMText, N_WORDS, IndependentLSTMText, TransformerText"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943d1445-72d4-4ad5-8247-121e604d9031",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = IndependentLSTMText()\n",
    "features = torch.zeros(2, 5, 512)\n",
    "hn = torch.zeros(1, 2, 512)\n",
    "cn = torch.zeros(1, 2, 512)\n",
    "mask = torch.rand(2,5,24) > 0.5\n",
    "captions = torch.randint(0, N_WORDS, (2, 5, 24))\n",
    "x = decoder((hn, cn), features, captions=captions, mask=mask)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "442eee8f-4fb4-4551-aa03-1458ef0dd402",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = ContextLSTMText()\n",
    "context = torch.zeros(2, 5, 512)\n",
    "mask = torch.rand(2,5,24) > 0.5\n",
    "captions = torch.randint(0, N_WORDS, (2, 5, 24))\n",
    "x = decoder(context, captions, mask)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea419dfb-950c-49a8-a7e1-ded3b7b99694",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoder = TransformerText()\n",
    "context = torch.zeros(2, 5, 512)\n",
    "mask = torch.rand(2,5,24) > 0.5\n",
    "captions = torch.randint(0, N_WORDS, (2, 5, 24))\n",
    "x = decoder(context, captions, mask)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5791486-30b4-43e5-820a-b87e22e2931e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e438f9b-ab91-42aa-ac9e-d96c2e2ab444",
   "metadata": {},
   "outputs": [],
   "source": [
    "## CSTNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0d3809c-a3b5-4e85-8923-e37aac3dcf05",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.model.cst import CST\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab7c1e28-5a56-4965-9aca-2d4a751be4df",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CST()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d44b8758-c34e-4bd0-a39a-d9080b39d179",
   "metadata": {},
   "outputs": [],
   "source": [
    "states = torch.zeros(1,13,3,224,224)\n",
    "states_mask = torch.rand(1,13)>0.5\n",
    "captions = torch.randint(0, 1000, (1, 12, 24))\n",
    "captions_mask = torch.rand(1,12,24)>0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce23aa66-9769-461a-9a75-1c64d8c433af",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = model(states, states_mask, captions, captions_mask)\n",
    "x[\"logits\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a3ac707-0652-4258-a2ff-8b01c87e5d28",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.onnx.export(\n",
    "    model,\n",
    "    args=(states, states_mask, captions, captions_mask),\n",
    "    f=\"cst.onnx\",\n",
    "    export_params=False,\n",
    "    opset_version=12\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4d426ee-9656-4adb-997e-3c2dec978711",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1218c6a-0d15-49c6-9766-920626225c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.model.glacnet import GLACNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47afbd0e-5bcd-4a57-b536-1ee4c9a8f71b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = GLACNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eccd242a-159e-4bdb-b0ae-b7dee150c907",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = model(states, states_mask, captions, captions_mask)\n",
    "x[\"logits\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0605cb98-54b0-4b8c-8b85-825455912e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.onnx.export(\n",
    "    model,\n",
    "    args=(states, states_mask, captions, captions_mask),\n",
    "    f=\"glacnet.onnx\",\n",
    "    export_params=False,\n",
    "    opset_version=12\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45254853-a2d6-4b7e-8bb2-08b20bb9e251",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.model.ttnet import TTNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "702b17bf-c8cf-4983-bd30-44c48df82fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TTNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4456078f-b23f-4243-8e73-cec94318e071",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = model(states, states_mask, captions, captions_mask)\n",
    "x[\"logits\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18adead2-fd82-49ca-bac9-8036a6397b5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.onnx.export(\n",
    "    model,\n",
    "    args=(states, states_mask, captions, captions_mask),\n",
    "    f=\"ttnet.onnx\",\n",
    "    export_params=False,\n",
    "    opset_version=12\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6c42cbe-8fa1-4b6b-a672-31b173d224b6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
