{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"10uIwiM24RAW0u9f8EcrN5ZBswVo1OLYl","timestamp":1721142765969},{"file_id":"1r6ZDLfDrgYKH0VZDX7Tm-IQyatA7LvUi","timestamp":1714074954783},{"file_id":"1e02OYj8HRZMAacsad_3YWEa0HbH47gf_","timestamp":1713486718135},{"file_id":"1j81TFYDvirvoBkbF5YH6a1P9-NdaOwpE","timestamp":1712775638458},{"file_id":"1wJFjJyJVxLhXp_ou0tjMZi3T8KUKv-Xd","timestamp":1709280262066}],"last_runtime":{"build_target":"//gdm/memory/multimodal/retrieval:memory_multimodal_colab_binary","kind":"private"},"private_outputs":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["Copyright 2024 REDACTED FOR ANONYMITY\n","\n","Licensed under the Apache License, Version 2.0 (the \"License\");\n","you may not use this file except in compliance with the License.\n","You may obtain a copy of the License at\n","\n","    https://www.apache.org/licenses/LICENSE-2.0\n","\n","Unless required by applicable law or agreed to in writing, software\n","distributed under the License is distributed on an \"AS IS\" BASIS,\n","WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n","See the License for the specific language governing permissions and\n","limitations under the License."],"metadata":{"id":"cu11043SG9Bt"}},{"cell_type":"markdown","source":["# Nearest neighbor evaluation\n","\n","This colab evaluates Dino / CLIP models on classification datasets, storing nearest neighbor information as JSON."],"metadata":{"id":"8Tyhb1XdIj8W"}},{"cell_type":"code","source":["#@title Imports\n","\n","import json\n","import os\n","import time\n","import tqdm\n","import random\n","import numpy as np\n","import functools\n","\n","import jax\n","from jax.sharding import PartitionSpec as P\n","from jax.experimental import mesh_utils\n","import jax.numpy as jnp\n","import numpy as np\n","import tensorflow as tf"],"metadata":{"id":"_XMAJ31ZlHqZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Jax sharding utils\n","\n","NamedSharding = jax.sharding.NamedSharding\n","\n","mesh_shape = (jax.device_count(),)\n","mesh = mesh_utils.create_device_mesh(\n","    mesh_shape, devices=jax.devices()\n",")\n","mesh = jax.sharding.Mesh(mesh, axis_names=('data',))\n","p = NamedSharding(mesh, P('data', None))\n","\n","print(mesh)\n","\n","\n","def get_shard_array_fn(sharding):\n","  shard_array_fn = jax.jit(lambda x:x, out_shardings =sharding)\n","  return shard_array_fn\n","\n","\n","def shard_array(arr):\n","  if arr.shape[0] % jax.local_device_count() == 0:\n","    shard_array_fn = get_shard_array_fn(p)\n","    return shard_array_fn(arr)\n","  return arr\n","\n","print('num devices = ', jax.local_device_count())"],"metadata":{"id":"IrepdLEa1v78"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Resizing\n","\n","def resize_smaller_side(\n","    images: tf.Tensor, size: int, antialias: bool = False\n",") -> tf.Tensor:\n","  \"\"\"Resizes the smaller side to size preserving the aspect ratio.\n","\n","  Args:\n","    images: image batch of shape [B, H, W, 3].\n","    size: integer that represents a new size of the smaller side of an input\n","      image.\n","    antialias: whether to use an anti-aliasing filter when downsampling an\n","      image.\n","\n","  Returns:\n","    resized images with aspect ratio preserved.\n","  \"\"\"\n","\n","  h, w = tf.shape(images)[-3], tf.shape(images)[-2]\n","\n","  ratio = tf.cast(size, tf.float32) / tf.cast(tf.minimum(h, w), tf.float32)\n","  h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)\n","  w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)\n","  images = tf.image.resize(\n","      images, [h, w], method=tf.image.ResizeMethod.BICUBIC, antialias=antialias\n","  )\n","  return images\n","\n","\n","def central_crop(images: tf.Tensor, size: int) -> tf.Tensor:\n","  \"\"\"Central crop images to size.\n","\n","  Args:\n","    images: images of shape [B, H, W, C] as float32 tensor.\n","    size: integer that represents the new height and width of the images.\n","\n","  Returns:\n","    resized images.\n","  \"\"\"\n","\n","  # assert len(tf.shape(images)) == 4\n","\n","  h, w = size, size\n","  top = (tf.shape(images)[-3] - h) // 2\n","  left = (tf.shape(images)[-2] - w) // 2\n","  return tf.image.crop_to_bounding_box(images, top, left, h, w)"],"metadata":{"id":"hISVCGBCXEkb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Define DINO featurizer\n","\n","class DinoFeatures():\n","  \"\"\"A class to extract DINO features from images.\"\"\"\n","\n","  # DINO was trained on ImageNet data that is centered and scaled using the\n","  # following mean and stddev. We need to transform image values from the\n","  # [0.0, 1.0] range using the mean and std values below as follows:\n","  # val = (val - mean) / std  ... for each R,G,B channel.\n","  MEAN_RGB = np.array([0.485, 0.456, 0.406])\n","  STDDEV_RGB = np.array([0.229, 0.224, 0.225])\n","\n","  def __init__(\n","      self,\n","      model_name: str = '',\n","      is_tpu_inference: bool = True,\n","  ):\n","    # load model weights, not included here\n","    pass\n","\n","  def preprocess_tpu(\n","      self,\n","      images: tf.Tensor,\n","      aspect_ratio_size: int,\n","      central_crop_size: int,\n","      antialias: bool,\n","  ) -> tf.Tensor:\n","    images_to_preprocess = images['image']\n","    resized_images = resize_smaller_side(\n","        images_to_preprocess, size=aspect_ratio_size, antialias=antialias\n","    )\n","    cropped_images = central_crop(resized_images, size=central_crop_size)\n","    images['image'] = cropped_images\n","    return images\n","\n","  def extract_batch(self, image: np.ndarray) -> np.ndarray:\n","    return self._extract_batch(image)\n","\n","  def _extract_batch(self, images: np.ndarray) -> np.ndarray:\n","    \"\"\"Computes DINO features on the given image.\n","\n","    Args:\n","      images: Image tensor [B H W C] with values in [0, 255] range.\n","             Image channels must be ordered as RGB and not BGR.\n","\n","    Returns:\n","      features: extracted pooled DINO features for the image.\n","    \"\"\"\n","\n","    # Preprocessing follows DINO v2 code for kNN eval (e.g.\n","    # https://github.com/facebookresearch/dino/blob/main/eval_knn.py)\n","    # exact preprocessing from here:\n","    # https://github.com/facebookresearch/dino/issues/149\n","    images = images.astype(np.float32)/255.0\n","    images = (images - DinoFeatures.MEAN_RGB) / DinoFeatures.STDDEV_RGB\n","    # Calling model's apply function to encode the images.\n","    _, features = self._apply_fn(images)\n","    features = jnp.apply_along_axis(lambda x: x / jnp.linalg.norm(x), arr=features, axis=1)\n","\n","    return features"],"metadata":{"id":"05Eo7Nd2FHLT"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Define CLIP featurizer\n","\n","class ClipFeatures():\n","  \"\"\"A class to extract Clip features from images.\"\"\"\n","\n","  def __init__(\n","      self,\n","      model_name: str = '',\n","      is_tpu_inference: bool = False,\n","  ):\n","    # load model weights, not included here\n","    pass\n","\n","  def preprocess_tpu(\n","      self,\n","      images: tf.Tensor,\n","      aspect_ratio_size: int,\n","      central_crop_size: int,\n","      antialias: bool,\n","  ) -> tf.Tensor:\n","\n","    images_preprocess = images['image']\n","\n","    resized_images = resize_smaller_side(\n","        images_preprocess, size=aspect_ratio_size, antialias=antialias\n","    )\n","    cropped_images = central_crop(resized_images, size=central_crop_size)\n","    images['image'] = cropped_images\n","    return images\n","\n","  def extract_batch(self, image: np.ndarray) -> np.ndarray:\n","    return self._extract_batch(image)\n","\n","  def _extract_batch(self, images: np.ndarray) -> np.ndarray:\n","    \"\"\"Computes CLIP features on the given preprocessed image batch.\n","\n","    Args:\n","      images: Float image tensor [B H W C] with values in [0, 255] range. Image\n","        channels must be ordered as RGB and not BGR.\n","\n","    Returns:\n","      ? pre-logit features [B, D] and D is the model variant dependent\n","      feature dimension.\n","    \"\"\"\n","    if len(images.shape) != 4:\n","      raise ValueError(f'Image must be (B, H, W, 3) but got {images.shape}')\n","    if images.shape[-1] != 3:\n","      raise ValueError(f'Image must be 3 channels but got {images.shape}')\n","\n","    images = images.astype(np.float32) / 255.0\n","    images = jnp.array(images)\n","    images = clip.normalize_image(images) # CLIP normalization, not included here\n","\n","    embedding, _ = self._apply_fn(image=images)\n","    return embedding\n","\n","  def extract_text_features(self, text_queries) -> np.ndarray:\n","    tokens = self._tokenizer(text_queries)\n","    _, text_features = self._model.apply(self._params, image=None,\n","                                         text=tokens)\n","    return text_features"],"metadata":{"id":"2LnD1cPBxvp8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Define preprocessing\n","\n","def dataset_preprocess(dataset, featurizer, batch_size, resizing_size, antialias):\n","  dataset_preprocessed = dataset.map(\n","      lambda x: featurizer.preprocess_tpu(\n","          x,\n","          resizing_size,\n","          224,\n","          antialias=antialias,\n","      ),\n","      num_parallel_calls=tf.data.AUTOTUNE,\n","  )\n","\n","  return dataset_preprocessed.batch(batch_size, drop_remainder=False)"],"metadata":{"id":"Mh66WYjIznNb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Utils for subsampling\n","\n","def subsample_train_set_embeddings(train_embeddings, num_per_class=100):\n","  subsampled_train_embeddings = {\n","      'file_name': [],\n","      'labels': [],\n","      'features': [],\n","  }\n","  examples_by_class = {}\n","\n","  # Bucket by class.\n","  for idx, label in enumerate(train_embeddings['labels']):\n","    if label not in examples_by_class:\n","      examples_by_class[label] = {\n","          'file_name': [],\n","          'labels': [],\n","          'features': []\n","      }\n","\n","    examples_by_class[label]['features'].append(train_embeddings['features'][idx])\n","    examples_by_class[label]['file_name'].append(train_embeddings['file_name'][idx])\n","    examples_by_class[label]['labels'].append(train_embeddings['labels'][idx])\n","\n","  # Subsample per class.\n","  subsampled_examples_by_class = {}\n","\n","  for key in examples_by_class:\n","    length = len(examples_by_class[key]['file_name'])\n","\n","    indices = list(range(length))\n","    random.shuffle(indices)\n","    random_indices = indices[:num_per_class]\n","\n","    subsampled_examples_by_class[key] = {}\n","    subsampled_examples_by_class[key]['file_name'] = np.array(examples_by_class[key]['file_name'])[random_indices]\n","    subsampled_examples_by_class[key]['labels'] = np.array(examples_by_class[key]['labels'])[random_indices]\n","    subsampled_examples_by_class[key]['features'] = np.array(examples_by_class[key]['features'])[random_indices]\n","\n","  # Put back into original form\n","  for key in subsampled_examples_by_class:\n","    subsampled_train_embeddings['file_name'].extend(subsampled_examples_by_class[key]['file_name'])\n","    subsampled_train_embeddings['labels'].extend(subsampled_examples_by_class[key]['labels'])\n","    subsampled_train_embeddings['features'].extend(subsampled_examples_by_class[key]['features'])\n","\n","  subsampled_train_embeddings['features'] = np.array(subsampled_train_embeddings['features'])\n","  subsampled_train_embeddings['file_name'] = np.array(subsampled_train_embeddings['file_name'])\n","  subsampled_train_embeddings['labels'] = np.array(subsampled_train_embeddings['labels'])\n","\n","  return subsampled_train_embeddings"],"metadata":{"id":"AuGbFaE3qBZw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def scale_memory(featurizer_name,\n","                 train_embeddings,\n","                 validation_embeddings,\n","                 file_path_prefix,\n","                 num_imgs_per_class_list = [1,10,100,1000],\n","                 top_k=100,\n","                 batch_size=1024):\n","\n","  for num_imgs_per_class in num_imgs_per_class_list:\n","\n","    subsampled_train_embeddings = subsample_train_set_embeddings(train_embeddings, num_per_class = num_imgs_per_class)\n","    train_x_train_subsampled = get_nearest_neighbors(subsampled_train_embeddings, subsampled_train_embeddings, k=top_k+1, batch_size=min(batch_size, subsampled_train_embeddings['features'].shape[1]))\n","    val_x_train_subsampled = get_nearest_neighbors(validation_embeddings, subsampled_train_embeddings, k=top_k, batch_size=min(batch_size, validation_embeddings['features'].shape[1]))\n","    print_accuracy(label=f'{featurizer_name}_train_x_train_subsampled_{num_imgs_per_class}', neighbors=train_x_train_subsampled)\n","    print_accuracy(label=f'{featurizer_name}_val_x_train_subsampled_{num_imgs_per_class}', neighbors=val_x_train_subsampled)\n","\n","    size = int(1000 * num_imgs_per_class)\n","    file_path = f'{file_path_prefix}_{size}_neighbor_info.json'\n","    write_neighbor_info_to_json(val_x_train_subsampled, file_path, FEATURIZER_NAME)\n","\n","    file_path = f'{file_path_prefix}_{size}_neighbor_info.json'\n","    write_neighbor_info_to_json(train_x_train_subsampled, file_path, FEATURIZER_NAME)"],"metadata":{"id":"dXwB_x72ypFB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Loading dataset from TFDS\n","import tensorflow.compat.v2 as tf\n","tf.enable_v2_behavior()\n","\n","import tensorflow_datasets as tfds\n","\n","def rename_feature(example):\n","    example['file_name'] = example.pop('image/filename')\n","    return example\n","\n","def load_dataset(dataset_name, split):\n","\n","  if dataset_name in ['ninco']:\n","    # https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder\n","    builder = tfds.folder_dataset.ImageFolder(f'/path/to/{dataset_name}/')\n","    dataset = builder.as_dataset(split=split, shuffle_files=False)\n","    dataset = dataset.map(rename_feature)\n","  else:\n","    dataset = tfds.load(dataset_name, split=split, shuffle_files=False)\n","  return dataset\n","\n","def load_combined_datasets(list_of_names_and_splits):\n","\n","  first_name, first_split = list_of_names_and_splits[0]\n","  combined_dataset = load_dataset(dataset_name=first_name, split=first_split)\n","\n","  for dataset_name, split in list_of_names_and_splits[1:]:\n","    dataset_2 = load_dataset(dataset_name=dataset_name, split=split)\n","    combined_dataset = combined_dataset.concatenate(dataset_2)\n","\n","  return combined_dataset\n","\n","def load_dataset_batches(list_of_names_and_splits, featurizer, batch_size, resizing_size, antialias):\n","\n","  dataset = load_combined_datasets(list_of_names_and_splits)\n","  batches = dataset_preprocess(dataset, featurizer, batch_size, resizing_size, antialias)\n","  return batches"],"metadata":{"id":"EBpRT95rz2DD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Get model embeddings\n","\n","def get_model_embeddings(featurizer, batched_dataset):\n","  \"\"\"Generate embeddings for featurizer.\n","\n","  Args:\n","    featurizer: model to be used to generate embeddings.\n","    batched_dataset: dataset preprocess and batched.\n","\n","  Returns:\n","    features_dict: dictionary containing embeddings of batched_dataset\n","\n","  \"\"\"\n","  features, labels, file_name = [], [], []\n","  for batch in tqdm.tqdm(iter(batched_dataset)):\n","    batch_sharded = shard_array(batch['image'].numpy())\n","    features.append(featurizer.extract_batch(batch_sharded))\n","    labels.append(batch['label'].numpy())\n","\n","    if 'file_name' in batch:\n","      file_name.append(batch['file_name'].numpy())\n","    elif 'image/filename' in batch:\n","      file_name.append(batch['image/filename'].numpy())\n","    else:\n","      raise ValueError('file_name or image/filename not found in batch')\n","\n","  features = np.vstack(features)\n","  labels = np.concatenate(labels, axis=0)\n","  file_name = np.concatenate(file_name, axis=0)\n","\n","  features_dict = {\n","      'features': features,\n","      'labels': labels,\n","      'file_name': file_name,\n","  }\n","  return features_dict"],"metadata":{"id":"AHEXLX6G3zVN"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Utils to compute top K neighbors\n","\n","@jax.jit\n","def jitted_dot_fn(x, y):\n","  dot_product =  jnp.dot(x, y.T, precision=jax.lax.Precision.HIGHEST)\n","  dot_product = jax.lax.with_sharding_constraint(dot_product, NamedSharding(mesh, P('data', None)))\n","\n","  return dot_product\n","\n","\n","# Extracts indices of Top-K neighbors\n","@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n","def extract_min_indices(dot_products, top_k=100):\n","  # array of items of shape [1024, 1024], len of this array is 1252\n","  # after concatenate it becomes array of shape [1024, 1,211,867]\n","  dot_products = jnp.concatenate(dot_products, axis=1)\n","  dot_products, indices = jnp.apply_along_axis(lambda x: jax.lax.top_k(x, k=top_k), arr=dot_products, axis=1)\n","\n","  return dot_products, indices\n","\n","def get_nearest_neighbors(query_features, neighbor_features, k, batch_size):\n","  \"\"\"Generate neighbors for the test images for kNN classification.\n","\n","  Args:\n","    train_features: list feature embeddings for train set.\n","    test_features: feature embeddings for test set.\n","    k: number of nearest neighbor.\n","    batch_size: batch size for dot product calculation.\n","\n","  Returns:\n","    _: dictionary containing information of nearest neighbors of test set.\n","\n","  \"\"\"\n","\n","  neighbor_info = {}\n","  batched_query_features = []\n","  batched_neighbor_features = []\n","\n","  # Form batches from entire dataset\n","  print('Sharding query batches ....')\n","  for idx in tqdm.tqdm(range(0, query_features['features'].shape[0], batch_size)):\n","    features = query_features['features'][\n","      idx : min(idx + batch_size, query_features['features'].shape[0])\n","    ]\n","\n","    query_labels = query_features['labels'][idx : min(idx + batch_size, query_features['features'].shape[0])]\n","    query_filenames = query_features['file_name'][idx : min(idx + batch_size, query_features['features'].shape[0])]\n","\n","    features = shard_array(features)\n","\n","    batched_query_features.append(\n","      {\n","        'features': features,\n","        'labels': query_labels,\n","        'filenames': query_filenames,\n","      })\n","\n","  print('Sharding neighbor batches ....')\n","  for idx in tqdm.tqdm(range(0, neighbor_features['features'].shape[0], batch_size)):\n","    features = neighbor_features['features'][\n","      idx : min(idx + batch_size, neighbor_features['features'].shape[0])\n","    ]\n","\n","    query_labels = neighbor_features['labels'][idx : min(idx + batch_size, neighbor_features['features'].shape[0])]\n","    query_filenames = neighbor_features['file_name'][idx : min(idx + batch_size, neighbor_features['features'].shape[0])]\n","\n","    features = shard_array(features)\n","\n","    batched_neighbor_features.append(\n","      {\n","        'features': features,\n","        'labels': query_labels,\n","        'filenames': query_filenames,\n","      })\n","\n","  # Compute dot products for 1 batch and extract top-K neighbors.\n","  print('Extracting top-K neighbors ....')\n","  for query_batch in tqdm.tqdm(batched_query_features):\n","    dot_products_list = []\n","    neighbor_file_names_list = []\n","    neighbor_labels_list = []\n","\n","    for neighbor_batch in batched_neighbor_features:\n","      # calculate the dot product and compute top-k neighbors\n","      distances = jitted_dot_fn(query_batch['features'], neighbor_batch['features'])\n","      neighbor_file_names = neighbor_batch['filenames']\n","      neighbor_labels = neighbor_batch['labels']\n","\n","      dot_products_list.append(distances)\n","      neighbor_file_names_list.append(neighbor_file_names)\n","      neighbor_labels_list.append(neighbor_labels)\n","\n","    neighbor_labels_list = np.concatenate(neighbor_labels_list, axis=0)\n","    dot_products_list, min_distance_indices = extract_min_indices(dot_products_list, top_k=k)\n","\n","    batch_size = dot_products_list.shape[0]\n","    neighbor_labels_list = np.broadcast_to(neighbor_labels_list, (batch_size, neighbor_labels_list.shape[0]))\n","\n","    neighbor_file_names_list = np.concatenate(neighbor_file_names_list, axis=0)\n","    neighbor_file_names_list = np.broadcast_to(\n","        neighbor_file_names_list, (batch_size, neighbor_file_names_list.shape[0]))\n","\n","    neighbor_file_names_list = np.take_along_axis(neighbor_file_names_list, min_distance_indices, axis=1)\n","    neighbor_labels_list = np.take_along_axis(neighbor_labels_list, min_distance_indices, axis=1)\n","    dot_products_list = jax.device_get(dot_products_list)\n","\n","    min_distance_indices.delete()\n","\n","    for index in range(batch_size):\n","      key = str(query_batch['filenames'][index].decode('utf-8'))\n","\n","      neighbor_info[key] = {\n","          'image_id': key,\n","          'image_class': query_batch['labels'][index],\n","          'neighbor_image_ids': neighbor_file_names_list[index],\n","          'neighbor_classes': neighbor_labels_list[index],\n","          'neighbor_distances': 1.0 - dot_products_list[index],\n","      }\n","\n","\n","  return neighbor_info"],"metadata":{"id":"6AXeAxY6_ySs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Write neighbor info to JSON\n","def write_neighbor_info_to_json(neighbors, file_path, featurizer_name):\n","\n","  t1 = time.time()\n","\n","  data = {}\n","\n","  for key in neighbors:\n","    k = neighbors[key]['image_id']\n","\n","    data[k] = {}\n","    data[k]['featurizer'] = featurizer_name\n","    data[k]['image_id'] = neighbors[key]['image_id']\n","    data[k]['image_class'] = int(neighbors[key]['image_class'])\n","    data[k]['neighbor_image_ids'] = [x.decode('utf-8') for x in neighbors[key]['neighbor_image_ids'].tolist()]\n","    data[k]['neighbor_classes'] = neighbors[key]['neighbor_classes'].tolist()\n","    data[k]['neighbor_distances'] = neighbors[key]['neighbor_distances'].tolist()\n","\n","  print(f'Writing json to {file_path}')\n","  with open(file_path, mode='w') as f:\n","    json.dump(data, f)\n","\n","  t2 = time.time()\n","  print(f'Time to write: {round(t2 - t1, 1)} seconds')"],"metadata":{"id":"NGN2rF0_Spdt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Evaluate k=1 accuracy\n","\n","def print_accuracy(label, neighbors):\n","  correct = 0\n","  total = 0\n","\n","  for key in neighbors:\n","    neighbor_classes = neighbors[key]['neighbor_classes']\n","\n","    if neighbor_classes[0] == neighbors[key]['image_class']:\n","      correct += 1\n","    total += 1\n","\n","  print('label=%s, correct: %d, total: %d, accuracy: %f' % (label, correct, total, correct*100 / total))"],"metadata":{"id":"BdKr4OFy8YJK"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Assert valid dataset names\n","\n","def assert_valid_memory_dataset(memory_dataset: str):\n","  assert memory_dataset in ['imagenet2012', 'ninco']\n","\n","def assert_valid_query_dataset(query_dataset: str):\n","  assert query_dataset in ['imagenet2012',\n","                           'imagenet_v2',\n","                           'imagenet_r',\n","                           'imagenet_sketch',\n","                           'imagenet_a',\n","                           'imagenet2012_real',\n","                           'ninco']\n","\n","def get_memory_name_and_split(memory_list):\n","\n","  name_index = 0\n","  split_index = 1\n","  mname = memory_list[0][name_index]\n","  msplit = memory_list[0][split_index]\n","\n","  for m, s in memory_list[1:]:\n","    mname = f'{mname}-and-{m}'\n","    msplit = f'{msplit}-and-{s}'\n","  return mname, msplit"],"metadata":{"id":"hFurrENLybcR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Load multiple models and datasets\n","\n","# NOTE: multiple memory datasets will be added as a combined memory\n","# MEMORY = [('imagenet2012', 'train'), ('ninco', 'test')]\n","MEMORY = [('imagenet2012', 'train')]\n","\n","SCALE_MEMORY = False # set to True for storing down-scaled memory in addition to full memory results\n","TOP_K = 100 # Note that for train, TOP_K+1 is saved automatically below.\n","batch_size=1024\n","\n","RESULT_DIR = '/path/to/result/directory'\n","\n","FEATURIZER_DICT = {\n","    'dinov2_vitl14': lambda: DinoFeatures(model_name='dinov2_vitl14', is_tpu_inference=True),\n","    #'dinov2_vitb14': lambda: DinoFeatures(model_name='dinov2_vitb14', is_tpu_inference=True),\n","    #'dinov2_vits14': lambda: DinoFeatures(model_name='dinov2_vits14', is_tpu_inference=True),\n","    #'clip-vit_l14':  lambda: ClipFeatures(model_name='vit_l14', is_tpu_inference=True),\n","    #'clip-vit_b16':  lambda: ClipFeatures(model_name='vit_b16', is_tpu_inference=True),\n","                   }\n","QUERY_DATASET_LIST = [\n","    ('imagenet2012', 'validation'),\n","    ('imagenet_v2', 'test'),\n","    ('imagenet_r', 'test'),\n","    ('imagenet_sketch', 'test'),\n","    ('imagenet_a', 'test'),\n","    ('ninco', 'test'),\n","    ('imagenet2012_real', 'validation'),\n","    ]\n","\n","# assert dataset names are valid\n","for m, _ in MEMORY:\n","  assert_valid_memory_dataset(m)\n","for (q, _) in QUERY_DATASET_LIST:\n","  assert_valid_query_dataset(q)"],"metadata":{"id":"9o-GL-I2wrOW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Main run loop\n","\n","for FEATURIZER_NAME, model_loader in FEATURIZER_DICT.items():\n","\n","  # Load model\n","  print('Loading model: ', FEATURIZER_NAME)\n","  featurizer = model_loader()\n","\n","  # Set preprocessing details\n","  # Dino preprocessing: resize(256) then crop(224)\n","  # https://github.com/facebookresearch/dinov2/blob/main/dinov2/data/transforms.py#L77\n","  # CLIP preprocessing: resize(224) then crop(224)\n","  # https://github.com/openai/CLIP/blob/main/clip/clip.py#L79\n","\n","  antialias = 'dino' in FEATURIZER_NAME\n","  print('antialias: ', antialias)\n","  resizing_size = 256\n","  if 'clip' in FEATURIZER_NAME:\n","    resizing_size = 224\n","  print('resizing size: ', resizing_size)\n","\n","  # Load train batches\n","  MEMORY_DATASET, MEMORY_SPLIT = get_memory_name_and_split(MEMORY)\n","  print(f'Loading memory dataset {MEMORY_DATASET} with split {MEMORY_SPLIT}')\n","\n","  train_batches = load_dataset_batches(list_of_names_and_splits=MEMORY,\n","                                       featurizer=featurizer,\n","                                       batch_size=batch_size,\n","                                       resizing_size=resizing_size,\n","                                       antialias=antialias)\n","  train_embeddings = get_model_embeddings(featurizer, train_batches)\n","  print(train_embeddings['features'].shape)\n","\n","  train_x_train_full_neighbors = get_nearest_neighbors(train_embeddings, train_embeddings, k=TOP_K+1, batch_size=batch_size)\n","  print_accuracy(label=f'{FEATURIZER_NAME}_train_x_train_full', neighbors=train_x_train_full_neighbors)\n","\n","  file_path = f'{RESULT_DIR}/memory-{MEMORY_DATASET.replace(\"_\", \"-\")}_msplit-{MEMORY_SPLIT}_query-{MEMORY_DATASET.replace(\"_\", \"-\")}_qsplit-{MEMORY_SPLIT}_{FEATURIZER_NAME}_full_neighbor_info.json'\n","  write_neighbor_info_to_json(train_x_train_full_neighbors, file_path, FEATURIZER_NAME)\n","\n","  for (QUERY_DATASET, QUERY_SPLIT) in QUERY_DATASET_LIST:\n","\n","    print(f'Loading query dataset {QUERY_DATASET} with split {QUERY_SPLIT}')\n","    query_batches = load_dataset_batches(list_of_names_and_splits=[(QUERY_DATASET, QUERY_SPLIT)],\n","                                         featurizer=featurizer,\n","                                         batch_size=batch_size,\n","                                         resizing_size=resizing_size,\n","                                         antialias=antialias)\n","    query_embeddings = get_model_embeddings(featurizer, query_batches)\n","    val_x_train_full_neighbors = get_nearest_neighbors(query_embeddings, train_embeddings, k=TOP_K, batch_size=batch_size)\n","    print_accuracy(label=f'{FEATURIZER_NAME}_val_x_train_full', neighbors=val_x_train_full_neighbors)\n","\n","    qdataset = QUERY_DATASET.replace(\"_\", \"-\")\n","    if qdataset == 'imagenet2012-real':\n","      qdataset = 'imagenet-real'\n","    file_path = f'{RESULT_DIR}/memory-{MEMORY_DATASET.replace(\"_\", \"-\")}_msplit-{MEMORY_SPLIT}_query-{qdataset}_qsplit-{QUERY_SPLIT}_{FEATURIZER_NAME}_full_neighbor_info.json'\n","    write_neighbor_info_to_json(val_x_train_full_neighbors, file_path, FEATURIZER_NAME)\n","\n","    if SCALE_MEMORY:\n","      scale_memory(featurizer_name=FEATURIZER_NAME,\n","                   train_embeddings=train_embeddings,\n","                   validation_embeddings=query_embeddings,\n","                   file_path_prefix = file_path.replace('_full_neighbor_info.json', ''),\n","                   num_imgs_per_class_list = [1,10,100,1000],\n","                   top_k=TOP_K)"],"metadata":{"id":"BcrCFNFy994O"},"execution_count":null,"outputs":[]}]}