{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fe6db1be-6ee0-494d-af38-26244261f25e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/miniforge/envs/iris/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "/mnt/task_runtime/datacomp/my_open_clip/src/my_open_clip/factory.py:129: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  checkpoint = torch.load(checkpoint_path, map_location=map_location)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image features shape: torch.Size([1, 512])\n",
      "Text features shape: torch.Size([1, 512])\n",
      "Logit scale: 21.641887664794922\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/miniforge/envs/iris/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
      "  return fn(*args, **kwargs)\n",
      "/miniforge/envs/iris/lib/python3.10/site-packages/torch/utils/checkpoint.py:92: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import my_open_clip.src.my_open_clip as open_clip\n",
    "\n",
    "# Global dictionary to store original forward methods\n",
    "original_forwards = {}\n",
    "\n",
    "def checkpointed_forward(module, *inputs, **kwargs):\n",
    "    def custom_forward(*inputs):\n",
    "        inputs, kwargs = inputs[:-1], inputs[-1]\n",
    "        return original_forwards[id(module)](*inputs, **kwargs)\n",
    "    return torch.utils.checkpoint.checkpoint(custom_forward, *inputs, kwargs)\n",
    "\n",
    "def apply_grad_checkpointing(model, start_layer=0, end_layer=None):\n",
    "    if end_layer is None:\n",
    "        end_layer = len(model.visual.transformer.resblocks)\n",
    "    \n",
    "    for i, layer in enumerate(model.visual.transformer.resblocks):\n",
    "        if start_layer <= i < end_layer:\n",
    "            if id(layer) not in original_forwards:\n",
    "                original_forwards[id(layer)] = layer.forward\n",
    "            layer.forward = lambda *inputs, layer=layer, **kwargs: checkpointed_forward(layer, *inputs, **kwargs)\n",
    "\n",
    "    for i, layer in enumerate(model.transformer.resblocks):\n",
    "        if start_layer <= i < end_layer:\n",
    "            if id(layer) not in original_forwards:\n",
    "                original_forwards[id(layer)] = layer.forward\n",
    "            layer.forward = lambda *inputs, layer=layer, **kwargs: checkpointed_forward(layer, *inputs, **kwargs)\n",
    "\n",
    "def remove_grad_checkpointing(model):\n",
    "    for layer in model.visual.transformer.resblocks:\n",
    "        if id(layer) in original_forwards:\n",
    "            layer.forward = original_forwards[id(layer)]\n",
    "            del original_forwards[id(layer)]\n",
    "\n",
    "# Updated test function\n",
    "def test_grad_checkpointing(model, tokenizer):\n",
    "    device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model.to(device)\n",
    "    \n",
    "    # Prepare image input\n",
    "    dummy_image = torch.randn(1, 3, 224, 224).to(device)\n",
    "    \n",
    "    # Prepare text input\n",
    "    text = [\"a photo of a cat\"]\n",
    "    text_tokens = tokenizer(text).to(device)\n",
    "    \n",
    "    # Forward pass\n",
    "    with torch.no_grad():  # Use no_grad for testing\n",
    "        image_features, text_features, logit_scale = model(dummy_image, text_tokens)\n",
    "    \n",
    "    print(f\"Image features shape: {image_features.shape}\")\n",
    "    print(f\"Text features shape: {text_features.shape}\")\n",
    "    print(f\"Logit scale: {logit_scale.item()}\")\n",
    "\n",
    "# Load the model and tokenizer\n",
    "my_model, _, _ = open_clip.create_model_and_transforms('ViT-B-32', pretrained='datacomp_s_s13m_b4k')\n",
    "my_tokenizer = open_clip.get_tokenizer('ViT-B-32')\n",
    "\n",
    "# Apply gradient checkpointing to layers 4 through 7\n",
    "apply_grad_checkpointing(my_model, start_layer=4, end_layer=8)\n",
    "\n",
    "# Test the model\n",
    "test_grad_checkpointing(my_model, my_tokenizer)\n",
    "\n",
    "# Uncomment the following line when you want to remove checkpointing\n",
    "# remove_grad_checkpointing(my_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adba5000-2db6-413d-a5a4-469441fc5243",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f1b1eb6-5775-4f57-8c16-d043c1580d5d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
