{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "gpuClass": "standard",
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "### Summary\n",
        "\n",
        "- Choose encoders (ViT, DINO, SentenceT)\n",
        "- Design the visual understanding  of your model\n",
        "    - upload an image-caption dataset or choose from the existing ones (CC12M, ImageNet, ImageNet v2, CIFAR, PETS)\n",
        "    - format of the dataset: \n",
        "        - tsv with images URL and captions\n",
        "        - folder of images zipped and a tsv with image file names and captions\n",
        "        - a standard Pytorch dataset\n",
        "            - if you upload an image-label dataset, transform it to captions defining some templates\n",
        "    - do you want to integrate your dataset with CC12M?\n",
        "- Test your model\n",
        "    - upload an image-caption test dataset\n",
        "    - upload a single image and some captions to choose from\n",
        "- Deep dive into the classification\n",
        "    - look at the elements in the dataset responsible for the classification."
      ],
      "metadata": {
        "id": "DPodlNdnU3j3"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FL8wUrEeUxQ-"
      },
      "outputs": [],
      "source": [
        "!pip install gdown\n",
        "!pip install -r requirements.txt"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# @title import\n",
        "import os\n",
        "import torch\n",
        "import torchvision.transforms as transforms\n",
        "from typing import Tuple, List, Type, Union\n",
        "from sentence_transformers import SentenceTransformer\n",
        "from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTModel, ConvNextFeatureExtractor, ConvNextModel, AutoFeatureExtractor, DeiTModel\n",
        "from matplotlib import pyplot as plt\n",
        "from matplotlib import cm\n",
        "import matplotlib.image as mpimg\n",
        "import numpy as np\n",
        "import gdown\n",
        "import subprocess\n",
        "from embdatasets import EmbeddedImagenet, EmbeddedCIFAR100, EmbeddedPETS, EmbeddedDataset, EmbeddedImagenetV2\n",
        "import re\n",
        "from PIL import Image\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "import pickle\n",
        "\n",
        "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"  # avoid using paallel tokenizers for text transformers for conflicts with dataloaders with multiple workers\n",
        "if not os.path.exists(f'example_images'): os.makedirs('example_images')\n",
        "cc12m_drive_keys = {\n",
        "        'VIT' : 'ANONYMIZED',\n",
        "        'DINO': 'ANONYMIZED',\n",
        "        'DEITtiny': 'ANONYMIZED',\n",
        "        'DEITmedium': 'ANONYMIZED',\n",
        "        'DEITlarge': 'ANONYMIZED',\n",
        "        'SentenceT' : 'ANONYMIZED',\n",
        "        'SentenceTmini' : 'ANONYMIZED',\n",
        "        'SentenceTmedium' : 'ANONYMIZED',\n",
        "        'Metadata' : 'ANONYMIZED',\n",
        "    }"
      ],
      "metadata": {
        "id": "fTPApZwf50jA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Choose the encoders"
      ],
      "metadata": {
        "id": "Sqm8b5SwcFmN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Visual and text encoders\n",
        "\n",
        "class VIT():\n",
        "    \"\"\"\n",
        "    Supervised image encoder based on the original ViT model from Google.\n",
        "    \"\"\"\n",
        "    def __init__(self, model_name: str = 'google/vit-base-patch16-224-in21k', device: str = None, load_jit=True):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model_name: name of the pretrained model to use\n",
        "            device: device to use for inference\n",
        "            load_jit: load model just in time (in the encode method)\n",
        "        \"\"\"\n",
        "        if not device:\n",
        "            self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
        "        else:\n",
        "            self.device = device\n",
        "        self.load_jit = load_jit\n",
        "        self.model_name = model_name\n",
        "        if self.model_name == 'google/vit-base-patch16-224-in21k':\n",
        "            self.embedding_size = 768  # to be faster if embeddings are precomputed\n",
        "        if not self.load_jit or self.model_name != 'google/vit-base-patch16-224-in21k':\n",
        "            self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)\n",
        "            self.model = ViTModel.from_pretrained(model_name).to(self.device)\n",
        "            self.embedding_size = self.model(torch.zeros(1, 3, 224, 224).to(self.device)).last_hidden_state[:,0,:].shape[1]  # 768\n",
        "        \n",
        "    def encode(self, images: torch.Tensor, alr_preprocessed: bool = False) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            images: images to encode\n",
        "            alr_preprocessed: whether the images are already preprocessed to tensors\n",
        "        Returns:\n",
        "            image embeddings\n",
        "        \"\"\"\n",
        "        if self.load_jit:\n",
        "            self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_name)\n",
        "            self.model = ViTModel.from_pretrained(self.model_name).to(self.device)\n",
        "            self.embedding_size = self.model(torch.zeros(1, 3, 224, 224).to(self.device)).last_hidden_state[:,0,:].shape[1]  # 768\n",
        "            self.load_jit = False  # we load the model a single time\n",
        "        with torch.no_grad():\n",
        "            if not alr_preprocessed:\n",
        "                x = self.feature_extractor(images=images, return_tensors=\"pt\").to(self.device)\n",
        "                x = self.model(**x)\n",
        "            else:\n",
        "                x = self.model(images.to(self.device))\n",
        "            x = x.last_hidden_state[:,0,:]\n",
        "        return x  # taking the embedding of the CLS token as image representation\n",
        "\n",
        "\n",
        "class DINO():\n",
        "    \"\"\"\n",
        "    Unsupervised image encoder using the DINO model from Meta Research.\n",
        "    \"\"\"\n",
        "    \n",
        "    def __init__(self, model_name: str = 'dino_vits8', device: str = None, load_jit=True):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model_name: name of the model to use: dino_vit{\"s\"mall or \"b\"ase}{\"8\" or \"16\" patch size}\n",
        "            device: device to use.\n",
        "            load_jit: load model just in time (in the encode method)\n",
        "        \"\"\"\n",
        "        if not device:\n",
        "            self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
        "        else:\n",
        "            self.device = device \n",
        "        self.load_jit = load_jit\n",
        "        self.model_name = model_name\n",
        "        if self.model_name == 'dino_vits8':\n",
        "            self.embedding_size = 384\n",
        "        if not self.load_jit or self.model_name != 'dino_vits8':\n",
        "            self.model = torch.hub.load('facebookresearch/dino:main', model_name)\n",
        "            self.model = self.model.to(self.device)\n",
        "            self.model.eval()\n",
        "            self.embedding_size = self.model(torch.zeros(1, 3, 224, 224).to(self.device)).shape[1]  # 768\n",
        "\n",
        "    def encode(self, images: torch.Tensor, alr_preprocessed: bool = True) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            images: images to encode.\n",
        "            alr_preprocessed: if True, the images are already preprocessed.\n",
        "        Returns:\n",
        "            encoded images.\n",
        "        \"\"\"\n",
        "        if self.load_jit:\n",
        "            self.model = torch.hub.load('facebookresearch/dino:main', self.model_name)\n",
        "            self.model = self.model.to(self.device)\n",
        "            self.model.eval()\n",
        "            self.embedding_size = self.model(torch.zeros(1, 3, 224, 224).to(self.device)).shape[1]  # 768\n",
        "            self.load_jit = False  # we load the model a single time\n",
        "        if not alr_preprocessed:\n",
        "            preprocess = transforms.Compose([transforms.ToTensor(),\n",
        "                                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # assuming imagenet mean and val\n",
        "                                            transforms.Resize((224, 224))])\n",
        "            images = torch.stack([preprocess(image) for image in images])\n",
        "\n",
        "        with torch.no_grad():\n",
        "            images = self.model(images.to(self.device))\n",
        "        return images  # taking the embedding of the CLS token as image representation\n",
        "\n",
        "\n",
        "class DEIT():\n",
        "    \"\"\"\n",
        "    Supervised image encoder based on the DEIT model from Meta (Visual Transformer).\n",
        "    \"\"\"\n",
        "    def __init__(self, model_name: str = 'facebook/deit-tiny-distilled-patch16-224', device: str = None):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model_name: name of the pretrained model to use\n",
        "            device: device to use for inference\n",
        "        \"\"\"\n",
        "        if not device:\n",
        "            self.device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
        "        else:\n",
        "            self.device = device\n",
        "        self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)\n",
        "        self.model = DeiTModel.from_pretrained(model_name).to(self.device)\n",
        "        self.embedding_size = self.model(torch.zeros(1, 3, 224, 224).to(self.device)).last_hidden_state[:,0,:].shape[1]  # 768\n",
        "\n",
        "    def encode(self, images: torch.Tensor, alr_preprocessed: bool = False) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            images: images to encode\n",
        "            alr_preprocessed: whether the images are already preprocessed to tensors\n",
        "        Returns:\n",
        "            image embeddings\n",
        "        \"\"\"\n",
        "        with torch.no_grad():\n",
        "            if not alr_preprocessed:\n",
        "                x = self.feature_extractor(images=images, return_tensors=\"pt\").to(self.device)\n",
        "                x = self.model(**x)\n",
        "            else:\n",
        "                x = self.model(images.to(self.device))\n",
        "            x = x.last_hidden_state[:,0,:]\n",
        "        return x\n",
        "\n",
        "\n",
        "class DEITmedium(DEIT):\n",
        "    def __init__(self, model_name: str='facebook/deit-small-distilled-patch16-224', device: str='cuda', load_jit=False) -> None:\n",
        "        super().__init__(model_name)\n",
        "\n",
        "class DEITlarge(DEIT):\n",
        "    def __init__(self, model_name: str='facebook/deit-base-distilled-patch16-224', device: str='cuda', load_jit=False) -> None:\n",
        "        super().__init__(model_name)\n",
        "\n",
        "\n",
        "class SentenceT():\n",
        "    \"\"\"\n",
        "    Unsupervised text encoder known as SentenceT. This class wraps the sentence-transformers library.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, model_name: str='sentence-transformers/all-mpnet-base-v2', device: str='cuda', load_jit=True) -> None:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            model_name (str): The name of the model to use.\n",
        "            device (str): The device to use.\n",
        "        \"\"\"\n",
        "        self.device = device\n",
        "        self.load_jit = load_jit\n",
        "        self.model_name = model_name\n",
        "        if self.model_name == 'sentence-transformers/all-mpnet-base-v2':\n",
        "            self.embedding_size = 768  # lazy default, if not load_jit we compute it \n",
        "        if not self.load_jit or self.model_name != 'sentence-transformers/all-mpnet-base-v2':\n",
        "            self.model = SentenceTransformer(model_name, device=self.device).eval()\n",
        "            self.text_encoder = self.model.encode\n",
        "            self.embedding_size = self.text_encoder([\"Chi vuol esser lieto, sia: di doman non c'è certezza\"], show_progress_bar=False).shape[1]\n",
        "\n",
        "    def encode(self, texts: List[str]) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            texts (List[str]): A list of texts to encode.\n",
        "\n",
        "        Returns:\n",
        "            torch.Tensor: A tensor of shape (len(texts), 768)\n",
        "        \"\"\"\n",
        "        if self.load_jit:\n",
        "            print('loading encoder now...')\n",
        "            self.model = SentenceTransformer(self.model_name, device=self.device).eval()\n",
        "            self.text_encoder = self.model.encode\n",
        "            self.embedding_size = self.text_encoder([\"Chi vuol esser lieto, sia: di doman non c'è certezza\"], show_progress_bar=False).shape[1]\n",
        "            self.load_jit = False  # we load the model a single time\n",
        "        ztxts = self.text_encoder(texts, show_progress_bar=False)\n",
        "        return torch.tensor(ztxts).to(self.device)\n",
        "\n",
        "\n",
        "class SentenceTmini(SentenceT):\n",
        "    def __init__(self, model_name: str='sentence-transformers/all-MiniLM-L6-v2', device: str='cuda', load_jit=False) -> None:\n",
        "        super().__init__(model_name, device, load_jit)\n",
        "\n",
        "class SentenceTmedium(SentenceT):\n",
        "    def __init__(self, model_name: str='sentence-transformers/all-MiniLM-L12-v2', device: str='cuda', load_jit=False) -> None:\n",
        "        super().__init__(model_name, device, load_jit)\n",
        "\n",
        "\n",
        "class CustomImageEncoder():\n",
        "    def __init__(self):\n",
        "        raise NotImplementedError\n",
        "\n",
        "    def encode(self, images):\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "class CustomTextEncoder():\n",
        "    def __init__(self):\n",
        "        raise NotImplementedError\n",
        "\n",
        "    def encode(self, texts):\n",
        "        raise NotImplementedError"
      ],
      "metadata": {
        "id": "YTQ28PjotBAB",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "encoders = {\n",
        "    'VIT': VIT,\n",
        "    'DINO': DINO,\n",
        "    'DEIT': DEIT,\n",
        "    'DEITmedium' : DEITmedium,\n",
        "    'DEITlarge' : DEITlarge,\n",
        "    'SentenceT' : SentenceT,\n",
        "    'SentenceTmini' : SentenceTmini,\n",
        "    'SentenceTmedium' : SentenceTmedium,\n",
        "    'Custom_Image_Encoder' : CustomImageEncoder,\n",
        "    'Custom_Text_Encoder' : CustomTextEncoder,\n",
        "}\n",
        "\n",
        "select_image_encoder = 'VIT (supervised pretraining on Im21k)' #@param [\"VIT (supervised pretraining on Im21k)\", \"DINO (unsupervised pretraining on Im21k)\", \"DEIT (supervised pretraining on Im1k, tiny VIT)\", \"DEITmedium (supervised pretraining on Im1k, small VIT)\", \"DEITlarge (supervised pretraining on Im1k, base VIT)\", \"Custom_Image_Encoder\"]\n",
        "select_text_encoder = 'SentenceT (all-mpnet-base-v2)' #@param [\"SentenceT (all-mpnet-base-v2)\", \"SentenceTmini (all-MiniLM-L6-v2)\", \"SentenceTmedium (all-MiniLM-L12-v2)\", \"Custom_Text_Encoder\"]\n",
        "\n",
        "image_encoder = encoders[select_image_encoder.split()[0]]()\n",
        "text_encoder = encoders[select_text_encoder.split()[0]]()"
      ],
      "metadata": {
        "id": "IqHkro-t5m_v",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Design the visual understanding of your model"
      ],
      "metadata": {
        "id": "zUYhvNFZcIv5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# EDIT THIS CLASS IF YOU WANT TO LOAD YOUR DATASET FOR TRAINING OR TEST\n",
        "\n",
        "class EmbeddedCustomDataset(EmbeddedDataset):\n",
        "    \"\"\"\n",
        "    Custom loader.\n",
        "    \"\"\"\n",
        "    def __init__(self, classes: List[str] = None, templates: List[str] = None, subsample_size: int = 10_000, sampler_seed: int = 42, load_precomputed_embeddings=False, split = 'test'):\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            classes: a list of classes, like [\"tench\", \"goldfish\", ...\n",
        "            templates: a list of templates to craft captions using classes, like ['itap of a {}.', 'a bad photo of the {}.', ...\n",
        "            subsample_size: how much images of the dataset you want to use. (they will be sampled randomly)\n",
        "            sampler_seed: the seed for the random sampler.\n",
        "        \"\"\"\n",
        "        ####### YOUR CODE HERE ########\n",
        "        self.classes = classes  # define here the list of classes to not pass them to the constructor\n",
        "        if not templates: self.templates = [  # if your dataset is already in the form of free text as CC12M, pass texts to classes and use ['{}'] as templates\n",
        "            'itap of a {}.',\n",
        "            'a bad photo of the {}.',\n",
        "            'a origami {}.',\n",
        "            'a photo of the large {}.',\n",
        "            'a {} in a video game.',\n",
        "            'art of the {}.',\n",
        "            'a photo of the small {}.',\n",
        "        ]\n",
        "        self.texts = [\n",
        "            prompt.format(classname)\n",
        "            for classname in self.classes\n",
        "            for prompt in self.templates\n",
        "        ]\n",
        "        self.load_precomputed = load_precomputed_embeddings\n",
        "        self.subsample_size = subsample_size\n",
        "        self.split = ''\n",
        "        if not self.load_precomputed:\n",
        "            preprocess = transforms.Compose([transforms.ToTensor(),\n",
        "                                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Imagenet mean and std, fill with mean and std of your custom dataset\n",
        "                                            transforms.Resize((224, 224))])\n",
        "            ####### YOUR CODE HERE ########\n",
        "            # load here your pytorch dataset, look at the PETS example\n",
        "            images = None # datasets.OxfordIIITPet(root='pets', transform=preprocess, split={'Train' : 'trainval', 'Test' : 'test'}[split], download=True)\n",
        "            self.loader = torch.utils.data.DataLoader(images, batch_size=100, num_workers=4)"
      ],
      "metadata": {
        "id": "mw6A800xsfO7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title utilities to compute the relative representations and store them in sparse tensors and to manage tensors larger than RAM\n",
        "\n",
        "from typing import Tuple, List, Type, Union\n",
        "import torch\n",
        "from tqdm import tqdm\n",
        "\n",
        "def get_free_gpu_mem(verbose=False):\n",
        "    t = torch.cuda.get_device_properties(0).total_memory\n",
        "    r = torch.cuda.memory_reserved(0)\n",
        "    a = torch.cuda.memory_allocated(0)\n",
        "    if verbose:\n",
        "        print(\"Total memory:\", t)\n",
        "        print(\"Reserved memory:\", r)\n",
        "        print(\"Allocated memory:\", a)\n",
        "        print(\"Free memory:\", t-a)  # or t-r\n",
        "    return t-a\n",
        "\n",
        "\n",
        "\n",
        "def relative_represent_naive(y: torch.Tensor, basis: torch.Tensor, non_zeros: int = 800) -> Tuple[torch.Tensor, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    Compute the sparse decomposition of a tensor y with respect to a basis.\n",
        "    \n",
        "    Args:\n",
        "        y (torch.Tensor): vectors to relative represent\n",
        "        basis (torch.Tensor): basis to represent with respect to\n",
        "        non_zeros (int): nonzero entries in the relative representation\n",
        "        \n",
        "    Returns:\n",
        "        indices (torch.Tensor): indices of the nonzero entries in each relative representation of y\n",
        "        values (torch.Tensor): corresponding coefficients of the entries\n",
        "    \"\"\"\n",
        "    print(y.shape, basis.shape)\n",
        "    with torch.no_grad():\n",
        "        in_prods = torch.einsum('ik, jk -> ij', y, basis)\n",
        "        values, indices = torch.topk(in_prods, non_zeros, dim=1)  # cosine similarity\n",
        "    return indices.to('cpu'), values.to('cpu')\n",
        "\n",
        "\n",
        "def relative_represent_autogpu(y: torch.Tensor, basis: torch.Tensor, non_zeros: int = 800, gpu_mem_free_margin_in_mb=5000) -> Tuple[torch.Tensor, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    Compute the sparse decomposition of a tensor y with respect to a basis aware of the available gpu memory.\n",
        "    \n",
        "    Args:\n",
        "        y (torch.Tensor): vectors to relative represent\n",
        "        basis (torch.Tensor): basis to represent with respect to\n",
        "        non_zeros (int): nonzero entries in the relative representation\n",
        "        \n",
        "    Returns:\n",
        "        indices (torch.Tensor): indices of the nonzero entries in each relative representation of y\n",
        "        values (torch.Tensor): corresponding coefficients of the entries\n",
        "    \"\"\"\n",
        "    values, indices = torch.zeros((y.shape[0], non_zeros)), torch.zeros((y.shape[0], non_zeros), dtype=torch.long)\n",
        "\n",
        "    free_gpu_mem = get_free_gpu_mem() - gpu_mem_free_margin_in_mb * 1024 ** 2\n",
        "    max_floats_in_mem = free_gpu_mem / 4\n",
        "    max_chunk_y = max_floats_in_mem / basis.shape[0]\n",
        "    n_chunks = int(y.shape[0] / max_chunk_y) + 1  # should be + 1\n",
        "    chunk_y = int(y.shape[0] / n_chunks) + n_chunks   # final sum to avoid one more chunk\n",
        "    print(max_chunk_y, n_chunks, chunk_y)\n",
        "    with torch.no_grad():\n",
        "        for c in range(n_chunks):\n",
        "            in_prods = torch.einsum('ik, jk -> ij', y[c * chunk_y : (c + 1) * chunk_y], basis)\n",
        "            values[c * chunk_y : (c + 1) * chunk_y], indices[c * chunk_y : (c + 1) * chunk_y] = torch.topk(in_prods, non_zeros, dim=1)  # cosine similarity\n",
        "            del in_prods  # it seems this is necessary\n",
        "    return indices.to('cpu'), values.to('cpu')\n",
        "\n",
        "\n",
        "def relative_represent(y: torch.Tensor, basis: torch.Tensor, non_zeros: int = 800, max_gpu_mem_gb=8) -> Tuple[torch.Tensor, torch.Tensor]:\n",
        "    \"\"\"\n",
        "    Compute the sparse decomposition of a tensor y with respect to a basis aware of the available gpu memory.\n",
        "    \n",
        "    Args:\n",
        "        y (torch.Tensor): vectors to relative represent\n",
        "        basis (torch.Tensor): basis to represent with respect to\n",
        "        non_zeros (int): nonzero entries in the relative representation\n",
        "        \n",
        "    Returns:\n",
        "        indices (torch.Tensor): indices of the nonzero entries in each relative representation of y\n",
        "        values (torch.Tensor): corresponding coefficients of the entries\n",
        "    \"\"\"\n",
        "    values, indices = torch.zeros((y.shape[0], non_zeros)), torch.zeros((y.shape[0], non_zeros), dtype=torch.long)\n",
        "\n",
        "    free_gpu_mem = max_gpu_mem_gb * 1024 ** 3\n",
        "    max_floats_in_mem = free_gpu_mem / 4\n",
        "    max_chunk_y = max_floats_in_mem / basis.shape[0]\n",
        "    n_chunks = int(y.shape[0] / max_chunk_y) + 1  \n",
        "    chunk_y = int(y.shape[0] / n_chunks) + n_chunks   # final sum to avoid one more chunk\n",
        "    # print(max_gpu_mem_gb, max_chunk_y, n_chunks, chunk_y)\n",
        "    with torch.no_grad():\n",
        "        for c in range(n_chunks):\n",
        "            in_prods = torch.einsum('ik, jk -> ij', y[c * chunk_y : (c + 1) * chunk_y], basis)\n",
        "            values[c * chunk_y : (c + 1) * chunk_y], indices[c * chunk_y : (c + 1) * chunk_y] = torch.topk(in_prods, non_zeros, dim=1)  # cosine similarity\n",
        "            del in_prods  # it seems this is necessary\n",
        "    return indices.to('cpu'), values.to('cpu')\n",
        "\n",
        "\n",
        "def sparsify(i: torch.Tensor, v: torch.Tensor, size: torch.Size) -> torch.sparse.FloatTensor:\n",
        "    \"\"\"\n",
        "    Organize indices and values of n vectors into a single sparse tensor.\n",
        "    \n",
        "    Args:\n",
        "        i (torch.Tensor): indices of non-zero elements of every vector. Shape: (n_vectors, nonzero elements)\n",
        "        v (torch.Tensor): values of non-zero elements of every vector. Shape: (n_vectors, nonzero elements)\n",
        "        size (torch.Size): shape of the output tensor\n",
        "        \n",
        "    Returns:\n",
        "        torch.sparse.FloatTensor: sparse tensor of shape \"size\" (n_vectors, zero + nonzero elements)\n",
        "    \"\"\"\n",
        "    flat_dim = len(i.flatten())\n",
        "    coo_first_row_idxs = torch.div(torch.arange(flat_dim), i.shape[1], rounding_mode='floor')\n",
        "    stacked_idxs = torch.cat((coo_first_row_idxs.unsqueeze(0), i.flatten().unsqueeze(0)), 0)\n",
        "    return torch.sparse_coo_tensor(stacked_idxs, v.flatten(), size)\n",
        "\n",
        "\n",
        "def normalize_sparse(tensor: torch.sparse.FloatTensor, nnz_per_row: int) -> torch.sparse.FloatTensor:\n",
        "    \"\"\"Normalize a sparse tensor by row.\n",
        "\n",
        "    Args:\n",
        "        tensor (torch.sparse.FloatTensor): The sparse tensor to normalize.\n",
        "        nnz_per_row (int): The number of non-zero elements per row.\n",
        "\n",
        "    Returns:\n",
        "        torch.sparse.FloatTensor: The normalized sparse tensor.\n",
        "    \"\"\"    \n",
        "    norms = torch.sparse.sum(tensor * tensor, dim=1).to_dense()\n",
        "    v = tensor._values().clone().detach().reshape(-1, nnz_per_row).t()\n",
        "    v /= torch.sqrt(norms)\n",
        "    return torch.sparse_coo_tensor(tensor._indices(), v.t().flatten(), tensor.shape)\n",
        "\n",
        "\n",
        "class memory_tensor():\n",
        "    \"\"\"\n",
        "    Class to load tensors from disk in chunks and store them in memory.\n",
        "    \"\"\"\n",
        "    def __init__(self, ordered_filepaths: List[str], chunk_size: int = 100_000, device: str = 'cuda', normalized=True) -> None:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            ordered_filepaths (List[str]): list of filepaths to load tensors from\n",
        "            chunk_size (int): size of chunks to load from disk\n",
        "            device (str): device to load tensors to\n",
        "            normalized (bool): whether dim1 of the tensor should be normalized to norm=1\n",
        "        \"\"\"\n",
        "        self.paths = ordered_filepaths\n",
        "        self.chunk_size = chunk_size\n",
        "        if normalized:\n",
        "            for i, fp in enumerate(ordered_filepaths):\n",
        "                tmp = torch.load(fp).to(device)\n",
        "                if i == 0:  # if first chunk is already normalized, assume everything is normalized\n",
        "                    tmp_sum = torch.einsum('ik ->', tmp * tmp)\n",
        "                    if tmp_sum - tmp.shape[0] < 0.01 * tmp.shape[0]: \n",
        "                        break # if already normalized break (1% error tolerated)\n",
        "                tmp /= torch.einsum('ik -> i', tmp * tmp).unsqueeze(1) ** 0.5\n",
        "                torch.save(tmp, fp)\n",
        "            del tmp\n",
        "        self.chunk_in_memory = 0\n",
        "        self.x = torch.load(self.paths[-1]).to(device)\n",
        "        self.device = device\n",
        "        self.len = self.chunk_size * (len(self.paths) - 1) + len(self.x)\n",
        "    def __getitem__(self, index: Union[int, slice, torch.Tensor]) -> torch.Tensor:\n",
        "        \"\"\"\n",
        "        Args:\n",
        "            index (Union[int, slice, torch.Tensor]): index to get item from\n",
        "        Returns:\n",
        "            torch.Tensor: tensor at index\n",
        "        \"\"\"\n",
        "        if isinstance(index, int):\n",
        "            chunk = index // self.chunk_size\n",
        "            if chunk != self.chunk_in_memory:\n",
        "                self.chunk_in_memory = chunk\n",
        "                self.x = torch.load(self.paths[chunk]).to(self.device)\n",
        "            return self.x[index % self.chunk_size]\n",
        "        if isinstance(index, slice):\n",
        "            start = index.start\n",
        "            stop = index.stop\n",
        "            if start is None:\n",
        "                start = 0\n",
        "            if stop is None:\n",
        "                stop = len(self)            \n",
        "            c_start = start // self.chunk_size\n",
        "            if stop >= self.len:\n",
        "                stop = self.len - 1\n",
        "            c_stop = stop // self.chunk_size\n",
        "            if c_start == c_stop:\n",
        "                if c_start != self.chunk_in_memory:\n",
        "                    self.chunk_in_memory = c_start\n",
        "                    self.x = torch.load(self.paths[c_start]).to(self.device)\n",
        "                return self.x[start % self.chunk_size : stop % self.chunk_size]\n",
        "            elif c_start + 1 == c_stop:\n",
        "                if c_start != self.chunk_in_memory:\n",
        "                    self.chunk_in_memory = c_start\n",
        "                    self.x = torch.load(self.paths[c_start]).to(self.device)\n",
        "                temp = self.x[start % self.chunk_size :]\n",
        "                self.chunk_in_memory = c_stop\n",
        "                self.x = torch.load(self.paths[c_stop]).to(self.device)\n",
        "                return torch.cat((temp, self.x[:stop % self.chunk_size]), dim=0)\n",
        "            else:\n",
        "                raise Exception(\"slice too big (bigger than chunk size)\")\n",
        "        if isinstance(index, torch.Tensor):\n",
        "            chunk = index.flatten()[0] // self.chunk_size\n",
        "            if chunk != self.chunk_in_memory:\n",
        "                self.chunk_in_memory = chunk\n",
        "                self.x = torch.load(self.paths[chunk]).to(self.device)\n",
        "            return self.x[index % self.chunk_size]\n",
        "\n",
        "    def __len__(self) -> int:\n",
        "        \"\"\"\n",
        "        Returns:\n",
        "            int: length of tensor\n",
        "        \"\"\"\n",
        "        return self.len\n",
        "\n",
        "\n",
        "def zero_shot_classification(zimgs: torch.Tensor, ztxts: torch.Tensor, aimgs: torch.Tensor, atxts: torch.Tensor, test_labels: list, non_zeros: int, range_anch: range, val_exps: list, dic_size: int = 100_000, max_gpu_mem_gb: float = 8.) -> (list, dict, torch.Tensor):\n",
        "    \"\"\"\n",
        "    Computes the zero-shot classification accuracy using relative representations\n",
        "    over sets of anchors of different sizes and raising the similarities to the given exponents.\n",
        "    \n",
        "    Args:\n",
        "        zimgs (torch.Tensor): absolute embeddings of the images\n",
        "        ztxts (torch.Tensor): absolute embeddings of the texts\n",
        "        aimgs (torch.Tensor): absolute embeddings of the anchor images\n",
        "        atxts (torch.Tensor): absolute embeddings of the anchor texts\n",
        "        test_labels (list): grounf trith labels of the images\n",
        "        non_zeros (int): nonzero entries in the relative representation\n",
        "        range_anch (range): range of sizes of the anchor's sets to use (overshoot is ok)\n",
        "        dic_size (int): size of the chunk of aimgs to load in memory to fit all intermediate variables in RAM\n",
        "        val_exps (list): similarity exponents to test\n",
        "\n",
        "    Returns:\n",
        "        n_anchors (list): list of sizes of the anchor's sets (with overshooting fixed)\n",
        "        scores (dict): dictionary of scores for each tested similarity exponent\n",
        "        sims (torch.Tensor): similarity matrix between images and texts\n",
        "    \"\"\"    \n",
        "    n_anchors = []\n",
        "    scores = {ve : [] for ve in val_exps}\n",
        "    n_templates = max(int(ztxts.shape[0] / (max(test_labels) - min(test_labels) + 1)), 1)\n",
        "    for i in tqdm(range_anch, position=0, leave=True):\n",
        "        sims = torch.zeros((len(zimgs), len(ztxts)))\n",
        "        idxs_imgs = torch.zeros(((len(zimgs), non_zeros * 2)), dtype=torch.long)\n",
        "        idxs_txts = torch.zeros(((len(ztxts), non_zeros * 2)), dtype=torch.long)\n",
        "        vals_imgs = torch.zeros(((len(zimgs), non_zeros * 2)))\n",
        "        vals_txts = torch.zeros(((len(ztxts), non_zeros * 2)))\n",
        "        for d in range(min(len(aimgs), i) // (dic_size + 1) + 1):\n",
        "            idxs, values = relative_represent(zimgs, aimgs[d * dic_size : min(i, (d + 1) * dic_size)], non_zeros=non_zeros, max_gpu_mem_gb=max_gpu_mem_gb)\n",
        "            idxs_imgs[:, :non_zeros] = idxs + d * dic_size\n",
        "            vals_imgs[:, :non_zeros] = values\n",
        "            idxs, values = relative_represent(ztxts, atxts[d * dic_size : min(i, (d + 1) * dic_size)], non_zeros=non_zeros, max_gpu_mem_gb=max_gpu_mem_gb)\n",
        "            idxs_txts[:, :non_zeros] = idxs + d * dic_size\n",
        "            vals_txts[:, :non_zeros] = values\n",
        "\n",
        "            top_valsi, indices = torch.topk(vals_imgs, non_zeros, dim=1)\n",
        "            top_idxsi = torch.gather(idxs_imgs, 1, indices)\n",
        "            top_valst, indices = torch.topk(vals_txts, non_zeros, dim=1)\n",
        "            top_idxst = torch.gather(idxs_txts, 1, indices)\n",
        "\n",
        "            idxs_imgs[:, non_zeros:] = top_idxsi\n",
        "            vals_imgs[:, non_zeros:] = top_valsi\n",
        "            idxs_txts[:, non_zeros:] = top_idxst\n",
        "            vals_txts[:, non_zeros:] = top_valst\n",
        "\n",
        "        for val_exp in val_exps:\n",
        "            ztxts_t = sparsify(top_idxst, top_valst ** val_exp, (len(ztxts), min(len(aimgs), i))).to(zimgs.device)\n",
        "            ztxts_t = normalize_sparse(ztxts_t, non_zeros)\n",
        "            if i < max_gpu_mem_gb * 1024 ** 3 / 4 / zimgs.shape[0]:  # einsum until it fits in gpu mem \n",
        "                zimgs_t = sparsify(top_idxsi, top_valsi ** val_exp, (len(zimgs), min(len(aimgs), i))).to(zimgs.device)\n",
        "                sims = torch.einsum('ij, kj -> ik', zimgs_t.to_dense(), ztxts_t.to_dense()).to('cpu')\n",
        "            else:\n",
        "                n_chunks = 6\n",
        "                zs = zimgs.shape[0]\n",
        "                chunks = [c * (zs // n_chunks) for c in range(n_chunks)] + [zs]\n",
        "                for ci in range(n_chunks):\n",
        "                    zimgs_t = sparsify(top_idxsi[chunks[ci]:chunks[ci+1]], top_valsi[chunks[ci]:chunks[ci+1]] ** val_exp, (chunks[ci+1] - chunks[ci], min(len(aimgs), i))).to(zimgs.device)\n",
        "                    sims[chunks[ci]:chunks[ci+1]] = torch.sparse.mm(zimgs_t, ztxts_t.t()).to('cpu').to_dense()\n",
        "            score = float((torch.div(sims.argmax(axis=1),  n_templates, rounding_mode='floor') == torch.tensor(test_labels)).sum() / len(zimgs))\n",
        "            scores[val_exp].append(score)\n",
        "        n_anchors.append(min(len(aimgs), i))    \n",
        "    return n_anchors, scores, sims\n",
        "\n",
        "\n",
        "def rand_mul_indices(indices_list, n_templates):\n",
        "    x = torch.randint(low=0, high=len(n_templates), size=(len(indices_list),))\n",
        "    return torch.tensor(indices_list) * len(n_templates) + x"
      ],
      "metadata": {
        "id": "x51O0Ndmj1Sy",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "available_datasets = {\n",
        "    'CC12M (1.5M image-caption couples, precomputed)': False,\n",
        "    'ImageNet val full': (EmbeddedImagenet, 'full', False),\n",
        "    'ImageNet val split1 (precomputed)': (EmbeddedImagenet, 'Val', True),\n",
        "    'ImageNet val split2 (precomputed)': (EmbeddedImagenet, 'Test', True),\n",
        "    'CIFAR100' : (EmbeddedCIFAR100, 'Test', False),\n",
        "    'PETS' : (EmbeddedPETS, 'Test', False),\n",
        "    'ImageNetV2' : (EmbeddedImagenetV2, '', False),\n",
        "    'CustomDataset': (EmbeddedCustomDataset, '', False)\n",
        "}\n",
        "\n",
        "select_train_dataset = 'CIFAR100' #@param [\"CC12M (1.5M image-caption couples)\", \"ImageNet val split1 (precomputed)\", \"ImageNet val split2 (precomputed)\", \"CIFAR100\", \"PETS\", \"ImageNetV2\", \"CustomDataset\"]\n",
        "integrate_with_CC12M = False #@param {type:\"boolean\"}\n",
        "\n",
        "if select_train_dataset[:5] != 'CC12M':\n",
        "    train_dataset_class, split, load_precomputed_embeddings = available_datasets[select_train_dataset]\n",
        "    train_dataset = train_dataset_class(split=split, load_precomputed_embeddings=load_precomputed_embeddings)\n",
        "    aimgs, atxts, atest_labels = train_dataset.embed(image_encoder, text_encoder)\n",
        "    atxts = atxts[rand_mul_indices(atest_labels, train_dataset.templates)]\n",
        "    # normalization to perform cosine similarity with a simple matmul\n",
        "    aimgs /= torch.einsum('ik -> i', aimgs * aimgs).unsqueeze(1) ** 0.5\n",
        "    atxts /= torch.einsum('ik -> i', atxts * atxts).unsqueeze(1) ** 0.5\n",
        "    if integrate_with_CC12M:\n",
        "        torch.save(aimgs, 'aimgs.pt')\n",
        "        torch.save(atxts, 'atxts.pt')\n",
        "if select_train_dataset[:5] == 'CC12M' or integrate_with_CC12M:\n",
        "    aimgs_paths, atxts_paths = [], []\n",
        "    i = 0\n",
        "    im_enc_name = select_image_encoder.split()[0]\n",
        "    tx_enc_name = select_text_encoder.split()[0]\n",
        "\n",
        "    for enc_name in [im_enc_name, tx_enc_name]:\n",
        "        if not os.path.exists(f'EmbeddedCC12M{enc_name}'):\n",
        "            key = cc12m_drive_keys[enc_name]\n",
        "            file_name = f'EmbeddedCC12M{enc_name}.tar'\n",
        "            print(f'Downloading and loading CC12M {enc_name} embeddings will take between 1 and 3 minutes')\n",
        "            gdown.download('https://drive.google.com/uc?id=' + key, file_name, quiet=False)\n",
        "            print(f'extracting {file_name}...')\n",
        "            subprocess.call(f\"tar -xf {file_name}\", shell=True)\n",
        "    while (os.path.exists(f'EmbeddedCC12M{im_enc_name}/zimgs_{im_enc_name}_{i:04d}.pt') and \n",
        "        os.path.exists(f'EmbeddedCC12M{tx_enc_name}/ztexts_{tx_enc_name}_{i:04d}.pt')):\n",
        "        aimgs_paths.append(f'EmbeddedCC12M{im_enc_name}/zimgs_{im_enc_name}_{i:04d}.pt')\n",
        "        atxts_paths.append(f'EmbeddedCC12M{tx_enc_name}/ztexts_{tx_enc_name}_{i:04d}.pt')\n",
        "        i += 1\n",
        "    if integrate_with_CC12M and select_train_dataset[:5] != 'CC12M':\n",
        "        aimgs_paths.append('aimgs.pt')\n",
        "        atxts_paths.append('atxts.pt')\n",
        "    aimgs = memory_tensor(aimgs_paths, chunk_size=100_000)\n",
        "    atxts = memory_tensor(atxts_paths, chunk_size=100_000)\n"
      ],
      "metadata": {
        "id": "vqPRQImZ6sR0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "select_test_dataset = 'PETS' #@param [\"ImageNet val split1 (precomputed)\", \"ImageNet val split2 (precomputed)\", \"CIFAR100\", \"PETS\", \"ImageNetV2\", \"CustomDataset\"]\n",
        "\n",
        "test_dataset_class, split, load_precomputed_embeddings = available_datasets[select_test_dataset]\n",
        "test_dataset = test_dataset_class(split=split, load_precomputed_embeddings=load_precomputed_embeddings)\n",
        "zimgs, ztxts, ztest_labels = test_dataset.embed(image_encoder, text_encoder)\n",
        "\n",
        "# normalization to perform cosine similarity with a simple matmul\n",
        "zimgs /= torch.einsum('ik -> i', zimgs * zimgs).unsqueeze(1) ** 0.5\n",
        "ztxts /= torch.einsum('ik -> i', ztxts * ztxts).unsqueeze(1) ** 0.5"
      ],
      "metadata": {
        "id": "VjV2WZi6DpLW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Trend\n",
        "\n",
        "# free memory from encoders (since models are loaded by default jit to encode and here we will not encode anything, we can just reassign variables)\n",
        "image_encoder = encoders[select_image_encoder.split()[0]]()\n",
        "text_encoder = encoders[select_text_encoder.split()[0]]()\n",
        "\n",
        "# plot zero-shot accuracy trend vs different sizes of the anchors set\n",
        "ds_name = select_test_dataset\n",
        "train_ds_name = {True: 'CC12M + ', False: ''}[integrate_with_CC12M] + select_train_dataset.split('(')[0][:-1]\n",
        "non_zeros = 800 \n",
        "range_anch = [2 ** i for i in range(int(np.log2(non_zeros) + 1), int(np.log2(len(aimgs))) + 2 )]\n",
        "val_exps = [8]\n",
        "max_gpu_mem_gb = 4\n",
        "\n",
        "n_anchors, scores, sims = zero_shot_classification(zimgs, ztxts, aimgs, atxts, ztest_labels, non_zeros, range_anch, val_exps, max_gpu_mem_gb=max_gpu_mem_gb)\n",
        "\n",
        "fig1 = plt.figure(figsize=(15,10))\n",
        "ax1 = fig1.add_subplot(111)\n",
        "cmap = cm.get_cmap('tab20')\n",
        "colors = [cmap(i) for i in np.linspace(0, 1, len(range_anch))]\n",
        "color_dict = dict(zip(range_anch, colors))\n",
        "\n",
        "v = sims.mean(dim=0).to('cpu')\n",
        "for val_exp in val_exps:\n",
        "    ax1.plot(n_anchors, scores[val_exp], label=f'non_zeros: {non_zeros}, val_exp: {val_exp}')\n",
        "\n",
        "ax1.legend(loc='lower right')\n",
        "ax1.set_xlabel('Number of anchors')\n",
        "ax1.set_xscale('log')\n",
        "ax1.set_ylabel(f'{ds_name} 0-shot accuracy')\n",
        "ax1.set_title(f'{ds_name} 0-shot accuracy vs number of anchors (train dataset: {train_ds_name})')\n",
        "ax1.grid(True)\n",
        "ax1.minorticks_on()\n",
        "ax1.grid(which='minor', linestyle='-', linewidth='0.5', color='black', alpha=0.2)\n",
        "plt.show()\n",
        "fig1.savefig(f'{ds_name}_0shot_accuracy_vs_anchors.png', dpi=300)\n"
      ],
      "metadata": {
        "id": "apg9vAFYcFzz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## New demos"
      ],
      "metadata": {
        "id": "1UHaU06Svdyj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def single_relative_represent(zimgs: torch.Tensor, aimgs: torch.Tensor, non_zeros: int = 800, val_exp: int = 8, chunk_size: int = 100_000) -> (list, dict, torch.Tensor):\n",
        "    idxs_imgs = torch.zeros(((len(zimgs), non_zeros * (len(aimgs) // (chunk_size + 1) + 1))), dtype=torch.long)\n",
        "    vals_imgs = torch.zeros(((len(zimgs), non_zeros * (len(aimgs) // (chunk_size + 1) + 1))))\n",
        "    for d in range(len(aimgs) // (chunk_size + 1) + 1):\n",
        "        idxs, values = relative_represent(zimgs, aimgs[d * chunk_size : (d + 1) * chunk_size], non_zeros=non_zeros)\n",
        "        idxs_imgs[:,d * non_zeros : (d + 1) * non_zeros] = idxs + d * chunk_size\n",
        "        vals_imgs[:,d * non_zeros : (d + 1) * non_zeros] = values\n",
        "    top_valsi, indices = torch.topk(vals_imgs, non_zeros, dim=1)\n",
        "    top_idxsi = torch.gather(idxs_imgs, 1, indices)\n",
        "    zimgs_t = sparsify(top_idxsi, top_valsi ** val_exp, (len(zimgs), len(aimgs)))\n",
        "    return zimgs_t\n",
        "\n",
        "def proc_relative_represent(zimgs: torch.Tensor, ztxts: torch.Tensor, aimgs: torch.Tensor, atxts: torch.Tensor, test_labels: list, non_zeros: int = 800, val_exp: int = 8, chunk_size: int = 100_000) -> (list, dict, torch.Tensor):\n",
        "    # sims = torch.zeros((len(zimgs), len(ztxts)))\n",
        "    idxs_imgs = torch.zeros(((len(zimgs), non_zeros * (len(aimgs) // (chunk_size + 1) + 1))), dtype=torch.long)\n",
        "    idxs_txts = torch.zeros(((len(ztxts), non_zeros * (len(aimgs) // (chunk_size + 1) + 1))), dtype=torch.long)\n",
        "    vals_imgs = torch.zeros(((len(zimgs), non_zeros * (len(aimgs) // (chunk_size + 1) + 1))))\n",
        "    vals_txts = torch.zeros(((len(ztxts), non_zeros * (len(aimgs) // (chunk_size + 1) + 1))))\n",
        "    for d in range(len(aimgs) // (chunk_size + 1) + 1):\n",
        "        idxs, values = relative_represent(zimgs, aimgs[d * chunk_size : (d + 1) * chunk_size], non_zeros=non_zeros)\n",
        "        idxs_imgs[:,d * non_zeros : (d + 1) * non_zeros] = idxs + d * chunk_size\n",
        "        vals_imgs[:,d * non_zeros : (d + 1) * non_zeros] = values\n",
        "        idxs, values = relative_represent(ztxts, atxts[d * chunk_size : (d + 1) * chunk_size], non_zeros=non_zeros)\n",
        "        idxs_txts[:,d * non_zeros : (d + 1) * non_zeros] = idxs + d * chunk_size\n",
        "        vals_txts[:,d * non_zeros : (d + 1) * non_zeros] = values\n",
        "    top_valsi, indices = torch.topk(vals_imgs, non_zeros, dim=1)\n",
        "    top_idxsi = torch.gather(idxs_imgs, 1, indices)\n",
        "    top_valst, indices = torch.topk(vals_txts, non_zeros, dim=1)\n",
        "    top_idxst = torch.gather(idxs_txts, 1, indices)\n",
        "    zimgs_t = sparsify(top_idxsi, top_valsi ** val_exp, (len(zimgs), len(aimgs)))\n",
        "    ztxts_t = sparsify(top_idxst, top_valst ** val_exp, (len(ztxts), len(aimgs)))\n",
        "    return zimgs_t, ztxts_t\n"
      ],
      "metadata": {
        "id": "bUTCglim4uyW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "if not os.path.exists(f'EmbeddedCC12Mmetadata'):\n",
        "    print('Downloading CC12M metadata...')\n",
        "    file_name = 'CC12M_metadata.tar'\n",
        "    gdown.download(f'https://drive.google.com/uc?id={cc12m_drive_keys[\"Metadata\"]}', file_name, quiet=False)\n",
        "    print(f'Extracting {file_name}...')\n",
        "    subprocess.call(f\"tar -xf {file_name}\", shell=True)\n",
        "with open('EmbeddedCC12Mmetadata/text_list.pkl', 'rb') as f:\n",
        "    text_list = pickle.load(f)\n",
        "with open('EmbeddedCC12Mmetadata/url_image_list.pkl', 'rb') as f:\n",
        "    image_list = pickle.load(f)"
      ],
      "metadata": {
        "id": "BHr4Tca4Onkg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Upload your images\n",
        "Upload your images in the folder `example_images`, you can drag-and-drop them\n",
        "\n"
      ],
      "metadata": {
        "id": "CRdf2txA1tqc"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "example_images = os.listdir('example_images')\n",
        "use_image_names_as_captions = True #@param {type:\"boolean\"}\n",
        "remove_numbers_from_image_names = True\n",
        "replace_underscores_with_spaces = True \n",
        "normalize_camelized_text = True \n",
        "\n",
        "example_captions = []\n",
        "if use_image_names_as_captions:\n",
        "    for image_name in example_images:\n",
        "        caption = os.path.splitext(image_name)[0]\n",
        "        if replace_underscores_with_spaces: \n",
        "            caption = caption.replace('_', ' ')\n",
        "        if normalize_camelized_text: \n",
        "            caption = re.sub(r'([a-z])([A-Z])', r'\\1 \\2', caption)\n",
        "        caption = ' '.join([word for word in caption.split(' ') if word and not any(char.isdigit() for char in word)])\n",
        "        example_captions.append(caption)\n",
        "\n",
        "example_captions = [caption.lower() for caption in example_captions]"
      ],
      "metadata": {
        "id": "-WUiHR-w87Lv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "if not use_image_names_as_captions:\n",
        "    example_captions = [\"salty water\",\n",
        "                    \"a crowd of hands\",\n",
        "                    \"yellow hands\",\n",
        "                    \"a lot of coloured hands with a yellow background\",\n",
        "                    \"a red car\",\n",
        "                    \"Audience with hands in the air at a music festival\"]\n",
        "    example_captions = [\"this is a photo of healthy lymph node tissue\",\n",
        "                        \"this is a photo of lymph node tumor tissue\"]\n",
        "\n",
        "zimgs = image_encoder.encode(images=[Image.open(f'example_images/{x}').convert('RGB') for x in example_images], alr_preprocessed=False)\n",
        "ztxts = text_encoder.encode(example_captions)\n",
        "ztest_labels = [i for i in range(len(example_images))]  # permutation map between example_images and example_captions, if the order of example_images and example_captions does not match, edit it\n",
        "\n"
      ],
      "metadata": {
        "id": "-HaixV9_5bRL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "proc_rr_imgs, proc_rr_txts = proc_relative_represent(zimgs, ztxts, aimgs, atxts, ztest_labels)"
      ],
      "metadata": {
        "id": "2SWoHxuh7tU2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "n_anchors, scores, sims = zero_shot_classification(zimgs, ztxts, aimgs, atxts, ztest_labels, 800, range_anch=[len(aimgs)], dic_size=100_000, val_exps=[8])"
      ],
      "metadata": {
        "id": "AcJfRTjBVd3f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def split_string(string, n, max_lines=9999):\n",
        "    # split the string on spaces\n",
        "    words = string.split(' ')\n",
        "    # create a list to store the lines\n",
        "    lines = []\n",
        "    # create a variable to store the current line\n",
        "    current_line = ''\n",
        "    # loop through the words\n",
        "    for word in words:\n",
        "        # if the current line + the word is less than n\n",
        "        if len(current_line) + len(word) < n:\n",
        "            # add the word to the current line\n",
        "            current_line += word + ' '\n",
        "        # otherwise\n",
        "        else:\n",
        "            # add the current line to the list of lines\n",
        "            lines.append(current_line)\n",
        "            # reset the current line to the word\n",
        "            current_line = word + ' '\n",
        "    # add the last line to the list of lines\n",
        "    lines.append(current_line)\n",
        "    if len(lines) > max_lines:\n",
        "        lines = lines[:max_lines]\n",
        "        lines[-1] = lines[-1][:-3] + '...'\n",
        "    # return the list of lines joined by \\n\n",
        "    return '\\n'.join(lines)\n",
        "\n",
        "def show_images(images, captions, n_cols=8):\n",
        "    n_rows = len(images) // n_cols + 1\n",
        "    figsize = (n_cols * 5, n_rows * 6)\n",
        "    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)\n",
        "    axes = axes.flatten()\n",
        "    for i, (image, caption) in enumerate(zip(images, captions)):\n",
        "        axes[i].imshow(mpimg.imread(f'example_images/{image}', 0))\n",
        "        if len(caption) > 93: caption = caption[:90] + '...'\n",
        "        axes[i].set_title(f'### {i} ###\\n' + split_string(caption, 30))\n",
        "        axes[i].axis('off')\n",
        "    while i < n_rows * n_cols - 1:\n",
        "        i += 1\n",
        "        axes[i].axis('off')\n",
        "    plt.tight_layout()\n",
        "    plt.show()"
      ],
      "metadata": {
        "id": "tFJlSZzldO38"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# @title Plot result\n",
        "\n",
        "if not use_image_names_as_captions: \n",
        "    print('Candidate captions:')\n",
        "    for i, cap in enumerate(example_captions): print(f'  ({i}) ' + cap)\n",
        "sims = sims.div(sims.sum(dim=1, keepdim=True))\n",
        "df = pd.DataFrame(sims.numpy(), columns=[split_string(cap, 15).split('\\n')[0] + '...' for cap in example_captions], index=[f'### {i} ###' for i in range(len(example_images))])\n",
        "# sns.heatmap(df, annot=True, fmt='.2f')\n",
        "sns.heatmap(df, annot=True, fmt='.2f', xticklabels=df.columns, yticklabels=df.index)\n",
        "plt.show()\n",
        "show_images(example_images, {True : example_captions, False : ['' for c in example_images]}[use_image_names_as_captions], n_cols=4)"
      ],
      "metadata": {
        "id": "n_Q63DiJclDR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Deep dive into a classification"
      ],
      "metadata": {
        "id": "oGAB0Kytzs7x"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install --q dash==2.0.0 jupyter-dash==0.4.0;"
      ],
      "metadata": {
        "id": "-v-3W9zLAkIV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### final deep dive"
      ],
      "metadata": {
        "id": "HIU8EPZfXHmW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from jupyter_dash import JupyterDash\n",
        "from dash import dcc, html, Input, Output, no_update\n",
        "from IPython.display import Image as IpyImage\n",
        "import plotly.graph_objects as go\n",
        "\n",
        "\n",
        "image_to_deep_dive = 1\n",
        "display(IpyImage(filename=f'example_images/{example_images[image_to_deep_dive]}', width=600))\n",
        "show_only_common_hits = False\n",
        "calculate_val_exp = False\n",
        "normalize_size = True  # if false, you can compare uncertainty between images looking at the size of the points (the largest the less uncertain classification)\n",
        "\n",
        "val_exp = {True : 8, False : 1}[calculate_val_exp]\n",
        "non_zero_indices = set(torch.cat((proc_rr_imgs[image_to_deep_dive]._indices()[0], proc_rr_txts[image_to_deep_dive]._indices()[0]), dim=0).tolist())\n",
        "x = []\n",
        "ys = {ec : [] for ec in example_captions}\n",
        "tags, image_tags = [], []\n",
        "proc_rr_imgs = normalize_sparse(proc_rr_imgs ** val_exp, proc_rr_imgs[0]._nnz())\n",
        "proc_rr_txts = normalize_sparse(proc_rr_txts ** val_exp, proc_rr_txts[0]._nnz())\n",
        "\n",
        "for i in non_zero_indices:\n",
        "    if (not show_only_common_hits and (proc_rr_imgs[image_to_deep_dive][i] or proc_rr_txts[image_to_deep_dive][i])) or (proc_rr_imgs[image_to_deep_dive][i] and proc_rr_txts[image_to_deep_dive][i]):\n",
        "        # x.append(proc_rr_imgs[image_to_deep_dive][i] ** val_exp)\n",
        "        # y.append(proc_rr_txts[image_to_deep_dive][i] ** val_exp)\n",
        "        x.append(proc_rr_imgs[image_to_deep_dive][i])\n",
        "        for k, ec in enumerate(example_captions):\n",
        "            ys[ec].append(proc_rr_txts[k][i])\n",
        "        if select_train_dataset.split()[0] != 'CC12M':\n",
        "            tags.append(train_dataset.classes[atest_labels[i]])\n",
        "        else:\n",
        "            tags.append(text_list[i])\n",
        "            image_tags.append(image_list[i])\n",
        "\n",
        "fig = go.Figure()\n",
        "\n",
        "norm_size_factor = 0\n",
        "sizes = []\n",
        "for k, ec in enumerate(example_captions):\n",
        "    size = np.array(x) * np.array(ys[ec])\n",
        "    sizes.append(size)\n",
        "    norm_size_factor += np.sum(size)\n",
        "\n",
        "\n",
        "for k, ec in enumerate(example_captions):\n",
        "    fig.add_trace(go.Scatter(\n",
        "        x=np.array(x),\n",
        "        y=np.array(ys[ec]),\n",
        "        mode='markers',\n",
        "        marker=dict(\n",
        "            size=300 * sizes[k] / norm_size_factor,\n",
        "            # color=np.array(x),\n",
        "            colorscale='Viridis',\n",
        "            showscale=False\n",
        "        ),\n",
        "        text=tags,\n",
        "        name=ec\n",
        "    ))\n",
        "\n",
        "fig.update_layout(\n",
        "    title='Similarity between image and text',\n",
        "    # xaxis_type=\"log\",\n",
        "    # yaxis_type=\"log\",\n",
        "    xaxis_title='Image similarity',\n",
        "    yaxis_title='Text similarity',\n",
        "    legend=dict(\n",
        "        itemsizing='constant'\n",
        "    )\n",
        ")\n",
        "\n",
        "\n",
        "fig.update_traces(hoverinfo=\"none\", hovertemplate=None)\n",
        "\n",
        "app = JupyterDash(__name__)\n",
        "\n",
        "app.layout = html.Div([\n",
        "    dcc.Graph(id=\"graph\", figure=fig, clear_on_unhover=True),\n",
        "    dcc.Tooltip(id=\"graph-tooltip\"),\n",
        "])\n",
        "\n",
        "\n",
        "@app.callback(\n",
        "    Output(\"graph-tooltip\", \"show\"),\n",
        "    Output(\"graph-tooltip\", \"bbox\"),\n",
        "    Output(\"graph-tooltip\", \"children\"),\n",
        "    Input(\"graph\", \"hoverData\"),\n",
        ")\n",
        "def display_hover(hoverData):\n",
        "    if hoverData is None:\n",
        "        return False, no_update, no_update\n",
        "\n",
        "    # demo only shows the first point, but other points may also be available\n",
        "    # for pt in hoverData[\"points\"]:\n",
        "    pt = hoverData[\"points\"][0]\n",
        "    bbox = pt[\"bbox\"]\n",
        "    num = pt[\"pointNumber\"]\n",
        "\n",
        "    # df_row = df.iloc[num]\n",
        "    img_src = image_tags[num]\n",
        "    name = '' #df_row['NAME']\n",
        "    form = '' #df_row['FORM']\n",
        "    desc = tags[num]\n",
        "    if len(desc) > 300: desc = desc[:100] + '...'\n",
        "\n",
        "    children = [\n",
        "        html.Div(children=[\n",
        "            html.Img(src=img_src, style={\"width\": \"100%\"}),\n",
        "            html.H2(f\"{name}\", style={\"color\": \"darkblue\"}),\n",
        "            html.P(f\"{form}\"),\n",
        "            html.P(f\"{desc}\"),\n",
        "        ],\n",
        "        style={'width': '400px', 'white-space': 'normal'})\n",
        "    ]\n",
        "\n",
        "    return True, bbox, children\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    app.run_server(debug=True, mode='inline')"
      ],
      "metadata": {
        "id": "S1t5nJCrCczg"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}