{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tMPs4M5kgt6_"
      },
      "source": [
        "# License\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at:\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "On-5AY2HPQqw"
      },
      "source": [
        "# Instructions\n",
        "\n",
        "This Notebook allows to reproduce all the experiments reported in the publication titled:\n",
        "\n",
        "\"*An Evolutionary Approach to Dynamic Introduction\n",
        "of Tasks in Large-scale Multitask Learning Systems*\" (2022)\n",
        "\n",
        "---\n",
        "\n",
        "Set `EXPERIMENT_NAME` to a name of choice.\n",
        "\n",
        "Set `BENCHMARK` to:\n",
        "\n",
        "1. `ViT tiny 0 layers / Cmaterdb benchmark` to reproduce the preliminarty experiment.\n",
        "1. `ViT large / ViT benchmark` to run 1 iteration on the ViT benchmark.\n",
        "1. `ViT large / VTAB-full benchmark` to run 1 iteration on the VTAB-full benchmark.\n",
        "1. `ViT large / VDD benchmark` to run 1 iteration on the VDD benchmark.\n",
        "1. `ViT large / Chars benchmark` to run 1 iteration on the Multitask Character Classification benchmark.\n",
        "1. `ViT large / VTAB-1k benchmark` to run 1 iteration on the VTAB-1k benchmark.\n",
        "\n",
        "To reproduce the sequence of iterations described in the paper, it is required to iteratively extend the experiment by setting `CONTINUED_FROM_STATE_DIR` to the directory storing the final state of the experiment that needs to be continued (will not be overwritten).\n",
        "If `CONTINUED_FROM_STATE_DIR` is set to empty, then the new experiment will start to generate a new multimodal system instead of extending an existing one.\n",
        "\n",
        "Set `CONFIGURATION` to `muNet` to run the μ2Net evolutionary method.\n",
        "\n",
        "Select `AUTOTUNE` to activate auto-tuning for muNet experiments.\n",
        "\n",
        "Set `EXPERIMENTS_ROOT_DIR` to the desired root directory that will contain experiment directories storing configuration and state.\n",
        "\n",
        "To reproduce the configuration of the experiments reported in the paper it is required to connect to a multihost TPUv4 allocation with 32 chip in MegaCore configuration.\n",
        "\n",
        "To start the configured experiment select \"Run all\" from the \"Runtime\" menu.\n",
        "\n",
        "The output is printed after the last cell."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M93tll7z29rX"
      },
      "outputs": [],
      "source": [
        "# @title Experiment parameters\n",
        "EXPERIMENT_NAME = 'Experiment'  # @param { type: 'string', isTemplate: true }\n",
        "BENCHMARK = 'ViT tiny 0 layers / Cmaterdb benchmark' # @param ['ViT tiny 0 layers / Cmaterdb benchmark', 'ViT tiny 3 layers / Chars benchmark', 'ViT base / VDD benchmark', 'ViT large / ViT benchmark', 'ViT large / VTAB-full benchmark', 'ViT large / VDD benchmark', 'ViT large / Chars benchmark', 'ViT large / VTAB-1k benchmark'] { type: 'string', isTemplate: true }\n",
        "CONFIGURATION = 'muNet' # @param ['muNet', 'Size scale:98', 'Size scale:95', 'Size scale:90', 'Size scale:70', 'Size scale:30', 'Size scale:2', 'Finetune all', 'Freeze bottom layers:0', 'Freeze bottom layers:1', 'Freeze bottom layers:2', 'Freeze bottom layers:3', 'Freeze bottom layers:4', 'Freeze bottom layers:12', 'Adapters:8', 'Adapters:16', 'Adapters:32', 'Adapters:64', 'Adapters:128', 'Adapters:256', 'Adapters:512']  { type: 'string', isTemplate: true }\n",
        "AUTO_TUNE = True # @param [True, False] { type: 'boolean', isTemplate: true }\n",
        "EXPERIMENTS_ROOT_DIR = '/tmp/' # @param { type: 'string', isTemplate: true }\n",
        "CONTINUED_FROM_STATE_DIR = '' # @param { type: 'string', isTemplate: true }\n",
        "\n",
        "if AUTO_TUNE:\n",
        "  assert CONFIGURATION == 'muNet' or CONFIGURATION.startswith('Size scale:'), \\\n",
        "      f'Invalid configuration for auto-tune: {CONFIGURATION}'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iqE9f8QHR8j8"
      },
      "outputs": [],
      "source": [
        "# @title Additional parameters\n",
        "# Set to true to continue interrupted experiment with matching EXPERIMENT_NAME\n",
        "AUTO_CONTINUE = False  # @param [True, False] { type: 'boolean', isTemplate: true }\n",
        "# Print debug statements.\n",
        "VERBOSE = False  # @param [True, False] { type: 'boolean', isTemplate: true }\n",
        "# Skip intermediate state save if last state was written within this time range.\n",
        "SKIP_INTERMEDIATE_STATE_SECS = 3600  # @param { type: 'integer', isTemplate: true }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RrmU42P6Wm1r"
      },
      "outputs": [],
      "source": [
        "!pip install -q flax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QmcbmqaGWnRI"
      },
      "outputs": [],
      "source": [
        "!pip install -q ml_collections"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PxktiVGLjisX"
      },
      "outputs": [],
      "source": [
        "!pip install -q tensorflow_addons"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L5Gv2qoaWo1s"
      },
      "outputs": [],
      "source": [
        "![ -d task_adaptation ] || git clone --depth=1 https://github.com/google-research/task_adaptation\n",
        "![ -d vision_transformer ] || git clone --depth=1 https://github.com/google-research/vision_transformer\n",
        "!pip install -qr vision_transformer/vit_jax/requirements.txt\n",
        "import sys\n",
        "if './task_adaptation' not in sys.path:\n",
        "  sys.path.append('./task_adaptation')\n",
        "if './vision_transformer' not in sys.path:\n",
        "  sys.path.append('./vision_transformer')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LzhEweKzMg6k"
      },
      "outputs": [],
      "source": [
        "import copy\n",
        "import datetime\n",
        "import gc\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import json\n",
        "import math\n",
        "import matplotlib\n",
        "import numpy as np\n",
        "import os\n",
        "import optax\n",
        "import pandas as pd\n",
        "import random\n",
        "import re\n",
        "import time\n",
        "from collections import defaultdict\n",
        "from functools import partial\n",
        "from matplotlib import pyplot as plt\n",
        "from threading import Thread\n",
        "from typing import Optional"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MXlZac9HXRpN"
      },
      "outputs": [],
      "source": [
        "import flax\n",
        "import flax.linen as nn\n",
        "from flax.training import checkpoints as flax_checkpoints"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CKrbsWJ2PfBV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "tf.compat.v1.enable_eager_execution()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6QHBzcuUYeh5"
      },
      "outputs": [],
      "source": [
        "from ml_collections import ConfigDict, FrozenConfigDict\n",
        "from vision_transformer.vit_jax import input_pipeline\n",
        "from vision_transformer.vit_jax import checkpoint\n",
        "from vision_transformer.vit_jax.configs import models as models_config  # Model configurations.\n",
        "from vision_transformer.vit_jax import models_vit as models # Actual model code."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d5XWg-sLbFVH"
      },
      "outputs": [],
      "source": [
        "import task_adaptation.registry as task_adapt_registry\n",
        "import task_adaptation.data.caltech\n",
        "import task_adaptation.data.cifar\n",
        "import task_adaptation.data.dtd\n",
        "import task_adaptation.data.oxford_flowers102\n",
        "import task_adaptation.data.oxford_iiit_pet\n",
        "import task_adaptation.data.sun397\n",
        "import task_adaptation.data.svhn\n",
        "import task_adaptation.data.patch_camelyon\n",
        "import task_adaptation.data.eurosat\n",
        "import task_adaptation.data.resisc45\n",
        "import task_adaptation.data.diabetic_retinopathy\n",
        "import task_adaptation.data.clevr\n",
        "import task_adaptation.data.dmlab\n",
        "import task_adaptation.data.dsprites\n",
        "import task_adaptation.data.kitti\n",
        "import task_adaptation.data.smallnorb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SHBWj0JmpWDX"
      },
      "outputs": [],
      "source": [
        "# Ref TFDS catalog: https://www.tensorflow.org/datasets/catalog/beans\n",
        "# These are usable task ids of image classificatiion tasks from the TFDS catalog.\n",
        "TFDS_IMAGE_CLASSIFCATON_DATASETS = set([\n",
        "    'beans',\n",
        "    'binary_alpha_digits',\n",
        "    'caltech_birds2010',\n",
        "    'caltech_birds2011',\n",
        "    'cars196',\n",
        "    'cassava',\n",
        "    'cats_vs_dogs',\n",
        "    'cifar10',\n",
        "    'cifar100',\n",
        "    'cifar10_1',\n",
        "    'citrus_leaves',\n",
        "    'cmaterdb/bangla',\n",
        "    'cmaterdb/devanagari',\n",
        "    'cmaterdb/telugu',\n",
        "    'colorectal_histology',\n",
        "    'deep_weeds',\n",
        "    'emnist/balanced',\n",
        "    'emnist/byclass',\n",
        "    'emnist/bymerge',\n",
        "    'emnist/digits',\n",
        "    'emnist/letters',\n",
        "    'emnist/mnist',\n",
        "    'fashion_mnist',\n",
        "    'food101',\n",
        "    'horses_or_humans',\n",
        "    'imagenet2012',\n",
        "    'imagenet2012_subset',\n",
        "    'imagenet_lt',\n",
        "    'imagenet_resized/8x8',\n",
        "    'imagenet_resized/16x16',\n",
        "    'imagenet_resized/32x32',\n",
        "    'imagenet_resized/64x64',\n",
        "    'imagenet_v2',\n",
        "    'imagenette',\n",
        "    'kmnist',\n",
        "    'malaria',\n",
        "    'mnist',\n",
        "    'mnist_corrupted',\n",
        "    'omniglot',\n",
        "    'plant_village',\n",
        "    'plantae_k',\n",
        "    'quickdraw_bitmap',\n",
        "    'rock_paper_scissors',\n",
        "    'stanford_dogs',\n",
        "    'stl10',\n",
        "    'tf_flowers',\n",
        "    'uc_merced',\n",
        "    'visual_domain_decathlon/aircraft',\n",
        "    'visual_domain_decathlon/cifar100',\n",
        "    'visual_domain_decathlon/daimlerpedcls',\n",
        "    'visual_domain_decathlon/dtd',\n",
        "    'visual_domain_decathlon/gtsrb',\n",
        "    'visual_domain_decathlon/imagenet12',\n",
        "    'visual_domain_decathlon/omniglot',\n",
        "    'visual_domain_decathlon/svhn',\n",
        "    'visual_domain_decathlon/ucf101',\n",
        "    'visual_domain_decathlon/vgg-flowers',\n",
        "    ])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XI7xcfWLXx-G"
      },
      "outputs": [],
      "source": [
        "# Tasks ids of the VTAB tasks.\n",
        "# Append suffix '/1k' to get the 1k version of each task.\n",
        "VTAB_TASKS = [\n",
        "              ## NATURAL TASKS\n",
        "              'caltech101',\n",
        "              # cifar100/10 were already added with slightly different val split but same test set.\n",
        "              # So here is added only the 1k versions.\n",
        "              'cifar100/1k',\n",
        "              'cifar10/1k',\n",
        "              'dtd',\n",
        "              'oxford_flowers102',\n",
        "              'oxford_iiit_pet',\n",
        "              'sun397',\n",
        "              'svhn_cropped',\n",
        "              ## SPECIALIZED TASKS\n",
        "              'patch_camelyon',\n",
        "              'eurosat',\n",
        "              'resisc45',\n",
        "              'diabetic_retinopathy_detection/btgraham-300',\n",
        "              ## STRUCTURED TASKS\n",
        "              'clevr/count_cylinders',  # Not in results table.\n",
        "              'clevr/count_all',  # Clevr-Count\n",
        "              'clevr/closest_object_distance',  # Clevr-Dist\n",
        "              'dmlab',\n",
        "              'dsprites/label_x_position',  # dSpr-Loc\n",
        "              'dsprites/label_orientation',  # dSpr-Ori\n",
        "              'kitti/closest_object_distance',  # Not in results table.\n",
        "              'kitti/count_vehicles',  # Not in results table.\n",
        "              'kitti/closest_vehicle_distance',  # Kitti-dist\n",
        "              'smallnorb/label_category',  # Not in results table.\n",
        "              'smallnorb/label_lighting',  # Not in results table.\n",
        "              'smallnorb/label_azimuth',  # Azim\n",
        "              'smallnorb/label_elevation',  # Elev\n",
        "              ]\n",
        "\n",
        "for tn in VTAB_TASKS:\n",
        "  assert tn not in TFDS_IMAGE_CLASSIFCATON_DATASETS, tn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "deuT4iMO4Nhf"
      },
      "outputs": [],
      "source": [
        "TFDS_BUILDERS_CACHE = {}\n",
        "\n",
        "def get_tfds_builder(tfds_name):\n",
        "  global TFDS_BUILDERS_CACHE\n",
        "  if tfds_name not in TFDS_BUILDERS_CACHE:\n",
        "    TFDS_BUILDERS_CACHE[tfds_name] = tfds.builder(tfds_name)\n",
        "    TFDS_BUILDERS_CACHE[tfds_name].download_and_prepare()\n",
        "  return TFDS_BUILDERS_CACHE[tfds_name]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jTOBlYBpWNMl"
      },
      "outputs": [],
      "source": [
        "def ids_str2ints(ids_str):\n",
        "  return [int(v) for v in ids_str.split('_')] if ids_str else []\n",
        "\n",
        "def ids_ints2str(ids_ints):\n",
        "  return '_'.join([str(v) for v in sorted(ids_ints)])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e88M0y8YVZrC"
      },
      "outputs": [],
      "source": [
        "AddPositionEmbs = models.AddPositionEmbs\n",
        "Encoder1DBlock = models.Encoder1DBlock\n",
        "VisionTransformer = models.VisionTransformer\n",
        "\n",
        "class ResidualAdapter(nn.Module):\n",
        "  adapter_dim: int\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, x):\n",
        "    hidden_dim = x.shape[-1]\n",
        "    y = nn.LayerNorm()(x)\n",
        "    y = nn.Dense(self.adapter_dim)(y)\n",
        "    y = nn.gelu(y)\n",
        "    # Default initalization.\n",
        "    # y = nn.Dense(hidden_dim)(y)\n",
        "    # Initialization from https://arxiv.org/pdf/1902.00751.pdf\n",
        "    # y = nn.Dense(hidden_dim, kernel_init=nn.initializers.normal(stddev=1e-3))(y)\n",
        "    # Zero Initialization so that added adapter does not change the representation.\n",
        "    y = nn.Dense(hidden_dim, kernel_init=jax.nn.initializers.zeros)(y)\n",
        "    return x + y  # Residual.\n",
        "\n",
        "# Modified from vision_transformer/vit_jax/models Encoder to add residual adapters.\n",
        "class Encoder(nn.Module):\n",
        "  num_layers: int\n",
        "  mlp_dim: int\n",
        "  num_heads: int\n",
        "  adapter_layers: str  # <MOD\n",
        "  adapter_dim: int  # MOD>\n",
        "  dropout_rate: float = 0.1\n",
        "  attention_dropout_rate: float = 0.1\n",
        "\n",
        "  @nn.compact\n",
        "  def __call__(self, inputs, *, train):\n",
        "    assert inputs.ndim == 3  # (batch, len, emb)\n",
        "\n",
        "    x = AddPositionEmbs(\n",
        "        posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.\n",
        "        name='posembed_input')(\n",
        "            inputs)\n",
        "    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)\n",
        "\n",
        "    # Input Encoder\n",
        "    adapter_layers_ids = ids_str2ints(self.adapter_layers)  # <MOD>\n",
        "    for lyr in range(self.num_layers):\n",
        "      if lyr in adapter_layers_ids:  # <MOD\n",
        "        x = ResidualAdapter(\n",
        "            adapter_dim=self.adapter_dim,\n",
        "            name=f'residual_adapter_{lyr}'\n",
        "            )(x)  # MOD>\n",
        "      x = Encoder1DBlock(\n",
        "          mlp_dim=self.mlp_dim,\n",
        "          dropout_rate=self.dropout_rate,\n",
        "          attention_dropout_rate=self.attention_dropout_rate,\n",
        "          name=f'encoderblock_{lyr}',\n",
        "          num_heads=self.num_heads)(\n",
        "              x, deterministic=not train)\n",
        "    encoded = nn.LayerNorm(name='encoder_norm')(x)\n",
        "    return encoded"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vvZ_4-kJ9Pt3"
      },
      "outputs": [],
      "source": [
        "def get_vit_filename(query):\n",
        "  df = checkpoint.get_augreg_df()\n",
        "  res = df.query(query).filename.unique()\n",
        "  assert len(res) == 1\n",
        "  return res[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0lvqd47g9ZsW"
      },
      "outputs": [],
      "source": [
        "USE_DROPOUT = False\n",
        "VIT_CONFIG_CACHE = {}\n",
        "\n",
        "def get_vit_config(query):\n",
        "  if query not in VIT_CONFIG_CACHE:\n",
        "    filename = get_vit_filename(query)\n",
        "    config = models_config.AUGREG_CONFIGS[filename.split('-')[0]].copy_and_resolve_references()\n",
        "    # Overwrite with custom Encoder.\n",
        "    config.unlock()\n",
        "    config.encoder = Encoder\n",
        "    config.transformer.adapter_layers = ''\n",
        "    config.transformer.adapter_dim = -1\n",
        "    if not USE_DROPOUT:\n",
        "      config.transformer.dropout_rate = 0.0\n",
        "      config.transformer.attention_dropout_rate = 0.0\n",
        "    config.lock()\n",
        "    VIT_CONFIG_CACHE[query] = config\n",
        "  return VIT_CONFIG_CACHE[query].copy_and_resolve_references()\n",
        "\n",
        "def get_max_num_layers(query):\n",
        "  config = get_vit_config(query)\n",
        "  return config.transformer.num_layers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lC8aZvQ_Osh5"
      },
      "outputs": [],
      "source": [
        "# Benchmarks used in the paper\n",
        "VIT_BENCHMARK = [\n",
        "  'imagenet2012',\n",
        "  'cifar100',\n",
        "  'cifar10',\n",
        "  ]\n",
        "VTAB_FULL_BENCHMARK = [\n",
        "  'caltech101',\n",
        "  # 'cifar100',  # Already added with VIT_BENCHMARK\n",
        "  'dtd',\n",
        "  'oxford_flowers102',\n",
        "  'oxford_iiit_pet',\n",
        "  'sun397',\n",
        "  'svhn_cropped',\n",
        "  'patch_camelyon',\n",
        "  'eurosat',\n",
        "  'resisc45',\n",
        "  'diabetic_retinopathy_detection/btgraham-300',\n",
        "  'clevr/count_cylinders',\n",
        "  'clevr/count_all',\n",
        "  'clevr/closest_object_distance',\n",
        "  'dmlab',\n",
        "  'dsprites/label_x_position',\n",
        "  'dsprites/label_orientation',\n",
        "  'kitti/closest_object_distance',\n",
        "  'kitti/count_vehicles',\n",
        "  'kitti/closest_vehicle_distance',\n",
        "  'smallnorb/label_category',\n",
        "  'smallnorb/label_lighting',\n",
        "  'smallnorb/label_azimuth',\n",
        "  'smallnorb/label_elevation',\n",
        "]\n",
        "CHARS_BENCHMARK = [\n",
        "  'emnist/digits',\n",
        "  'emnist/letters',\n",
        "  'kmnist',\n",
        "  'mnist',\n",
        "  'omniglot',\n",
        "  'cmaterdb/bangla',\n",
        "  'cmaterdb/devanagari',\n",
        "  'cmaterdb/telugu',\n",
        "  ]\n",
        "VDD_BENCHMARK = [\n",
        "  'visual_domain_decathlon/imagenet12',\n",
        "  'visual_domain_decathlon/svhn',\n",
        "  'visual_domain_decathlon/cifar100',\n",
        "  'visual_domain_decathlon/gtsrb',\n",
        "  'visual_domain_decathlon/daimlerpedcls',\n",
        "  'visual_domain_decathlon/omniglot',\n",
        "  'visual_domain_decathlon/ucf101',\n",
        "  'visual_domain_decathlon/aircraft',\n",
        "  'visual_domain_decathlon/dtd',\n",
        "  'visual_domain_decathlon/vgg-flowers',\n",
        "  ]\n",
        "VTAB_1K_BENCHMARK = [\n",
        "  'caltech101/1k',\n",
        "  'cifar100/1k',\n",
        "  'cifar10/1k',\n",
        "  'dtd/1k',\n",
        "  'oxford_flowers102/1k',\n",
        "  'oxford_iiit_pet/1k',\n",
        "  'sun397/1k',\n",
        "  'svhn_cropped/1k',\n",
        "  'patch_camelyon/1k',\n",
        "  'eurosat/1k',\n",
        "  'resisc45/1k',\n",
        "  'diabetic_retinopathy_detection/btgraham-300/1k',\n",
        "  'clevr/count_cylinders/1k',\n",
        "  'clevr/count_all/1k',\n",
        "  'clevr/closest_object_distance/1k',\n",
        "  'dmlab/1k',\n",
        "  'dsprites/label_x_position/1k',\n",
        "  'dsprites/label_orientation/1k',\n",
        "  'kitti/closest_object_distance/1k',\n",
        "  'kitti/count_vehicles/1k',\n",
        "  'kitti/closest_vehicle_distance/1k',\n",
        "  'smallnorb/label_category/1k',\n",
        "  'smallnorb/label_lighting/1k',\n",
        "  'smallnorb/label_azimuth/1',\n",
        "  'smallnorb/label_elevation/1k',\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZRKdUDAs86zr"
      },
      "outputs": [],
      "source": [
        "def set_continue_configs(exp_config):\n",
        "  if CONTINUED_FROM_STATE_DIR:\n",
        "    exp_config.load_rand_init = False\n",
        "    exp_config.load_vit_checkpoint = False\n",
        "    exp_config.load_experiment = True\n",
        "    exp_config.load_experiment_dir = CONTINUED_FROM_STATE_DIR"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P5_XVkhGhRlp"
      },
      "outputs": [],
      "source": [
        "DATASET_HPARAMS_KEYS_PRERFIX = 'ds_'\n",
        "OPTIMIZER_HPARAMS_KEYS_PRERFIX = 'opt_'\n",
        "\n",
        "def get_exp_config_ti3_chars():\n",
        "  exp_config = ConfigDict()\n",
        "  exp_config.experiment_name = EXPERIMENT_NAME\n",
        "  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR\n",
        "  exp_config.num_train_examples_between_validations_max = 51200  # 100 batches.\n",
        "  exp_config.num_validations_per_path_training = 5\n",
        "  exp_config.num_validation_examples_max = 5120  # 10 batches.\n",
        "  exp_config.batch_size = 512\n",
        "  exp_config.num_task_iters = 2\n",
        "  exp_config.num_samples_per_task = 8*8\n",
        "  exp_config.mutate_adapters = True\n",
        "  # Force finetune last layer norm that technically is part of the head.\n",
        "  exp_config.force_finetune_components = ['encoder_norm']\n",
        "  # Population policy params:\n",
        "  exp_config.policy_class = 'PPDecay'\n",
        "  exp_config.policy_kwargs = {}\n",
        "  # Scorer params:\n",
        "  exp_config.scorer_class = 'ScorerDecay'\n",
        "  exp_config.scorer_kwargs = dict(\n",
        "      base=1.0,\n",
        "      num_params=1_484_162,  # params in Ti/16 with 3 layers.\n",
        "      )\n",
        "\n",
        "  # Seed models params:\n",
        "  exp_config.load_rand_init = False\n",
        "  exp_config.load_vit_checkpoint = True\n",
        "  exp_config.load_vit_checkpoint_query = 'name==\"Ti/16\" and ds==\"i21k\" and aug==\"light1\" and wd==0.1 and sd==0.0'\n",
        "  exp_config.load_experiment = False\n",
        "  exp_config.load_experiment_dir = ''\n",
        "  set_continue_configs(exp_config)\n",
        "\n",
        "  # Hyperparameters:\n",
        "  exp_config.models_default_hparams = {\n",
        "      '_mu_': 0.1,\n",
        "      # Default num_classes has no effect since it is always overwritten or used\n",
        "      # for rand init models whose head is always replaced.\n",
        "      'num_classes': 1,\n",
        "      # Set to ids_ints2str(range(max_num_layers)) to activate all adapters.\n",
        "      'adapter_layers': '',\n",
        "      'num_layers': 3,\n",
        "      'adapter_dim': 32,\n",
        "      'opt_lr': 0.01,\n",
        "      'opt_lr_schedule': 'cosine',\n",
        "      'opt_lr_warmup_ratio': 0.1,\n",
        "      'opt_momentum': 0.9,\n",
        "      'opt_nesterov': False,\n",
        "      'ds_image_size': 32,\n",
        "      'ds_crop': True,\n",
        "      'ds_area_range_min': 0.05,\n",
        "      'ds_aspect_ratio_range_min': 0.75,\n",
        "      'ds_flip_left_right': True,\n",
        "      'ds_brightness_delta': 0.0,\n",
        "      'ds_contrast_delta': 0.0,\n",
        "      'ds_saturation_delta': 0.0,\n",
        "      'ds_hue_delta': 0.0,\n",
        "  }\n",
        "  exp_config.models_mutation_ranges = {}\n",
        "  exp_config.task_names = CHARS_BENCHMARK\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def get_exp_config_base_deca():\n",
        "  exp_config = ConfigDict()\n",
        "  exp_config.experiment_name = EXPERIMENT_NAME\n",
        "  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR\n",
        "  exp_config.num_train_examples_between_validations_max = 51200  # 200 batches.\n",
        "  exp_config.num_validations_per_path_training = 30\n",
        "  exp_config.num_validation_examples_max = 5120  # 20 batches.\n",
        "  exp_config.batch_size = 256\n",
        "  exp_config.num_task_iters = 2\n",
        "  exp_config.num_samples_per_task = 8*8\n",
        "  exp_config.mutate_adapters = True\n",
        "  exp_config.force_finetune_components = ['encoder_norm']\n",
        "  # Population policy params:\n",
        "  exp_config.policy_class = 'PPDecay'\n",
        "  exp_config.policy_kwargs = {}\n",
        "  # Scorer params:\n",
        "  exp_config.scorer_class = 'ScorerDecay'\n",
        "  exp_config.scorer_kwargs = dict(\n",
        "      base=1.0,\n",
        "      num_params=85_652_738,  # params in B/16\n",
        "      )\n",
        "  # Seed models params:\n",
        "  exp_config.load_rand_init = False\n",
        "  exp_config.load_vit_checkpoint = True\n",
        "  exp_config.load_vit_checkpoint_query = 'name==\"B/16\" and ds==\"i21k\" and aug==\"medium1\" and wd==0.1 and sd==0'\n",
        "  exp_config.load_experiment = False\n",
        "  exp_config.load_experiment_dir = ''\n",
        "  set_continue_configs(exp_config)\n",
        "\n",
        "  # Hyperparameters:\n",
        "  max_num_layers = get_max_num_layers(exp_config.load_vit_checkpoint_query)\n",
        "  exp_config.models_default_hparams = {\n",
        "      '_mu_': 0.1,\n",
        "      'num_classes': 1,\n",
        "      'adapter_layers': '',\n",
        "      'num_layers': max_num_layers,\n",
        "      'adapter_dim': 32,\n",
        "      'opt_lr': 0.01,\n",
        "      'opt_lr_schedule': 'cosine',\n",
        "      'opt_lr_warmup_ratio': 0.1,\n",
        "      'opt_momentum': 0.9,\n",
        "      'opt_nesterov': False,\n",
        "      'ds_image_size': 80,\n",
        "      'ds_crop': True,\n",
        "      'ds_area_range_min': 0.05,\n",
        "      'ds_aspect_ratio_range_min': 0.75,\n",
        "      'ds_flip_left_right': True,\n",
        "      'ds_brightness_delta': 0.0,\n",
        "      'ds_contrast_delta': 0.0,\n",
        "      'ds_saturation_delta': 0.0,\n",
        "      'ds_hue_delta': 0.0,\n",
        "  }\n",
        "  exp_config.models_mutation_ranges = {}\n",
        "  exp_config.task_names = VDD_BENCHMARK\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def get_exp_config_ti0_cmaterdb():\n",
        "  exp_config = ConfigDict()\n",
        "  exp_config.experiment_name = EXPERIMENT_NAME\n",
        "  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR\n",
        "  exp_config.num_train_examples_between_validations_max = 51200  # 100 batches.\n",
        "  exp_config.num_validations_per_path_training = 4\n",
        "  exp_config.num_validation_examples_max = 5120  # 10 batches.\n",
        "  exp_config.batch_size = 512\n",
        "  exp_config.num_task_iters = 2\n",
        "  exp_config.num_samples_per_task = 8*4\n",
        "  exp_config.mutate_adapters = False\n",
        "  exp_config.force_finetune_components = ['encoder_norm']\n",
        "  # Population policy params:\n",
        "  exp_config.policy_class = 'PPDecay'\n",
        "  exp_config.policy_kwargs = {}\n",
        "  # Scorer params:\n",
        "  exp_config.scorer_class = 'ScorerDecay'\n",
        "  exp_config.scorer_kwargs = dict(\n",
        "      base=1.0,\n",
        "      num_params=1_484_162,  # params in Ti/16 with 3 layers.\n",
        "      )\n",
        "\n",
        "  # Seed models params:\n",
        "  exp_config.load_rand_init = True\n",
        "  exp_config.load_vit_checkpoint = False\n",
        "  # The query is used to get the model configs even if the checkpoint is not loaded.\n",
        "  exp_config.load_vit_checkpoint_query = 'name==\"Ti/16\" and ds==\"i21k\" and aug==\"light1\" and wd==0.1 and sd==0.0'\n",
        "  exp_config.load_experiment = False\n",
        "  exp_config.load_experiment_dir = ''\n",
        "  set_continue_configs(exp_config)\n",
        "\n",
        "  # Hyperparameters:\n",
        "  exp_config.models_default_hparams = {\n",
        "      '_mu_': 0.2,\n",
        "      'num_classes': 1,\n",
        "      'adapter_layers': '',\n",
        "      'num_layers': 0,\n",
        "      'adapter_dim': 32,\n",
        "      'opt_lr': 0.01,\n",
        "      'opt_lr_schedule': 'cosine',\n",
        "      'opt_lr_warmup_ratio': 0.1,\n",
        "      'opt_momentum': 0.9,\n",
        "      'opt_nesterov': False,\n",
        "      'ds_image_size': 32,\n",
        "      'ds_crop': True,\n",
        "      'ds_area_range_min': 0.05,\n",
        "      'ds_aspect_ratio_range_min': 0.75,\n",
        "      'ds_flip_left_right': True,\n",
        "      'ds_brightness_delta': 0.0,\n",
        "      'ds_contrast_delta': 0.0,\n",
        "      'ds_saturation_delta': 0.0,\n",
        "      'ds_hue_delta': 0.0,\n",
        "  }\n",
        "  exp_config.models_mutation_ranges = {\n",
        "      'num_layers': list(range(0, 4)),\n",
        "  }\n",
        "  exp_config.task_names = [\n",
        "    'cmaterdb/bangla',\n",
        "    'cmaterdb/devanagari',\n",
        "    'private:cmaterdb/telugu',\n",
        "    ]\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def get_exp_config_large(benchmark_string_id):\n",
        "  exp_config = ConfigDict()\n",
        "  exp_config.experiment_name = EXPERIMENT_NAME\n",
        "  exp_config.experiments_root_dir = EXPERIMENTS_ROOT_DIR\n",
        "  # Cap to 1/10th of imagenet train set size to have similar ratio of exps reported in:\n",
        "  # https://arxiv.org/abs/2106.10270\n",
        "  exp_config.num_train_examples_between_validations_max = 128_116\n",
        "  exp_config.num_validations_per_path_training = 4\n",
        "  exp_config.num_validation_examples_max = 10_000\n",
        "  # Fit HBM memory: TPUv4 megacore=64, TPUv3=32.\n",
        "  exp_config.batch_size = 64\n",
        "  exp_config.num_task_iters = 1\n",
        "  # Assuming TPUv4 32 cores * 4 generations.\n",
        "  exp_config.num_samples_per_task = 32 * 4\n",
        "  exp_config.mutate_adapters = False\n",
        "  exp_config.force_finetune_components = ['encoder_norm']\n",
        "  # Population policy params:\n",
        "  exp_config.policy_class = 'PPDecay'\n",
        "  exp_config.policy_kwargs = {}\n",
        "  # Scorer params:\n",
        "  exp_config.scorer_class = 'ScorerDecay'\n",
        "  exp_config.scorer_kwargs = dict(\n",
        "      base=1.0,\n",
        "      num_params=303_303_682,  # Params in L/16\n",
        "      )\n",
        "  # Seed models params:\n",
        "  exp_config.load_rand_init = False\n",
        "  exp_config.load_vit_checkpoint = True\n",
        "  exp_config.load_vit_checkpoint_query = 'name==\"L/16\" and ds==\"i21k\" and aug==\"medium2\" and wd==0.03 and sd==0.1'\n",
        "  exp_config.load_experiment = False\n",
        "  exp_config.load_experiment_dir = ''\n",
        "  set_continue_configs(exp_config)\n",
        "\n",
        "  # Hyperparameters:\n",
        "  max_num_layers = get_max_num_layers(exp_config.load_vit_checkpoint_query)\n",
        "  exp_config.models_default_hparams = {\n",
        "      '_mu_': 0.2,\n",
        "      'num_classes': 1,\n",
        "      'adapter_layers': '',\n",
        "      'num_layers': max_num_layers,\n",
        "      'adapter_dim': 16,\n",
        "      'opt_lr': 0.01,\n",
        "      'opt_lr_schedule': 'cosine',\n",
        "      'opt_lr_warmup_ratio': 0.05,\n",
        "      'opt_momentum': 0.9,\n",
        "      'opt_nesterov': False,\n",
        "      'ds_image_size': 384,\n",
        "      'ds_crop': True,\n",
        "      'ds_area_range_min': 0.05,\n",
        "      'ds_aspect_ratio_range_min': 0.75,\n",
        "      'ds_flip_left_right': True,\n",
        "      'ds_brightness_delta': 0.0,\n",
        "      'ds_contrast_delta': 0.0,\n",
        "      'ds_saturation_delta': 0.0,\n",
        "      'ds_hue_delta': 0.0,\n",
        "  }\n",
        "  exp_config.models_mutation_ranges = {}\n",
        "  if benchmark_string_id == 'ViT large / ViT benchmark':\n",
        "    exp_config.task_names = VIT_BENCHMARK\n",
        "  elif benchmark_string_id == 'ViT large / VTAB-full benchmark':\n",
        "    exp_config.task_names = VTAB_FULL_BENCHMARK\n",
        "  elif benchmark_string_id == 'ViT large / VDD benchmark':\n",
        "    exp_config.task_names = VDD_BENCHMARK\n",
        "  elif benchmark_string_id == 'ViT large / Chars benchmark':\n",
        "    exp_config.task_names = CHARS_BENCHMARK\n",
        "  elif benchmark_string_id == 'ViT large / VTAB-1k benchmark':\n",
        "    exp_config.task_names = VTAB_1K_BENCHMARK\n",
        "  else:\n",
        "    assert False, f'Unknown benchmark: {benchmark_string_id}'\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_add_auto_tune(exp_config):\n",
        "  exp_config.models_mutation_ranges['adapter_dim'] = [8, 16, 32, 64, 128]\n",
        "  exp_config.models_mutation_ranges['opt_lr'] = [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]\n",
        "  exp_config.models_mutation_ranges['opt_lr_schedule'] = ['constant', 'cosine', 'restarts']\n",
        "  exp_config.models_mutation_ranges['opt_lr_warmup_ratio'] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4]\n",
        "  exp_config.models_mutation_ranges['opt_momentum'] = [None, 0.2, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99]\n",
        "  exp_config.models_mutation_ranges['opt_nesterov'] = [True, False]\n",
        "  exp_config.models_mutation_ranges['ds_image_size'] = [ 16 * i for i in (range(1, 1+int(384/16))) ]\n",
        "  exp_config.models_mutation_ranges['ds_area_range_min'] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]\n",
        "  exp_config.models_mutation_ranges['ds_aspect_ratio_range_min'] = [0.25, 0.5, 0.75, 1.0]\n",
        "  exp_config.models_mutation_ranges['ds_flip_left_right'] = [True, False]\n",
        "  exp_config.models_mutation_ranges['ds_brightness_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  exp_config.models_mutation_ranges['ds_contrast_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  exp_config.models_mutation_ranges['ds_saturation_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  exp_config.models_mutation_ranges['ds_hue_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_add_auto_tune_v2(exp_config):\n",
        "  exp_config.models_mutation_ranges['_mu_'] = [0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24, 0.26, 0.28, 0.30]\n",
        "  exp_config.models_mutation_ranges['opt_lr'] = [0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5]\n",
        "  exp_config.models_mutation_ranges['opt_lr_warmup_ratio'] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.3, 0.4]\n",
        "  exp_config.models_mutation_ranges['opt_momentum'] = [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 0.98, 0.99]\n",
        "  exp_config.models_mutation_ranges['opt_nesterov'] = [True, False]\n",
        "  exp_config.models_mutation_ranges['ds_area_range_min'] = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]\n",
        "  exp_config.models_mutation_ranges['ds_crop'] = [True, False]\n",
        "  exp_config.models_mutation_ranges['ds_aspect_ratio_range_min'] = [0.25, 0.5, 0.75, 1.0]\n",
        "  exp_config.models_mutation_ranges['ds_flip_left_right'] = [True, False]\n",
        "  exp_config.models_mutation_ranges['ds_brightness_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  exp_config.models_mutation_ranges['ds_contrast_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  exp_config.models_mutation_ranges['ds_saturation_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  exp_config.models_mutation_ranges['ds_hue_delta'] = [0.0, 0.01, 0.02, 0.05, 0.1, 0.2]\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_validate(exp_config):\n",
        "  for khp in exp_config.models_default_hparams:\n",
        "    if khp in exp_config.models_mutation_ranges:\n",
        "      assert exp_config.models_default_hparams[khp] \\\n",
        "          in exp_config.models_mutation_ranges[khp]\n",
        "\n",
        "def exp_config_set_size_scale(exp_config, base_percent:int):\n",
        "  exp_config.scorer_kwargs['base'] = float(base_percent) / 100.0\n",
        "  if 'num_layers' not in exp_config.models_mutation_ranges:\n",
        "    exp_config.models_mutation_ranges['num_layers'] = list(\n",
        "        range(1, exp_config.models_default_hparams['num_layers']+1))\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_set_baseline_common(exp_config):\n",
        "  parallelism = jax.local_device_count()\n",
        "  assert (int(exp_config.num_samples_per_task / parallelism) ==\n",
        "          exp_config.num_samples_per_task / parallelism)\n",
        "  exp_config.num_validations_per_path_training *= \\\n",
        "      exp_config.num_task_iters \\\n",
        "      * int(exp_config.num_samples_per_task/parallelism)\n",
        "  exp_config.num_task_iters = 1\n",
        "  exp_config.num_samples_per_task = parallelism\n",
        "  exp_config.models_mutation_ranges = {}\n",
        "  exp_config.policy_class = 'PPBaseline'\n",
        "  exp_config.policy_kwargs = {}\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_set_baseline_finetune_all(exp_config):\n",
        "  exp_config = exp_config_set_baseline_common(exp_config)\n",
        "  exp_config.mutate_adapters = False\n",
        "  exp_config.models_default_hparams['_mu_'] = 1.0\n",
        "  exp_config.models_default_hparams['adapter_layers'] = ''\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_set_baseline_freeze_bottom_layers(exp_config, num_layers:int):\n",
        "  exp_config = exp_config_set_baseline_common(exp_config)\n",
        "  max_num_layers = exp_config.models_default_hparams['num_layers']\n",
        "  assert max_num_layers >= num_layers\n",
        "  unfrozen_layers = [f'encoderblock_{id}' for id in range(num_layers, max_num_layers)]\n",
        "  exp_config.force_finetune_components = ['encoder_norm'] + unfrozen_layers\n",
        "  exp_config.mutate_adapters = False\n",
        "  exp_config.models_default_hparams['_mu_'] = 0.0\n",
        "  exp_config.models_default_hparams['adapter_layers'] = ''\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config\n",
        "\n",
        "def exp_config_set_baseline_adapters(exp_config, adapter_dim:int):\n",
        "  exp_config = exp_config_set_baseline_common(exp_config)\n",
        "  exp_config.force_finetune_components = ['encoder_norm']\n",
        "  exp_config.mutate_adapters = True\n",
        "  exp_config.models_default_hparams['_mu_'] = 0.0\n",
        "  max_num_layers = exp_config.models_default_hparams['num_layers']\n",
        "  exp_config.models_default_hparams['adapter_layers'] = ids_ints2str(\n",
        "      range(max_num_layers))\n",
        "  exp_config.models_default_hparams['adapter_dim'] = adapter_dim\n",
        "  exp_config_validate(exp_config)\n",
        "  return exp_config"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3MRsK4hvocq8"
      },
      "outputs": [],
      "source": [
        "def get_sample_images(image_size:int, batch_size:int):\n",
        "  return np.zeros((batch_size, image_size, image_size, 3))\n",
        "\n",
        "def get_sample_labels(batch_size:int):\n",
        "  return np.zeros(batch_size, dtype=np.int32)\n",
        "\n",
        "def get_sample_batch(image_size:int, batch_size:int):\n",
        "  return {'image': get_sample_images(image_size, batch_size),\n",
        "          'label': get_sample_labels(batch_size),}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0jPjzRlYYi0x"
      },
      "outputs": [],
      "source": [
        "def get_vit_checkpoint(image_size, query):\n",
        "  filename = get_vit_filename(query)\n",
        "  config = get_vit_config(query)\n",
        "  model = VisionTransformer(**config, num_classes=2)  # num_classes unsed.\n",
        "  init_params = copy.deepcopy(jax.device_get(\n",
        "      model.init(jax.random.PRNGKey(0),\n",
        "                 get_sample_images(image_size=image_size,\n",
        "                                   batch_size=1),\n",
        "                 train=USE_DROPOUT)['params']))\n",
        "  params = checkpoint.load_pretrained(\n",
        "    pretrained_path=f'gs://vit_models/augreg/{filename}.npz',\n",
        "    init_params=init_params,\n",
        "    model_config=config)\n",
        "  return params\n",
        "\n",
        "def get_vit_checkpoint_mapped(image_size, query):\n",
        "  params = get_vit_checkpoint(image_size, query)\n",
        "  params = params_model_to_comps(params)\n",
        "  return params\n",
        "\n",
        "def get_reshaped_posembed_component(image_size, query):\n",
        "  params = get_vit_checkpoint_mapped(image_size, query)['posembed_input']\n",
        "  return Component(name='posembed_input',\n",
        "                   params=params,\n",
        "                   train_locks=[NOT_TRAINABLE])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G_Xsw_tdLIRC"
      },
      "outputs": [],
      "source": [
        "# Parameter mapping.\n",
        "TRANSFORMER_KEYS = set(\n",
        "    ['encoder_norm', 'posembed_input'] + \\\n",
        "    [f'encoderblock_{k}' for k in range(24)])\n",
        "\n",
        "def params_model_to_comps(params):\n",
        "  global TRANSFORMER_KEYS\n",
        "  TRANSFORMER_KEYS.update(params['Transformer'].keys())\n",
        "  new_params = {}\n",
        "  for k in params.keys():\n",
        "    if k == 'Transformer':\n",
        "      t_params = params[k]\n",
        "      for t_k in t_params.keys():\n",
        "        new_params[t_k] = t_params[t_k]\n",
        "    else:\n",
        "      new_params[k] = params[k]\n",
        "  params = flax.core.freeze(new_params)\n",
        "  return flax.core.freeze(params)\n",
        "\n",
        "def params_comps_to_model(params):\n",
        "  params = params.unfreeze()\n",
        "  params['Transformer'] = {}\n",
        "  keys = list(params.keys())\n",
        "  assert len(TRANSFORMER_KEYS) != 0\n",
        "  for k in keys:\n",
        "    if k in TRANSFORMER_KEYS:\n",
        "      params['Transformer'][k] = params.pop(k)\n",
        "  return flax.core.freeze(params)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Ktk16O9PhYZ"
      },
      "outputs": [],
      "source": [
        "def get_model_kwargs(hparams, exp_config):\n",
        "  # Validate adapters params.\n",
        "  for v in ids_str2ints(hparams['adapter_layers']):\n",
        "    assert v < hparams['num_layers']\n",
        "  return {\n",
        "        'num_classes': int(hparams['num_classes']),\n",
        "        'num_layers': int(hparams['num_layers']),\n",
        "        'image_size': int(hparams['ds_image_size']),\n",
        "        'adapter_layers': str(hparams['adapter_layers']),\n",
        "        'adapter_dim': int(hparams['adapter_dim']),\n",
        "        'query': str(exp_config.load_vit_checkpoint_query),\n",
        "    }\n",
        "\n",
        "def get_vit_model(num_classes, num_layers, adapter_layers, adapter_dim, query):\n",
        "  config = get_vit_config(query)\n",
        "  config['transformer']['num_layers'] = num_layers\n",
        "  config['transformer']['adapter_layers'] = adapter_layers\n",
        "  config['transformer']['adapter_dim'] = adapter_dim\n",
        "  config = FrozenConfigDict(config)\n",
        "  model = VisionTransformer(**config, num_classes=num_classes)\n",
        "  return model\n",
        "\n",
        "def get_vit_model_and_params(\n",
        "    num_classes, num_layers, image_size, adapter_layers, adapter_dim, query,\n",
        "    rng_key=0):\n",
        "  model = get_vit_model(\n",
        "      num_classes, num_layers, adapter_layers, adapter_dim, query)\n",
        "  init_params = copy.deepcopy(jax.device_get(\n",
        "      model.init(\n",
        "          jax.random.PRNGKey(rng_key),\n",
        "          get_sample_images(image_size=image_size, batch_size=1),\n",
        "          train=USE_DROPOUT)['params']))\n",
        "  return model, init_params\n",
        "\n",
        "def get_vit_params_mapped(**kwargs):\n",
        "  model, init_params = get_vit_model_and_params(**kwargs)\n",
        "  init_params = params_model_to_comps(init_params)\n",
        "  return init_params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-F8s09QiK8ri"
      },
      "outputs": [],
      "source": [
        "def format_params(a, b):\n",
        "  params = a.copy(b)\n",
        "  assert len(params) == len(a) + len(b)  # Dicts should not overlap.\n",
        "  params = params_comps_to_model(params)\n",
        "  return params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FsuttUqHvYGr"
      },
      "outputs": [],
      "source": [
        "def get_optimizer(\n",
        "    lr: float,\n",
        "    lr_schedule: str,\n",
        "    lr_warmup_ratio: float,\n",
        "    momentum: float,\n",
        "    nesterov: bool,\n",
        "    num_train_batches_between_validations: int,\n",
        "    num_validations_per_path_training: int,\n",
        "    ):\n",
        "  if lr_schedule == 'constant':\n",
        "    # Divide by 2 so that average lr is the same as other types.\n",
        "    learning_rate = 0.5 * lr\n",
        "  elif lr_schedule == 'cosine':\n",
        "    train_steps = int(num_train_batches_between_validations\n",
        "                      * num_validations_per_path_training)\n",
        "    learning_rate = optax.warmup_cosine_decay_schedule(\n",
        "        init_value=lr/100.0,\n",
        "        peak_value=lr,\n",
        "        warmup_steps=int(lr_warmup_ratio * train_steps),\n",
        "        decay_steps=train_steps)\n",
        "  elif lr_schedule == 'restarts':\n",
        "    train_steps = num_train_batches_between_validations\n",
        "    repeats = num_validations_per_path_training\n",
        "    kwargs = dict(\n",
        "        init_value=lr/100.0,\n",
        "        peak_value=lr,\n",
        "        warmup_steps=int(lr_warmup_ratio * train_steps),\n",
        "        decay_steps=train_steps,\n",
        "    )\n",
        "    kwargs = [kwargs] * repeats\n",
        "    learning_rate = optax.sgdr_schedule(kwargs)\n",
        "  else:\n",
        "    assert False, f'Invalid lr schedule: {lr_schedule}'\n",
        "\n",
        "  return optax.chain(\n",
        "      optax.clip_by_global_norm(1.0),\n",
        "      optax.sgd(\n",
        "          learning_rate=learning_rate,\n",
        "          momentum=momentum,\n",
        "          nesterov=nesterov,\n",
        "          accumulator_dtype=jnp.bfloat16))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kqnhOnk5jbnm"
      },
      "outputs": [],
      "source": [
        "def get_default_splits(tfds_name):\n",
        "  info = get_tfds_builder(tfds_name).info\n",
        "  splits = list(info.splits.keys())\n",
        "  assert 'train' in splits, splits\n",
        "  splits.remove('train')\n",
        "  used_percent = 0\n",
        "  slice_percent = 5\n",
        "  pp = {}\n",
        "  for k in ['test', 'validation']:\n",
        "    if k in splits:\n",
        "      pp[k] = k\n",
        "      splits.remove(k)\n",
        "    else:\n",
        "      pp[k] = f'train[{used_percent}%:{used_percent+slice_percent}%]'\n",
        "      used_percent += slice_percent\n",
        "  pp['train'] = f'train[{used_percent}%:]'\n",
        "  return pp\n",
        "\n",
        "def get_dataset_and_splits(tfds_name: str):\n",
        "  vtab_class = None\n",
        "  if tfds_name in ['imagenet_v2', 'cifar10_1']:\n",
        "    assert False,  f\"{tfds_name} used as validation set for other tasks.\"\n",
        "\n",
        "  if tfds_name == 'imagenet2012':\n",
        "    dataset = {\n",
        "        'train':'imagenet2012', 'validation':'imagenet_v2', 'test':'imagenet2012'}\n",
        "    splits = {\n",
        "        'train':'train', 'validation':'test', 'test':'validation'}\n",
        "  elif tfds_name == 'cifar100':\n",
        "    dataset = tfds_name\n",
        "    splits = {\n",
        "        'train':'train[:98%]', 'validation':'train[98%:]', 'test':'test'}\n",
        "  elif tfds_name == 'cifar10':\n",
        "    dataset = {\n",
        "        'train':'cifar10', 'validation':'cifar10_1', 'test':'cifar10'}\n",
        "    splits = {\n",
        "        'train':'train', 'validation':'test', 'test':'test'}\n",
        "  elif tfds_name.startswith('visual_domain_decathlon/'):\n",
        "    dataset = tfds_name\n",
        "    # Test has no labels, split validation in half.\n",
        "    splits =  {\n",
        "        'train':'train', 'validation':'validation[:50%]', 'test':'validation[50%:]'}\n",
        "  elif tfds_name.startswith('cmaterdb/'):\n",
        "    dataset = tfds_name\n",
        "    # Increase size of validation set due to small dataset size.\n",
        "    splits =  {\n",
        "        'train':'train[20%:]', 'validation':'train[:20%]', 'test':'test'}\n",
        "  elif tfds_name == 'omniglot':\n",
        "    # Test has no labels, and missing validation, use additional splits.\n",
        "    dataset = tfds_name\n",
        "    splits = {'train':'train', 'validation':'small1', 'test':'small2'}\n",
        "  elif tfds_name in VTAB_TASKS or (\n",
        "      tfds_name.endswith('/1k') and tfds_name.replace('/1k', '') in VTAB_TASKS):\n",
        "    is_vtab_1k = tfds_name.endswith('/1k')\n",
        "    tfds_name = tfds_name.replace('/1k', '')\n",
        "    registry_name = {\n",
        "        'diabetic_retinopathy_detection/btgraham-300': 'diabetic_retinopathy',\n",
        "        'svhn_cropped': 'svhn',\n",
        "        'cifar100': 'cifar',\n",
        "        'cifar10': 'cifar',\n",
        "    }.get(tfds_name, tfds_name.split('/')[0])\n",
        "    args = {\n",
        "        'clevr/count_all': ('count_all',),\n",
        "        'clevr/count_cylinders': ('count_cylinders',),\n",
        "        'clevr/closest_object_distance': ('closest_object_distance',),\n",
        "        'dsprites/label_x_position': ('label_x_position',),\n",
        "        'dsprites/label_orientation': ('label_orientation',),\n",
        "        'kitti/closest_object_distance': ('closest_object_distance',),\n",
        "        'kitti/count_vehicles': ('count_vehicles',),\n",
        "        'kitti/closest_vehicle_distance': ('closest_vehicle_distance',),\n",
        "        'smallnorb/label_category': ('label_category',),\n",
        "        'smallnorb/label_lighting': ('label_lighting',),\n",
        "        'smallnorb/label_azimuth': ('label_azimuth',),\n",
        "        'smallnorb/label_elevation': ('label_elevatio',),\n",
        "        'cifar100': (100,),\n",
        "        'cifar10': (10,),\n",
        "    }.get(tfds_name, ())\n",
        "    vtab_class = task_adapt_registry.Registry.lookup(\n",
        "        f'data.{registry_name}')(*args)\n",
        "    vtab_splits = vtab_class._tfds_splits\n",
        "    dataset = {\n",
        "        'caltech101': 'caltech101:3.*.*',\n",
        "        'dtd': 'dtd:3.*.*',\n",
        "        'oxford_flowers102': 'oxford_flowers102:2.*.*',\n",
        "        'oxford_iiit_pet': 'oxford_iiit_pet:3.*.*',\n",
        "        'sun397': 'sun397/tfds:4.*.',\n",
        "        'svhn': 'svhn_cropped:3.*.*',\n",
        "        'patch_camelyon': 'patch_camelyon:2.*.*',\n",
        "        'eurosat': 'eurosat/rgb:2.*.*',\n",
        "        'resisc45': 'resisc45:3.*.*',\n",
        "        'diabetic_retinopathy': 'diabetic_retinopathy_detection/btgraham-300:3.*.*',\n",
        "        'clevr': 'clevr:3.*.*',\n",
        "        'dmlab': 'dmlab:2.0.1',\n",
        "        'dsprites': 'dsprites:2.*.*',\n",
        "        'kitti': 'kitti:3.2.0',\n",
        "        'smallnorb': 'smallnorb:2.*.*',\n",
        "        'cifar' : 'cifar100:3.*.*' if tfds_name == 'cifar100' else 'cifar10:3.*.*',\n",
        "    }[registry_name]\n",
        "    if is_vtab_1k:\n",
        "      splits =  {\n",
        "          'train': str(vtab_splits['train800']),\n",
        "          'validation': str(vtab_splits['val200']),\n",
        "          'test': str(vtab_splits['test']),\n",
        "          }\n",
        "    else:\n",
        "      splits =  {\n",
        "          'train': str(vtab_splits['train']),\n",
        "          'validation': str(vtab_splits['val']),\n",
        "          'test': str(vtab_splits['test']),\n",
        "          }\n",
        "  else:\n",
        "    dataset = tfds_name\n",
        "    splits = get_default_splits(tfds_name)\n",
        "  return dataset, splits, vtab_class\n",
        "\n",
        "class Task():\n",
        "  def __init__(self, name, exp_config):\n",
        "    self.exp_config = exp_config\n",
        "    if name.startswith(NOT_TRAINABLE):\n",
        "      self.name = name\n",
        "      self.private = False\n",
        "      return\n",
        "\n",
        "    if name.startswith('private:'):\n",
        "      _, name = name.split('private:')\n",
        "      self.private = True\n",
        "    else:\n",
        "      self.private = False\n",
        "\n",
        "    self.dataset, self.splits, self.vtab_class = get_dataset_and_splits(name)\n",
        "    self.name = name\n",
        "    if self.vtab_class:\n",
        "      self.num_classes = self.vtab_class.get_num_classes()\n",
        "    else:\n",
        "      self.num_classes = self.get_builder('train').info.features['label'].num_classes\n",
        "    num_train_examples = self.get_builder('train').info.splits[self.splits['train']].num_examples\n",
        "    self.train_batch_size = exp_config.batch_size\n",
        "    self.num_train_batches_between_validations = math.ceil(\n",
        "        min(num_train_examples,\n",
        "            exp_config.num_train_examples_between_validations_max)\n",
        "        / self.train_batch_size)\n",
        "    self.cache_train = num_train_examples < min(100_000, (\n",
        "        exp_config.num_validations_per_path_training\n",
        "        * self.num_train_batches_between_validations\n",
        "        * self.train_batch_size))\n",
        "\n",
        "    num_validation_examples_tot = self.get_builder('validation').info.splits[self.splits['validation']].num_examples\n",
        "    if exp_config.num_validation_examples_max <= num_validation_examples_tot:\n",
        "      self.validation_batch_size = exp_config.batch_size\n",
        "      self.num_validation_batches = math.floor(\n",
        "          exp_config.num_validation_examples_max / self.validation_batch_size)\n",
        "    else:\n",
        "      # Adjust batch_size and num_batches to cover the smaller validation sets.\n",
        "      self.num_validation_batches = math.ceil(\n",
        "          num_validation_examples_tot / exp_config.batch_size)\n",
        "      self.validation_batch_size = math.floor(\n",
        "          num_validation_examples_tot / self.num_validation_batches)\n",
        "      assert num_validation_examples_tot >= (self.num_validation_batches*self.validation_batch_size)\n",
        "    self.num_validation_examples = self.num_validation_batches * self.validation_batch_size\n",
        "\n",
        "    print(f'Task: {self.name}')\n",
        "    print(f'  Train batches between validations: {self.num_train_batches_between_validations}')\n",
        "    print(f'  Validation batches: {self.num_validation_batches}')\n",
        "    print(f'  Validation batch size: {self.validation_batch_size}')\n",
        "    print(f'  Dataset {{\\n{self.dataset}}}')\n",
        "    print(f'  Splits {{\\n{self.splits}}}')\n",
        "\n",
        "\n",
        "  def get_builder(self, mode):\n",
        "    if type(self.dataset) == str:\n",
        "      return get_tfds_builder(self.dataset)\n",
        "    return get_tfds_builder(self.dataset[mode])\n",
        "\n",
        "  def __str__(self):\n",
        "    return f'Task_{self.name}'\n",
        "\n",
        "  def is_trainable(self):\n",
        "    return not self.name.startswith(NOT_TRAINABLE)\n",
        "\n",
        "  def is_private(self):\n",
        "    return self.private\n",
        "\n",
        "  def get_ds(self, mode, hparams):\n",
        "    data = self.get_builder(mode).as_dataset(\n",
        "        split=self.splits[mode],\n",
        "        shuffle_files=mode=='train')\n",
        "\n",
        "    def _pp(data):\n",
        "      im = data['image']\n",
        "      im = tf.cast(im, tf.float32)\n",
        "      # Must have 3 channels.\n",
        "      if im.shape[-1] == 1:\n",
        "        im = tf.squeeze(tf.stack([im] * 3, -1), axis=-2)\n",
        "      assert im.shape[-1] == 3\n",
        "      # Values in range [-1 , 1]\n",
        "      im = im / 127.5 - 1\n",
        "\n",
        "      if mode == 'train':\n",
        "        if hparams['ds_crop'] and hparams['ds_area_range_min'] < 1.0:\n",
        "          channels = im.shape[-1]\n",
        "          begin, size, _ = tf.image.sample_distorted_bounding_box(\n",
        "              tf.shape(im),\n",
        "              tf.zeros([0, 0, 4], tf.float32),\n",
        "              aspect_ratio_range=[hparams['ds_aspect_ratio_range_min'],\n",
        "                                  1.0/hparams['ds_aspect_ratio_range_min']],\n",
        "              area_range=[hparams['ds_area_range_min'], 1.0],\n",
        "              # Overlap with bounding box, the bounding box should anyway\n",
        "              # default defaults to whole image in this case.\n",
        "              min_object_covered=0,\n",
        "              use_image_if_no_bounding_boxes=True)\n",
        "          im = tf.slice(im, begin, size)\n",
        "          # Restore the depth-dimension lost by the above operation.\n",
        "          im.set_shape([None, None, channels])\n",
        "        if hparams['ds_flip_left_right']:\n",
        "          if tf.random.uniform(shape=[]) > 0.5:\n",
        "            im = tf.image.flip_left_right(im)\n",
        "        if hparams['ds_brightness_delta'] > 0.0:\n",
        "          im = tf.image.random_brightness(\n",
        "              im, max_delta=hparams['ds_brightness_delta'])\n",
        "        if hparams['ds_contrast_delta'] > 0.0:\n",
        "          im = tf.image.random_contrast(\n",
        "              im, lower=1 - hparams['ds_contrast_delta'],\n",
        "              upper=1 + hparams['ds_contrast_delta'])\n",
        "        if hparams['ds_saturation_delta'] > 0.0:\n",
        "          im = tf.image.random_saturation(\n",
        "              im, lower=1 - hparams['ds_saturation_delta'],\n",
        "              upper=1 + hparams['ds_saturation_delta'])\n",
        "        if hparams['ds_hue_delta'] > 0.0:\n",
        "          im = tf.image.random_hue(im, max_delta=hparams['ds_hue_delta'])\n",
        "\n",
        "      im = tf.image.resize(im, [hparams['ds_image_size'],\n",
        "                                hparams['ds_image_size']])\n",
        "      im = tf.clip_by_value(im, -1, 1)\n",
        "\n",
        "      return {'image': im, 'label': data['label']}\n",
        "\n",
        "    if mode == 'validation':\n",
        "      data = data.take(self.num_validation_examples)\n",
        "    if mode == 'validation' or (mode == 'train' and self.cache_train):\n",
        "      data = data.cache()\n",
        "    if mode != 'test':\n",
        "      data = data.repeat()\n",
        "    if self.vtab_class and self.vtab_class._base_preprocess_fn:\n",
        "      data = data.map(self.vtab_class._base_preprocess_fn, tf.data.AUTOTUNE)\n",
        "    data = data.map(_pp, tf.data.AUTOTUNE)\n",
        "    if mode == 'train':\n",
        "      batch_size = self.train_batch_size\n",
        "    else:\n",
        "      batch_size = self.validation_batch_size\n",
        "    data = data.batch(batch_size)\n",
        "    if mode == 'train':\n",
        "      data = data.shuffle(10)\n",
        "    return tfds.as_numpy(data.prefetch(tf.data.AUTOTUNE))\n",
        "\n",
        "def get_task_factory_fn(exp_config):\n",
        "  def get_task(task_name):\n",
        "    return Task(name=task_name, exp_config=exp_config)\n",
        "  return get_task\n",
        "\n",
        "NOT_TRAINABLE = 'NOT_TRAINABLE'\n",
        "not_trainable = Task(NOT_TRAINABLE, None)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2A3aZXSWYwCd"
      },
      "outputs": [],
      "source": [
        "def get_num_params(params):\n",
        "  return sum(jax.tree_flatten(\n",
        "      jax.tree_map(lambda p: np.prod(p.shape), params)\n",
        "      )[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rzxoZ4rdQZcA"
      },
      "outputs": [],
      "source": [
        "# Convert frozend dict of params to a list of components.\n",
        "def params2comps(params, train_locks , name=None):\n",
        "  components = []\n",
        "  for k in params:\n",
        "    if name is None or name == k:\n",
        "      c = Component(name=k, params=params[k], train_locks=train_locks)\n",
        "      components.append(c)\n",
        "  return components\n",
        "\n",
        "def params2comp_names(params):\n",
        "  return list(params.keys())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "97GXYGUJN8yY"
      },
      "outputs": [],
      "source": [
        "def is_local(params):\n",
        "  for p in jax.tree_util.tree_leaves(params):\n",
        "    if type(p) != np.ndarray:\n",
        "      assert issubclass(type(p), jnp.DeviceArray), type(p)\n",
        "      return False\n",
        "  return True\n",
        "\n",
        "def check_is_local(params):\n",
        "  for p in jax.tree_util.tree_leaves(params):\n",
        "    assert type(p) == np.ndarray, type(p)\n",
        "\n",
        "def check_is_on_device(params, device):\n",
        "  for p in jax.tree_util.tree_leaves(params):\n",
        "    assert p.device_buffer.device() == device, device"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DNxjDX13_dm_"
      },
      "outputs": [],
      "source": [
        "def fingerprint_params(params):\n",
        "  return np.sum(np.array(jax.tree_leaves(jax.tree_map(jnp.sum, params))))\n",
        "\n",
        "class Component():\n",
        "  counter = 0\n",
        "\n",
        "  def reset_globals():\n",
        "    Component.counter = 0\n",
        "\n",
        "  def __init__(self, name:str, params, train_locks:set, opt_state=None):\n",
        "    self.name = name\n",
        "    if is_local(params):\n",
        "      self.params = params\n",
        "    else:\n",
        "      self.params = jax.device_get(params)\n",
        "    check_is_local(self.params)\n",
        "    self.opt_state = opt_state\n",
        "    if self.opt_state is not None:\n",
        "      check_is_local(self.opt_state)\n",
        "    self.num_params = None\n",
        "    self.train_locks = set(train_locks)\n",
        "    self.id = Component.counter\n",
        "    Component.counter += 1\n",
        "\n",
        "  def __str__(self):\n",
        "    rtn = f'Component: {self.id}\\n  Name: {self.name}'\n",
        "    rtn += f'\\n  Train locks: {self.train_locks}'\n",
        "    rtn += f'\\n  Fingerprint: {self.fingerprint()}'\n",
        "    rtn += f'\\n  Num params: {self.get_num_params()}'\n",
        "    return rtn\n",
        "\n",
        "  def get_num_params(self):\n",
        "    if self.num_params is None:\n",
        "      self.num_params = get_num_params(self.params)\n",
        "    return self.num_params\n",
        "\n",
        "  def fingerprint(self):\n",
        "    return fingerprint_params(self.params)\n",
        "\n",
        "  def is_trainable(self):\n",
        "    return len(self.train_locks) == 0\n",
        "\n",
        "  def clone(self):\n",
        "    check_is_local(self.params)\n",
        "    return Component(name=self.name,\n",
        "                     params=copy.deepcopy(self.params),\n",
        "                     train_locks=set(),\n",
        "                     opt_state=copy.deepcopy(self.opt_state))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2FrDHFPU6NV-"
      },
      "outputs": [],
      "source": [
        "class ObjectCache():\n",
        "  def __init__(self, factory_fn, max_size=None):\n",
        "    self.factory_fn = factory_fn\n",
        "    self.cache = {}\n",
        "    self.max_size = max_size\n",
        "\n",
        "  def __call__(self, *args, **kwargs):\n",
        "    assert not args\n",
        "    key = json.dumps(kwargs, sort_keys=True)\n",
        "    if key not in self.cache:\n",
        "      if self.max_size and self.max_size <= len(self.cache):\n",
        "        rm_key = random.choice(list(self.cache.keys()))\n",
        "        print(f'Removed from cache: {self.factory_fn.__name__}({rm_key})  [cache size {len(self.cache)}]')\n",
        "        rm_obj = self.cache.pop(rm_key)\n",
        "        del rm_obj\n",
        "      self.cache[key] = self.factory_fn(**kwargs)\n",
        "      if VERBOSE:\n",
        "        print(f'Added to cache: {self.factory_fn.__name__}({key})  [cache size {len(self.cache)}]')\n",
        "    else:\n",
        "      if VERBOSE:\n",
        "        print(f'Cache hit: {self.factory_fn.__name__}({key})  [cache size {len(self.cache)}]')\n",
        "    return self.cache[key]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kwhVlaqRHp-0"
      },
      "outputs": [],
      "source": [
        "def incremental_mutation(value, values_list:list):\n",
        "  assert value in values_list, f'{value} not in {values_list}'\n",
        "  idx = values_list.index(value)\n",
        "  idx += 1 if np.random.uniform() < 0.5 else -1\n",
        "  idx = max(0, min(len(values_list) - 1, idx))\n",
        "  return values_list[idx]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8PU5ffvd_gC9"
      },
      "outputs": [],
      "source": [
        "class Path():\n",
        "  def reset_globals(exp_config):\n",
        "    Path.exp_config = exp_config\n",
        "    Path.counter = 0\n",
        "    Path.paths = []\n",
        "    Path.scorer = None  # To be set to scorer of choice during init of exp.\n",
        "    # Cache output of functions calls with same args.\n",
        "    Path.tasks = ObjectCache(get_task_factory_fn(exp_config))\n",
        "    Path.posembed_components = ObjectCache(get_reshaped_posembed_component)\n",
        "    Path.optimizers = ObjectCache(get_optimizer)\n",
        "    Path.models = ObjectCache(get_vit_model)\n",
        "    Path.init_params = ObjectCache(get_vit_params_mapped, max_size=1)\n",
        "\n",
        "  def __init__(self, hparams, components, parent, task:Task):\n",
        "    self.components = components\n",
        "    self.id = Path.counter\n",
        "    Path.counter += 1\n",
        "    self.task = task\n",
        "    self.parent = parent\n",
        "    self.hparams = hparams\n",
        "    self.metrics = {\n",
        "        'offsprings': 0,\n",
        "        'offsprings_per_task': json.dumps({}),\n",
        "        'reloads': 0,\n",
        "        'generation': 0 if parent is None else parent.metrics['generation'] + 1,\n",
        "        'private': task.is_private(),\n",
        "    }\n",
        "    self.model = Path.models(\n",
        "        num_classes = int(hparams['num_classes']),\n",
        "        num_layers = int(hparams['num_layers']),\n",
        "        adapter_layers = str(hparams['adapter_layers']),\n",
        "        adapter_dim = int(hparams['adapter_dim']),\n",
        "        query = str(self.exp_config.load_vit_checkpoint_query))\n",
        "    Path.paths.append(self)\n",
        "\n",
        "  def __str__(self):\n",
        "    rtn = f'Path: {self.id}'\n",
        "    rtn += f'\\n  Components: {[c.id for c in self.components]}'\n",
        "    if self.parent:\n",
        "      rtn += f'\\n  Parent: {self.parent.id}'\n",
        "    rtn += f'\\n  Task: {self.task.name}'\n",
        "    rtn += f'\\n  Total Parameters: {get_num_params(self.get_all_params())}'\n",
        "    rtn += f'\\n  Accounted params: {self.accounted_num_params()}'\n",
        "    for k,v in self.hparams.items():\n",
        "      rtn += f'\\n    {k}: {v}'\n",
        "    for k,v in self.metrics.items():\n",
        "      rtn += f'\\n    {k}: {v}'\n",
        "    rtn += f'\\n  Score: {self.score()}'\n",
        "    return rtn\n",
        "\n",
        "  def is_trainable(self):\n",
        "    return self.task.is_trainable()\n",
        "\n",
        "  def is_private(self):\n",
        "    return self.task.is_private()\n",
        "\n",
        "  def score(self):\n",
        "    return Path.scorer.score(self)\n",
        "\n",
        "  def get_all_params(self):\n",
        "    params = {}\n",
        "    for c in self.components:\n",
        "      params[c.name] = c.params\n",
        "    return flax.core.freeze(params)\n",
        "\n",
        "  def get_trainable_params(self):\n",
        "    params = {}\n",
        "    for c in self.components:\n",
        "      if c.is_trainable():\n",
        "        params[c.name] = c.params\n",
        "    return flax.core.freeze(params)\n",
        "\n",
        "  def get_fixed_params(self):\n",
        "    params = {}\n",
        "    for c in self.components:\n",
        "      if not c.is_trainable():\n",
        "        params[c.name] = c.params\n",
        "    return flax.core.freeze(params)\n",
        "\n",
        "  def update_trainable(self, trained_params, opt_state):\n",
        "    assert len(trained_params.keys()) == len(opt_state.keys())\n",
        "    trainable_count = 0\n",
        "    for c in self.components:\n",
        "      if c.is_trainable():\n",
        "        trainable_count += 1\n",
        "        assert c.name in trained_params.keys()\n",
        "        assert c.name in opt_state.keys()\n",
        "        c.params = trained_params[c.name]\n",
        "        c.opt_state = opt_state[c.name]\n",
        "        check_is_local(c.params)\n",
        "        check_is_local(c.opt_state)\n",
        "    assert len(trained_params.keys()) == trainable_count, f'{len(trained_params.keys())} {trainable_count}'\n",
        "\n",
        "  def accounted_num_params(self):\n",
        "    rtn = 0\n",
        "    for c in self.components:\n",
        "      tl = copy.copy(c.train_locks)\n",
        "      assert type(tl) is set\n",
        "      tl.add(self.task.name)\n",
        "      if NOT_TRAINABLE in tl:\n",
        "        tl.remove(NOT_TRAINABLE)\n",
        "      if len(tl) == 0:\n",
        "        return np.nan\n",
        "      rtn += c.get_num_params() / len(tl)\n",
        "    return rtn\n",
        "\n",
        "  def clone(\n",
        "      self,\n",
        "      task:Task,\n",
        "      ds_hparams,\n",
        "      policy):\n",
        "    exp_config = Path.exp_config\n",
        "    assert exp_config == task.exp_config\n",
        "    comps = []\n",
        "    new_hparams = copy.deepcopy(self.hparams)\n",
        "    new_hparams['num_classes'] = task.num_classes\n",
        "    # Overwrite dataset hparams with those sampled for the generation batch.\n",
        "    new_hparams.update(ds_hparams)\n",
        "\n",
        "    def get_component_ref(c, clone=False):\n",
        "      if c.is_trainable() or clone:\n",
        "        # Clone trainable component.\n",
        "        return c.clone()\n",
        "      # Refer to frozen component.\n",
        "      return c\n",
        "\n",
        "    for k in sorted(exp_config.models_mutation_ranges):\n",
        "      if (policy.do_mutate(new_hparams['_mu_']) and\n",
        "          (k in ['_mu_', 'num_layers', 'adapter_dim']\n",
        "            or k.startswith(OPTIMIZER_HPARAMS_KEYS_PRERFIX))):\n",
        "        new_hparams[k] = incremental_mutation(\n",
        "            new_hparams[k],\n",
        "            exp_config.models_mutation_ranges[k])\n",
        "    new_hparams['adapter_layers'] = mutate_adapters(\n",
        "        exp_config.mutate_adapters,\n",
        "        adapter_layers_ids=new_hparams['adapter_layers'],\n",
        "        num_layers=new_hparams['num_layers'],\n",
        "        mutation_prob=new_hparams['_mu_'],\n",
        "        policy=policy)\n",
        "\n",
        "    init_params = Path.init_params(\n",
        "        **get_model_kwargs(new_hparams, exp_config))\n",
        "    new_comp_names = params2comp_names(init_params)\n",
        "    for new_comp_name in new_comp_names:\n",
        "      comp = None\n",
        "      # Attept to reuse matching componenent from closer ancestor.\n",
        "      ancestor = self\n",
        "      while ancestor is not None:\n",
        "        comps_lookup = {c.name:c for c in ancestor.components}\n",
        "        if new_comp_name in comps_lookup:\n",
        "          # Head must be trainable if no acestor is of same task will fall back\n",
        "          # to random init of correct shape.\n",
        "          if new_comp_name == 'head' and not comps_lookup[new_comp_name].is_trainable():\n",
        "            assert task.name != ancestor.task.name, f\"{task.name} != {ancestor.task.name}\"\n",
        "            ancestor = ancestor.parent\n",
        "            continue\n",
        "\n",
        "          # Check shapes match otherwise skip.\n",
        "          if jax.tree_map(jnp.shape, init_params[new_comp_name]) != jax.tree_map(jnp.shape, comps_lookup[new_comp_name].params):\n",
        "            if new_comp_name == 'posembed_input':\n",
        "              # Change of image size changed shape of position embeddings,\n",
        "              # this can happend if ds_image_size is tuned,\n",
        "              # continue searching through ancestors for matching size.\n",
        "              assert 'ds_image_size' in exp_config.models_mutation_ranges\n",
        "              assert new_hparams['ds_image_size'] != ancestor.hparams['ds_image_size']\n",
        "              ancestor = ancestor.parent\n",
        "              continue\n",
        "            if new_comp_name.startswith('residual_adapter_'):\n",
        "              # Change of adapter inner dimension changed shape of dense layers,\n",
        "              # this can happend if adapter_dim is tuned,\n",
        "              # continue searching through ancestors for matching size.\n",
        "              assert 'adapter_dim' in exp_config.models_mutation_ranges\n",
        "              assert new_hparams['adapter_dim'] != ancestor.hparams['adapter_dim']\n",
        "              ancestor = ancestor.parent\n",
        "              continue\n",
        "\n",
        "            print(f'WARNING: Shapes do not match for component: {new_comp_name}  {ancestor.task.name}->{task.name}')\n",
        "            print(jax.tree_map(jnp.shape, init_params[new_comp_name]))\n",
        "            print(jax.tree_map(jnp.shape, comps_lookup[new_comp_name].params))\n",
        "            assert False  # Should not happen in current configuration.\n",
        "\n",
        "          comp = get_component_ref(comps_lookup[new_comp_name],\n",
        "                                   clone=policy.do_mutate(new_hparams['_mu_'],\n",
        "                                                          new_comp_name))\n",
        "          break\n",
        "        ancestor = ancestor.parent\n",
        "\n",
        "      # Get reshaped posembed_input.\n",
        "      if comp is None and new_comp_name == 'posembed_input':\n",
        "        pe_comp = Path.posembed_components(\n",
        "            image_size=new_hparams['ds_image_size'],\n",
        "            query=exp_config.load_vit_checkpoint_query)\n",
        "        # Clone to make the component trainable.\n",
        "        comp = get_component_ref(pe_comp, clone=True)\n",
        "\n",
        "      # Otherwise create one from random init params.\n",
        "      if comp is None:\n",
        "        if VERBOSE:\n",
        "          print('Init:', new_comp_name)\n",
        "        # Possible rand init triggering combinations in current configurations.\n",
        "        assert (\n",
        "            new_comp_name == 'head'\n",
        "            or new_comp_name.startswith('residual_adapter_')\n",
        "            or (new_comp_name.startswith('encoderblock_') and \\\n",
        "                exp_config.models_default_hparams['num_layers'] < max(\n",
        "                exp_config.models_mutation_ranges.get('num_layers', [-1])))\n",
        "            )\n",
        "        comp = params2comps(init_params, train_locks=[], name=new_comp_name)[0]\n",
        "      assert comp is not None\n",
        "      comps.append(comp)\n",
        "\n",
        "    rtn = Path(new_hparams, comps, parent=self, task=task)\n",
        "    if task == self.task:\n",
        "      self.metrics['offsprings'] = self.metrics.get('offsprings', 0) + 1\n",
        "\n",
        "    offsprings_per_task = json.loads(self.metrics['offsprings_per_task'])\n",
        "    offsprings_per_task[task.name] = offsprings_per_task.get(task.name, 0) + 1\n",
        "    self.metrics['offsprings_per_task'] = json.dumps(offsprings_per_task)\n",
        "\n",
        "    return rtn\n",
        "\n",
        "  def get_optimizer(self):\n",
        "    return Path.optimizers(\n",
        "        lr=float(self.hparams['opt_lr']),\n",
        "        lr_schedule=str(self.hparams['opt_lr_schedule']),\n",
        "        lr_warmup_ratio=float(self.hparams['opt_lr_warmup_ratio']),\n",
        "        momentum=float(self.hparams['opt_momentum']),\n",
        "        nesterov=bool(self.hparams['opt_nesterov']),\n",
        "        num_train_batches_between_validations=int(\n",
        "            self.task.num_train_batches_between_validations),\n",
        "        num_validations_per_path_training=int(\n",
        "            self.task.exp_config.num_validations_per_path_training),\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C3-3N-LPV09d"
      },
      "outputs": [],
      "source": [
        "def mutate_adapters(mutate, adapter_layers_ids, num_layers, mutation_prob, policy, allow_removal=False):\n",
        "  a_ids = set(ids_str2ints(adapter_layers_ids))\n",
        "  if mutate:\n",
        "    for a_id in range(num_layers):\n",
        "      if policy.do_mutate(mutation_prob):\n",
        "        if a_id in a_ids:\n",
        "          if allow_removal:\n",
        "            a_ids.remove(a_id)\n",
        "        else:\n",
        "          a_ids.add(a_id)\n",
        "  # Drop adapters of layers dropped by a possible mutation in num_layers.\n",
        "  a_ids = [a_id for a_id in a_ids if a_id < num_layers]\n",
        "  return ids_ints2str(a_ids)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bkRmcJgzUbwN"
      },
      "outputs": [],
      "source": [
        "class Scorer():\n",
        "  def score(self, path):\n",
        "    assert False, 'Not implemented'\n",
        "\n",
        "class ScorerQuality(Scorer):\n",
        "  def score(self, path):\n",
        "    if ('quality' not in path.metrics\n",
        "        or math.isnan(path.metrics['quality'])):\n",
        "      return None\n",
        "    assert path.metrics['quality'] >= 0, \\\n",
        "        f'{path.task.name} {path.metrics[\"quality\"]}'\n",
        "    score = path.metrics['quality']\n",
        "    assert score >= 0\n",
        "    return score\n",
        "\n",
        "class ScorerDecay(Scorer):\n",
        "  def __init__(self, base, num_params):\n",
        "    self.base = base\n",
        "    assert self.base > 0.0\n",
        "    assert self.base <= 1.0\n",
        "    self.num_params = num_params\n",
        "    assert self.num_params > 0\n",
        "  def score(self, path):\n",
        "    if ('quality' not in path.metrics\n",
        "        or math.isnan(path.metrics['quality'])):\n",
        "      return None\n",
        "    assert path.metrics['quality'] >= 0, \\\n",
        "        f'{path.task.name} {path.metrics[\"quality\"]}'\n",
        "    score = path.metrics['quality'] * (self.base ** (path.accounted_num_params() / self.num_params))\n",
        "    assert score >= 0\n",
        "    return score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iWZeLq5zZlfb"
      },
      "outputs": [],
      "source": [
        "class PPDecay():\n",
        "  def __init__(self, exp_config):\n",
        "    self.exp_config = exp_config\n",
        "\n",
        "  def do_mutate(self, mutation_prob, comp_name=None):\n",
        "    if comp_name:\n",
        "      if comp_name in exp_config.force_finetune_components:\n",
        "        return True\n",
        "    return mutation_prob > np.random.uniform()\n",
        "\n",
        "  def sample_parent(self, paths, task_name):\n",
        "    for path in paths:\n",
        "      offsprings = json.loads(path.metrics['offsprings_per_task']).get(task_name, 0)\n",
        "      print(' ', path.id, offsprings)\n",
        "      assert not math.isnan(offsprings)\n",
        "      if np.random.uniform() < 0.5 ** offsprings:\n",
        "        return path\n",
        "    return None\n",
        "\n",
        "  def sample_path(self, pop, task:Task, ds_hparams):\n",
        "    parent = self.sample_parent(\n",
        "        sorted(pop.paths[task], key=lambda p: p.score(), reverse=True),\n",
        "        task.name)\n",
        "    if not parent:\n",
        "      print('  seeds')\n",
        "      parent = self.sample_parent(pop.seed_paths, task.name)\n",
        "      if parent:  # Rotate seeds.\n",
        "        pos = pop.seed_paths.index(parent) + 1\n",
        "        pop.seed_paths = pop.seed_paths[pos:] + pop.seed_paths[:pos]\n",
        "    if not parent:  # Random sample.\n",
        "      parent = random.choice(pop.paths[task] + pop.seed_paths)\n",
        "      print('  random', parent.id)\n",
        "\n",
        "    child = parent.clone(task, ds_hparams, self)\n",
        "\n",
        "    gc.collect()\n",
        "\n",
        "    # Store record of mutations.\n",
        "    mutations = {}\n",
        "    for k in child.hparams:\n",
        "      if parent.hparams.get(k) != child.hparams[k]:\n",
        "        mutations[k] = (parent.hparams.get(k), child.hparams[k])\n",
        "    child.metrics['mutations'] = json.dumps(mutations)\n",
        "    print(child.id, child.metrics['mutations'])\n",
        "    return child\n",
        "\n",
        "  def sample_ds_hparams(self, pop, task:Task):\n",
        "    assert pop.exp_config is self.exp_config\n",
        "    ds_hparams = {}\n",
        "    for key in self.exp_config.models_default_hparams:\n",
        "      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):\n",
        "        ds_hparams[key] = self.exp_config.models_default_hparams[key]\n",
        "    best_path = pop.get_best_path(task)\n",
        "    if best_path:\n",
        "      ds_hparams.update(\n",
        "          {k : best_path.hparams[k] for k in ds_hparams if k in best_path.hparams})\n",
        "    for k in ds_hparams:\n",
        "      if (k in self.exp_config.models_mutation_ranges\n",
        "          and pop.policy.do_mutate(self.exp_config.models_default_hparams['_mu_'])):\n",
        "        ds_hparams[k] = incremental_mutation(\n",
        "            ds_hparams[k],\n",
        "            self.exp_config.models_mutation_ranges[k])\n",
        "    return ds_hparams"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tKZ8VTr84VQu"
      },
      "outputs": [],
      "source": [
        "class PPBaseline():\n",
        "  def __init__(self, exp_config):\n",
        "    self.exp_config = exp_config\n",
        "\n",
        "  def sample_parent(self, paths):\n",
        "    assert False, 'Baselines should not reach evolutionary codepath.'\n",
        "\n",
        "  def do_mutate(self, mutation_prob, comp_name=None):\n",
        "    if comp_name:\n",
        "      if comp_name in exp_config.force_finetune_components:\n",
        "        return True\n",
        "    if mutation_prob == 0.0:\n",
        "      return False\n",
        "    elif mutation_prob == 1.0:\n",
        "      return True\n",
        "    else:\n",
        "      assert False, mutation_prob\n",
        "\n",
        "  def sample_path(self, pop, task:Task, ds_hparams):\n",
        "    assert len(pop.paths[not_trainable]) == 1\n",
        "    parent = pop.paths[not_trainable][0]\n",
        "    child = parent.clone(task, ds_hparams, self)\n",
        "    return child\n",
        "\n",
        "  def sample_ds_hparams(self, pop, task:Task):\n",
        "    ds_hparams = {}\n",
        "    for key in self.exp_config.models_default_hparams:\n",
        "      if key.startswith(DATASET_HPARAMS_KEYS_PRERFIX):\n",
        "        ds_hparams[key] = self.exp_config.models_default_hparams[key]\n",
        "    return ds_hparams"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YMbYgKd8_nyi"
      },
      "outputs": [],
      "source": [
        "class Population():\n",
        "  def __init__(self, exp_config):\n",
        "    self.paths = defaultdict(list)\n",
        "    self.exp_config = exp_config\n",
        "    self.paths_df = pd.DataFrame()\n",
        "    self.comps_df = pd.DataFrame()\n",
        "    self.policy = globals()[exp_config.policy_class](\n",
        "        **exp_config.policy_kwargs,\n",
        "        exp_config=exp_config)\n",
        "\n",
        "  def get_best_path(self, task:Task):\n",
        "    if len(self.paths[task]) == 0:\n",
        "      return None\n",
        "    # Most recent path achieving max score.\n",
        "    return max(sorted(self.paths[task], key=lambda p: p.id, reverse=True),\n",
        "               key=lambda p: p.score())\n",
        "\n",
        "  def sample_path(self, task:Task, ds_hparams):\n",
        "    return self.policy.sample_path(pop=self, task=task, ds_hparams=ds_hparams)\n",
        "\n",
        "  def sample_ds_hparams(self, task:Task):\n",
        "    ds_hparams = self.policy.sample_ds_hparams(pop=self, task=task)\n",
        "    return ds_hparams\n",
        "\n",
        "  def add_train_locks(self, task:Task):\n",
        "    # Check.\n",
        "    for ps in self.paths.values():\n",
        "      for p in ps:\n",
        "        for c in p.components:\n",
        "          assert task.name not in c.train_locks\n",
        "    # Add locks.\n",
        "    paths = self.paths[task]\n",
        "    for p in paths:\n",
        "      for c in p.components:\n",
        "        c.train_locks.add(task.name)\n",
        "\n",
        "  def rm_train_locks(self, task:Task):\n",
        "    # Add locks.\n",
        "    paths = self.paths[task]\n",
        "    for p in paths:\n",
        "      for c in p.components:\n",
        "        if task.name in c.train_locks:\n",
        "          c.train_locks.remove(task.name)\n",
        "    # Check.\n",
        "    for ps in self.paths.values():\n",
        "      for p in ps:\n",
        "        for c in p.components:\n",
        "          assert task.name not in c.train_locks\n",
        "\n",
        "  def set_seed_paths(self, task:Task):\n",
        "    self.seed_paths = []\n",
        "    for paths in self.paths.values():\n",
        "      for path in paths:\n",
        "        if path.task is task:\n",
        "          continue\n",
        "        if path.task.is_private():\n",
        "          continue\n",
        "        self.seed_paths.append(path)\n",
        "    random.shuffle(self.seed_paths)\n",
        "\n",
        "  def start_task(self, task:Task):\n",
        "    self.set_seed_paths(task)\n",
        "    self.rm_train_locks(task)\n",
        "\n",
        "  def end_task(self, task:Task):\n",
        "    # Keep only best one.\n",
        "    best_path = self.get_best_path(task)\n",
        "    assert best_path is not None\n",
        "    self.paths[task] = [best_path]\n",
        "\n",
        "    self.add_train_locks(task)\n",
        "    self.garbage_collect_paths()\n",
        "\n",
        "  def garbage_collect_paths(self):\n",
        "    # Store stats before dropping references to trigger garbage collection\n",
        "    # of unused paths, components and parameters.\n",
        "    self.paths_df = self.paths_df.append(paths_to_df(Path.paths),\n",
        "                                         ignore_index=True)\n",
        "    self.comps_df = self.comps_df.append(components_to_df(Path.paths),\n",
        "                                         ignore_index=True)\n",
        "\n",
        "    # Drop unused paths generated in this task iteration for garbage collection.\n",
        "    Path.paths = []\n",
        "    # Simplify ancestor tree to contain only live paths.\n",
        "    live_paths_ids = [p.id for paths in self.paths.values() for p in paths]\n",
        "    # Notice that the simplification is done also for paths of other tasks,\n",
        "    # since they may be pointing to a path of this task that was just pruned.\n",
        "    for path in [path for paths in self.paths.values() for path in paths]:\n",
        "      ancestor = path.parent\n",
        "      if ancestor is None:\n",
        "        continue\n",
        "      while True:\n",
        "        if ancestor.id in live_paths_ids:\n",
        "          path.parent = ancestor\n",
        "          break\n",
        "        ancestor = ancestor.parent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4EgQHbawpNcS"
      },
      "outputs": [],
      "source": [
        "pd.set_option('display.expand_frame_repr', False)\n",
        "pd.set_option('display.max_columns', 100)\n",
        "pd.set_option('display.max_rows', 100)\n",
        "\n",
        "def pop_to_df(pop):\n",
        "  return paths_to_df([p for paths in pop.paths.values() for p in paths])\n",
        "\n",
        "def paths_to_df(paths):\n",
        "  # Collect all metrics names.\n",
        "  metrics_keys = set()\n",
        "  hparams_keys = set()\n",
        "  for path in paths:\n",
        "    metrics_keys.update(path.metrics)\n",
        "    hparams_keys.update(path.hparams)\n",
        "\n",
        "  data = defaultdict(list)\n",
        "  for path in paths:\n",
        "    data['task_name'].append(path.task.name)\n",
        "    data['id'].append(path.id)\n",
        "    data['parent_id'].append(path.parent.id if path.parent else -1)\n",
        "    data['parent_task_name'].append(path.parent.task.name if path.parent else None)\n",
        "    data['final_accounted_params'].append(path.accounted_num_params())\n",
        "    data['components'].append('_'.join([str(c.id) for c in path.components]))\n",
        "    for k in hparams_keys:\n",
        "      data[f'hparams.{k}'].append(path.hparams[k] if k in path.hparams else None)\n",
        "    for k in metrics_keys:\n",
        "      data[f'metrics.{k}'].append(path.metrics[k] if k in path.metrics else None)\n",
        "    data['score'].append(path.score())\n",
        "  return pd.DataFrame(data)\n",
        "\n",
        "def components_to_df(paths):\n",
        "  # Collect all components.\n",
        "  comps = set()\n",
        "  for p in paths:\n",
        "    comps.update(p.components)\n",
        "\n",
        "  data = defaultdict(list)\n",
        "  for c in comps:\n",
        "    data['id'].append(c.id)\n",
        "    data['name'].append(c.name)\n",
        "    data['num_params'].append(c.get_num_params())\n",
        "    data['train_locks'].append(','.join(c.train_locks))\n",
        "  return pd.DataFrame(data)\n",
        "\n",
        "def print_df_segments(df, segment_length:int = 10):\n",
        "  tot_length = df.shape[0]\n",
        "  # Pad column title with spaces to keep alignment across segments.\n",
        "  def prepend_spaces(original_str, pad_to_len):\n",
        "    return ' ' * (pad_to_len-len(original_str)) + original_str\n",
        "  pad_to_len = max([len(tn) for tn in set(df['task_name'].to_list())])+1\n",
        "  df = df.rename(columns={\n",
        "    'task_name': prepend_spaces('task_name', pad_to_len),\n",
        "    'parent_task_name': prepend_spaces('parent_task_name', pad_to_len),\n",
        "    })\n",
        "  for x in range(0, tot_length, segment_length):\n",
        "    print(df[x:min(x+segment_length, tot_length)])\n",
        "\n",
        "def df_leaderboard(df):\n",
        "  df = df.loc[df['task_name'] != NOT_TRAINABLE]\n",
        "  # Place columns on the left for readability.\n",
        "  all_keys = sorted(df.columns.tolist())\n",
        "  first_keys = ['task_name','score', 'metrics.quality', 'metrics.test_quality',\n",
        "                'id', 'parent_id','parent_task_name', 'final_accounted_params']\n",
        "  first_keys = [k for k in first_keys if k in all_keys]\n",
        "  sorted_keys = first_keys + [k for k in all_keys if k not in first_keys]\n",
        "  df = df[sorted_keys]\n",
        "  print_df_segments(df)\n",
        "  print(f'Avg score:        {df[\"score\"].mean():.6f}')\n",
        "  print(f'Avg quality:      {df[\"metrics.quality\"].mean():.6f}')\n",
        "  if 'metrics.test_quality' in df:\n",
        "    print(f'Avg test quality: {df[\"metrics.test_quality\"].mean():.6f}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xmYkTEY-PBgh"
      },
      "outputs": [],
      "source": [
        "# Print path.\n",
        "def prp(path):\n",
        "  rtn = []\n",
        "  if VERBOSE:\n",
        "    rtn.append(str(path))\n",
        "    for c in path.components:\n",
        "      rtn.append(str(c))\n",
        "  else:\n",
        "    rtn.append(str(path.id))\n",
        "  return '\\n'.join(rtn)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M_RTFjoOwfpM"
      },
      "outputs": [],
      "source": [
        "def df_write_to_file(df, dir_path, df_name):\n",
        "  filename_df = os.path.join(dir_path, f'{df_name}.csv')\n",
        "  with tf.io.gfile.GFile(filename_df, 'w') as outfile:\n",
        "    df.to_csv(outfile, index=False)\n",
        "\n",
        "def df_read_from_file(dir_path, df_name):\n",
        "  filename_df = os.path.join(dir_path, f'{df_name}.csv')\n",
        "  with tf.io.gfile.GFile(filename_df, 'r') as infile:\n",
        "    df = pd.read_csv(infile)\n",
        "  # Pandas read_csv() reads empty stings as NaNs. Set NaNs to empty strings in\n",
        "  # columns with type strings/object.\n",
        "  for c in df.columns:\n",
        "    if df[c].dtype == np.object_:\n",
        "        df[c].fillna('', inplace=True)\n",
        "  return df\n",
        "\n",
        "def get_comps_params(pop:Population):\n",
        "  comps_params = {}\n",
        "  for c in set([c for paths in pop.paths.values() for p in paths for c in p.components]):\n",
        "    comps_params[f'{c.name}:{c.id}'] = c.params\n",
        "    if c.opt_state is not None:\n",
        "      comps_params[f'opt_state:{c.name}:{c.id}'] = c.opt_state\n",
        "  return comps_params\n",
        "\n",
        "LAST_CHECKPOINT_TIME = time.time()\n",
        "def checkpoint_save(\n",
        "    experiment_dir:str, comps_params, generation_id:int, loop_id:int):\n",
        "  flax_checkpoints.save_checkpoint(\n",
        "      ckpt_dir=experiment_dir,\n",
        "      target=comps_params,\n",
        "      step=loop_id,\n",
        "      prefix=f\"checkpoint_{generation_id}_\",\n",
        "      overwrite=True)\n",
        "  global LAST_CHECKPOINT_TIME\n",
        "  LAST_CHECKPOINT_TIME = time.time()\n",
        "\n",
        "def save_state(\n",
        "    pop:Population,\n",
        "    generation_id:int,\n",
        "    loop_id:int,\n",
        "    exp_config:FrozenConfigDict):\n",
        "  # Save data needed to resume exp.\n",
        "  write_st = time.time()\n",
        "  df_leaderboard(pop_to_df(pop))\n",
        "  pop.garbage_collect_paths()\n",
        "  print('WRITING CHECKPOINT:', loop_id, generation_id)\n",
        "  if loop_id == 0:\n",
        "    tf.io.gfile.makedirs(exp_config.experiment_dir)\n",
        "    json.dump(exp_config.as_configdict().to_dict(),\n",
        "              tf.io.gfile.GFile(os.path.join(exp_config.experiment_dir,\n",
        "                                      'config.json'),\n",
        "                         'wb'), indent=2)\n",
        "  state_dir = os.path.join(\n",
        "      exp_config.experiment_dir, f\"state_{loop_id}_{generation_id}\")\n",
        "  tf.io.gfile.makedirs(state_dir)\n",
        "  # Write state in background threads.\n",
        "  write_threads = []\n",
        "  write_threads.append(Thread(target=df_write_to_file, args=(pop_to_df(pop), state_dir, 'population')))\n",
        "  write_threads.append(Thread(target=df_write_to_file, args=(pop.paths_df, state_dir, 'paths')))\n",
        "  write_threads.append(Thread(target=df_write_to_file, args=(pop.comps_df, state_dir, 'components')))\n",
        "  write_threads.append(Thread(target=checkpoint_save, args=(state_dir, get_comps_params(pop), loop_id, generation_id)))\n",
        "  for t in write_threads:\n",
        "    t.start()\n",
        "  print(f'WRITE START TIME: {time.time() - write_st:.2f} s')\n",
        "  return write_threads"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7yWdy7DskBph"
      },
      "outputs": [],
      "source": [
        "def load_population_from_checkpoint(\n",
        "    pop:Population,\n",
        "    ckpt_dir:str,\n",
        "    population_df):\n",
        "  loaded_params = flax.core.freeze(\n",
        "      flax_checkpoints.restore_checkpoint(\n",
        "          ckpt_dir=ckpt_dir,\n",
        "          target=None))\n",
        "  id_2_comp = {}\n",
        "  for k in loaded_params.keys():\n",
        "    if k.startswith('opt_state:'):\n",
        "      continue\n",
        "    name, id = k.split(':')\n",
        "    if 'opt_state:'+k in loaded_params.keys():\n",
        "      opt_state = loaded_params['opt_state:' + k]\n",
        "    else:\n",
        "      opt_state = None\n",
        "    c = Component(name=name, params=loaded_params[k], train_locks=[], opt_state=opt_state)\n",
        "    c.id = int(id)\n",
        "    assert c.id not in id_2_comp\n",
        "    id_2_comp[c.id] = c\n",
        "  # For parent assignemt.\n",
        "  id_2_path = {}\n",
        "  path_2_parent_id = {}\n",
        "  for index, row in population_df.iterrows():\n",
        "    comps_ids = row['components'].split('_')\n",
        "    comps = []\n",
        "    for id in comps_ids:\n",
        "      comps.append(id_2_comp[int(id)])\n",
        "    task_name = row['task_name']\n",
        "    if task_name == NOT_TRAINABLE:\n",
        "      task = not_trainable\n",
        "    else:\n",
        "      task = Path.tasks(task_name=task_name)\n",
        "    # Retrieve hparams and metrics.\n",
        "    hparams = {}\n",
        "    metrics = {}\n",
        "    for k in row.keys():\n",
        "      if k.startswith('hparams.'):\n",
        "        hparams[k[len('hparams.'):]] = row[k]\n",
        "      if k.startswith('metrics.'):\n",
        "        metrics[k[len('metrics.'):]] = row[k]      \n",
        "    if type(hparams['adapter_layers']) is float:\n",
        "      if math.isnan(hparams['adapter_layers']):\n",
        "        hparams['adapter_layers'] = ''\n",
        "      else:\n",
        "        hparams['adapter_layers'] = str(int(hparams['adapter_layers']))\n",
        "    metrics['reloads'] = metrics['reloads'] + 1\n",
        "    # Create path.\n",
        "    path = Path(\n",
        "        hparams=hparams,\n",
        "        components=comps,\n",
        "        parent=None,\n",
        "        task=task,\n",
        "        )\n",
        "    path.metrics = metrics\n",
        "    path.id = int(row['id'])\n",
        "    # Add train locks.\n",
        "    for c in path.components:\n",
        "      c.train_locks.add(task_name)\n",
        "    pop.paths[task].append(path)\n",
        "    assert path.id not in id_2_path\n",
        "    id_2_path[path.id] = path\n",
        "    if task_name != NOT_TRAINABLE:\n",
        "      path_2_parent_id[path] = int(row['parent_id'])\n",
        "\n",
        "  # Set parents.\n",
        "  for path, parent_id in path_2_parent_id.items():\n",
        "    path.parent = id_2_path[parent_id]\n",
        "  Path.counter = 1 + max([id for id in id_2_path])\n",
        "  Component.counter = 1 + max([id for id in id_2_comp])\n",
        "  Path.paths = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m6vSjSIvNPq4"
      },
      "outputs": [],
      "source": [
        "@partial(jax.jit, static_argnames='model')\n",
        "def eval_step(params, images, labels, model):\n",
        "  logits = model.apply({'params': params}, images, train=USE_DROPOUT)\n",
        "  # Avg accuracy on the batch.\n",
        "  return (logits.argmax(axis=-1) == labels).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_m2xl8XR7cWy"
      },
      "outputs": [],
      "source": [
        "@partial(jax.jit, static_argnames=['model', 'optimizer'], donate_argnums=[0, 2])\n",
        "def train_step(params, fixed_params, opt_state, images, labels, model, optimizer):\n",
        "  def loss_fn(params, fixed_params, images, labels):\n",
        "    logits = model.apply({'params': format_params(params, fixed_params)},\n",
        "                         images, train=USE_DROPOUT)\n",
        "    labels = jax.nn.one_hot(labels, logits.shape[-1])\n",
        "    return -jnp.mean(jnp.sum(labels * nn.log_softmax(logits), axis=-1))\n",
        "  grads = jax.grad(loss_fn)(params, fixed_params, images, labels)\n",
        "  updates, opt_state = optimizer.update(grads, opt_state, params=params)\n",
        "  params = optax.apply_updates(params, updates)\n",
        "  return params, opt_state"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lRjic_IpGYJU"
      },
      "outputs": [],
      "source": [
        "LOOP_START = time.time()\n",
        "\n",
        "def train_loop(paths, ds_train, ds_validation, devices, exp_config):\n",
        "  global LOOP_START\n",
        "  timing = {'start_time': time.time(),\n",
        "            'start_time_loop': LOOP_START}\n",
        "  task = paths[0].task\n",
        "  # The following values should be shared by all paths in this generation batch.\n",
        "  for path in paths:\n",
        "    assert task == path.task\n",
        "    assert paths[0].hparams['ds_image_size'] == path.hparams['ds_image_size']\n",
        "\n",
        "  gc.collect()\n",
        "\n",
        "  # Compile.\n",
        "  compile_train_batches_arr = jax.device_put_replicated(\n",
        "      get_sample_batch(\n",
        "        paths[0].hparams['ds_image_size'],\n",
        "        task.train_batch_size),\n",
        "      devices)\n",
        "  compile_eval_batches_arr = jax.device_put_replicated(\n",
        "      get_sample_batch(\n",
        "          paths[0].hparams['ds_image_size'],\n",
        "          task.validation_batch_size),\n",
        "      devices)\n",
        "\n",
        "  for p_id, path in enumerate(paths):\n",
        "    if VERBOSE:\n",
        "      print('Parent')\n",
        "      print(prp(path.parent))\n",
        "      print(prp(path))\n",
        "    path.device_id = p_id % len(devices)\n",
        "    path.device = devices[path.device_id]\n",
        "    path.optimizer = path.get_optimizer()\n",
        "    path.optimizer_init_fn = jax.jit(path.optimizer.init, device=path.device)\n",
        "    path.best_params_local = None\n",
        "    path.best_opt_state_local = None\n",
        "    path.best_quality = None\n",
        "    path.best_score = path.parent.score() if path.task is path.parent.task else -np.inf\n",
        "    path.evals = []\n",
        "\n",
        "    # Launch parallel compilation of eval and train step functions.\n",
        "    params_local = path.get_trainable_params()\n",
        "    check_is_local(params_local)\n",
        "    path.compile_params_device = jax.device_put(params_local, path.device)\n",
        "    path.compile_fixed_params_device = jax.device_put(\n",
        "        path.get_fixed_params(),\n",
        "        path.device)\n",
        "    path.compile_train = Thread(\n",
        "        target=train_step,\n",
        "        args=(path.compile_params_device,\n",
        "              path.compile_fixed_params_device,\n",
        "              path.optimizer_init_fn(params_local),\n",
        "              compile_train_batches_arr['image'][path.device_id],\n",
        "              compile_train_batches_arr['label'][path.device_id],\n",
        "              path.model,\n",
        "              path.optimizer))\n",
        "    path.compile_eval = Thread(\n",
        "        target=eval_step,\n",
        "        args=(format_params(\n",
        "                  path.compile_params_device,\n",
        "                  path.compile_fixed_params_device),\n",
        "              compile_eval_batches_arr['image'][path.device_id],\n",
        "              compile_eval_batches_arr['label'][path.device_id],\n",
        "              path.model))\n",
        "    path.compile_eval.start()\n",
        "\n",
        "  for path in paths:\n",
        "    path.compile_eval.join()\n",
        "    del path.compile_eval\n",
        "    timing['end_compile_eval'] = time.time()\n",
        "    path.compile_train.start()\n",
        "  del compile_eval_batches_arr\n",
        "\n",
        "  for path in paths:\n",
        "    path.compile_train.join()\n",
        "    del path.compile_train\n",
        "    del path.compile_params_device\n",
        "    del path.compile_fixed_params_device\n",
        "    timing['end_compile'] = time.time()\n",
        "  del compile_train_batches_arr\n",
        "\n",
        "  gc.collect()\n",
        "\n",
        "  # Parameter tranfer.\n",
        "  for path in paths:\n",
        "    path.params_device = jax.device_put(\n",
        "        path.get_trainable_params(),\n",
        "        path.device)\n",
        "    path.fixed_params_device = jax.device_put(\n",
        "        path.get_fixed_params(),\n",
        "        path.device)\n",
        "    path.opt_state_device = path.optimizer_init_fn(path.params_device)\n",
        "    # Set opt state.\n",
        "    for c in path.components:\n",
        "      if c.is_trainable():\n",
        "        assert c.name in path.opt_state_device[1][0].trace.keys()\n",
        "        if c.opt_state is not None:\n",
        "          path.opt_state_device = (\n",
        "              path.opt_state_device[0],\n",
        "              (optax.TraceState(\n",
        "                  trace=path.opt_state_device[1][0].trace.copy(\n",
        "                      {c.name: jax.device_put(c.opt_state,\n",
        "                                              path.device)})),\n",
        "               path.opt_state_device[1][1]\n",
        "               )\n",
        "          )\n",
        "    check_is_on_device(path.opt_state_device, path.device)\n",
        "\n",
        "  iter_ds_validation = iter(ds_validation)\n",
        "  # TRAIN\n",
        "  for t_step, train_batch in zip(\n",
        "      range(exp_config.num_validations_per_path_training\n",
        "            * task.num_train_batches_between_validations),\n",
        "      ds_train,\n",
        "  ):\n",
        "    train_batch_arr = jax.device_put_replicated(train_batch, devices)\n",
        "    for p_id, path in enumerate(paths):\n",
        "      if t_step == 0:\n",
        "        timing['end_prep'] = time.time()\n",
        "        t_step_0_time = time.time()\n",
        "      path.params_device, path.opt_state_device = train_step(\n",
        "          path.params_device,\n",
        "          path.fixed_params_device,\n",
        "          path.opt_state_device,\n",
        "          train_batch_arr['image'][path.device_id],\n",
        "          train_batch_arr['label'][path.device_id],\n",
        "          path.model,\n",
        "          path.optimizer)\n",
        "      if t_step == 0 and time.time() - t_step_0_time > 1:\n",
        "        print(f'WARNING: First train step took: {time.time()-t_step_0_time:.2f} s')\n",
        "    del train_batch, train_batch_arr\n",
        "\n",
        "    # EVAL\n",
        "    if (t_step+1) % task.num_train_batches_between_validations == 0:\n",
        "      first_eval = ((t_step+1) == task.num_train_batches_between_validations)\n",
        "      if first_eval:\n",
        "        timing['start_eval'] = time.time()\n",
        "      for path in paths:\n",
        "        path.accs = []\n",
        "      for e_step, eval_batch in zip(\n",
        "          range(task.num_validation_batches),\n",
        "          iter_ds_validation,\n",
        "          ):\n",
        "        eval_batch_arr = jax.device_put_replicated(eval_batch, devices)\n",
        "        for p_id, path in enumerate(paths):\n",
        "          if first_eval and e_step == 0:\n",
        "            e_step_0_time = time.time()\n",
        "          path.accs.append(\n",
        "              eval_step(\n",
        "                  format_params(path.params_device, path.fixed_params_device),\n",
        "                  eval_batch_arr['image'][path.device_id],\n",
        "                  eval_batch_arr['label'][path.device_id],\n",
        "                  path.model))\n",
        "          if first_eval and e_step == 0 and time.time() - e_step_0_time > 1:\n",
        "            print(f'WARNING: First eval step took: {time.time()-e_step_0_time:.2f} s')\n",
        "      del eval_batch, eval_batch_arr\n",
        "\n",
        "      # Get params of best models.\n",
        "      qs = []\n",
        "      eval_idx = (t_step+1) // task.num_train_batches_between_validations\n",
        "      for path in paths:\n",
        "        quality = np.mean(path.accs)\n",
        "        del path.accs\n",
        "        qs.append(f'{quality:.4f}')\n",
        "        path.evals.append(quality)\n",
        "        # Set quality in metrics for current score computation.\n",
        "        path.metrics['quality'] = quality\n",
        "        path_score = path.score()\n",
        "        if path_score >= path.best_score:\n",
        "          path.best_params_local = jax.device_get(path.params_device)\n",
        "          path.best_opt_state_local = jax.device_get(path.opt_state_device[1][0].trace)\n",
        "          path.best_score = path_score\n",
        "          path.best_quality = quality\n",
        "          qs[-1] += '*'\n",
        "      train_time = time.time() - timing['end_compile']\n",
        "      avg_path_time = (train_time / eval_idx) / len(paths)\n",
        "      print(('\\t'.join(qs) + f'\\t< Eval {eval_idx}').expandtabs(8),\n",
        "            f'tot:{train_time:.1f}s', f'avg/path:{avg_path_time:.1f}s')\n",
        "\n",
        "      if first_eval:\n",
        "        timing['end_eval'] = time.time()\n",
        "\n",
        "  for path in paths:\n",
        "    del path.params_device\n",
        "    del path.fixed_params_device\n",
        "    del path.opt_state_device\n",
        "    del path.optimizer\n",
        "    del path.optimizer_init_fn\n",
        "  gc.collect()\n",
        "\n",
        "  timing['end_train'] = time.time()\n",
        "\n",
        "  loop_time = timing['start_time'] - LOOP_START\n",
        "  compile_time = timing['end_compile'] - timing['start_time']\n",
        "  compile_eval_time = timing['end_compile_eval'] - timing['start_time']\n",
        "  compile_train_time = timing['end_compile'] - timing['end_compile_eval']\n",
        "  prep_time = timing['end_prep'] - timing['end_compile']\n",
        "  train_time = timing['end_train'] - timing['end_prep']\n",
        "  eval_time = timing['end_eval'] - timing['start_eval']\n",
        "  LOOP_START = time.time()\n",
        "\n",
        "  for path in paths:\n",
        "    path.metrics['loop_time'] = loop_time\n",
        "    path.metrics['compile_time'] = compile_time\n",
        "    path.metrics['prep_time'] = prep_time\n",
        "    path.metrics['train_time'] = train_time\n",
        "    path.metrics['eval_time'] = eval_time\n",
        "    path.metrics['start_time'] = timing['start_time']\n",
        "    path.metrics['start_time_loop'] = timing['start_time_loop']\n",
        "    path.metrics['end_time'] = time.time()\n",
        "    num_all_params = get_num_params(path.get_all_params())\n",
        "    num_trainable_params = get_num_params(path.get_trainable_params())\n",
        "    path.metrics['trainable_params_ratio'] = num_trainable_params/num_all_params\n",
        "    path.metrics['num_trainable_params'] = num_trainable_params\n",
        "    path.metrics['quality'] = max(path.evals)\n",
        "    path.metrics['evals'] = json.dumps([float(v) for v in path.evals])\n",
        "    path.metrics['training_accounted_params'] = path.accounted_num_params()\n",
        "    path.metrics['training_score'] = path.score()\n",
        "\n",
        "    if path.best_params_local:\n",
        "      path.metrics['improved'] = True\n",
        "      path.update_trainable(path.best_params_local,\n",
        "                            path.best_opt_state_local)\n",
        "      assert path.best_quality == path.metrics['quality']\n",
        "      assert path.best_score == path.metrics['training_score']\n",
        "    else:\n",
        "      path.metrics['improved'] = False\n",
        "      # Sampled path will be dropped if not improved, so skip paramter update.\n",
        "      assert path.best_params_local == None\n",
        "      assert path.best_opt_state_local == None\n",
        "      assert path.best_quality == None\n",
        "\n",
        "    del path.best_params_local\n",
        "    del path.best_opt_state_local\n",
        "    del path.best_score\n",
        "    del path.best_quality\n",
        "    del path.evals\n",
        "\n",
        "    if VERBOSE:\n",
        "      print('UPDATED:')\n",
        "      print(prp(path))\n",
        "\n",
        "  pqs = []\n",
        "  qs = []\n",
        "  psc = []\n",
        "  sc = []\n",
        "  for path in paths:\n",
        "    if path.task is path.parent.task:\n",
        "      pqs.append(f'{path.parent.metrics[\"quality\"]:.4f}')\n",
        "      psc.append(f'{path.parent.score():.4f}')\n",
        "    else:\n",
        "      pqs.append('NEW')\n",
        "      psc.append('NEW')\n",
        "    qs.append(f'{path.metrics[\"quality\"]:.4f}')\n",
        "    sc.append(f'{path.score():.4f}')\n",
        "    if path.metrics['improved']:\n",
        "      sc[-1] += '+'\n",
        "\n",
        "  print(('\\t'.join([f'{path.parent.id}' for path in paths]) +\n",
        "        '\\t< Parent id').expandtabs(8))\n",
        "  print(('\\t'.join([f'{path.id}' for path in paths]) +\n",
        "        '\\t< Path id').expandtabs(8))\n",
        "  print(('\\t'.join(pqs) + '\\t< Parent best quality').expandtabs(8))\n",
        "  print(('\\t'.join(qs) + '\\t< Path best quality').expandtabs(8))\n",
        "  print(('\\t'.join(psc) + '\\t< Parent score').expandtabs(8))\n",
        "  print(('\\t'.join(sc) + '\\t< Path score').expandtabs(8))\n",
        "\n",
        "  print('time\\tLOOP\\tCOMPevl\\tCOMPtrn\\tPREP\\tTRN+EVL\\t1stEVAL'.expandtabs(8))\n",
        "  print(f'(s)\\t{loop_time:.1f}\\t{compile_eval_time:.1f}\\t{compile_train_time:.1f}\\t{prep_time:.1f}\\t{train_time:.1f}\\t{eval_time:.1f}'.expandtabs(8))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p-3gYmECRBOU"
      },
      "outputs": [],
      "source": [
        "# Run a full paths sampling iteration for a task.\n",
        "def task_iter(\n",
        "    task, devices, pop:Population, generation_id:int, loop_id:int,\n",
        "    exp_config:FrozenConfigDict):\n",
        "  num_devices = len(devices)\n",
        "  # Track best path.\n",
        "  best_path = pop.get_best_path(task)\n",
        "  num_gen_batches = math.ceil(exp_config.num_samples_per_task/num_devices)\n",
        "  for _ in range(num_gen_batches):\n",
        "    if generation_id >= num_gen_batches:\n",
        "      break\n",
        "    print('----')\n",
        "    print(f'GENERATION: [{generation_id+1}/{num_gen_batches}]')\n",
        "    ds_hparams = pop.sample_ds_hparams(task)\n",
        "    ds_train = task.get_ds('train', ds_hparams)\n",
        "    ds_validation = task.get_ds('validation', ds_hparams)\n",
        "    paths = []\n",
        "    for i in range(num_devices):\n",
        "      paths.append(pop.sample_path(task, ds_hparams))\n",
        "    train_loop(paths, ds_train, ds_validation, devices, exp_config)\n",
        "    for path in paths:\n",
        "      if path.metrics['improved']:\n",
        "        assert path not in pop.paths\n",
        "        pop.paths[task].append(path)\n",
        "    # Track best path.\n",
        "    curr_best_path = pop.get_best_path(task)\n",
        "    if curr_best_path != best_path:\n",
        "      if best_path:\n",
        "        assert curr_best_path.score() >= best_path.score()\n",
        "      best_path = curr_best_path\n",
        "      best_path.metrics['new_best'] = True\n",
        "      print(f'Best id:{best_path.id}',\n",
        "            f'score:{best_path.score():.4f}',\n",
        "            f'quality:{best_path.metrics[\"quality\"]:.4f}',\n",
        "            f'gen:{generation_id}',\n",
        "            f'\\n{best_path.hparams}')\n",
        "    generation_id += 1\n",
        "    if generation_id < num_gen_batches:\n",
        "      # Skip intermediate state save if last state was written recently.\n",
        "      if (time.time() - LAST_CHECKPOINT_TIME) > SKIP_INTERMEDIATE_STATE_SECS:\n",
        "        save_state(pop, generation_id, loop_id, exp_config)\n",
        "      else:\n",
        "        print('Skip checkpointing, seconds since last save:',\n",
        "              f'{time.time() - LAST_CHECKPOINT_TIME:.0f}')\n",
        "  assert best_path in pop.paths[task]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JVT8nwIWMAVf"
      },
      "outputs": [],
      "source": [
        "TEST_MODELS_IMMUTABILITY = False\n",
        "\n",
        "# Run final eval on test set.\n",
        "def run_test_eval(path, ds_test):\n",
        "  # Running on same device should allow to reuse the fn compiled for validation\n",
        "  # if batch size matches.\n",
        "  params = path.get_all_params()\n",
        "  check_is_local(params)\n",
        "  if not hasattr(path, 'device'):\n",
        "    path.device = random.choice(jax.local_devices())\n",
        "  params_device = jax.device_put(params_comps_to_model(params), path.device)\n",
        "  acc_sum = []\n",
        "  tot_num_samples = 0\n",
        "  # Warning: if repeat() is called on this dataset, then this loop never ends.\n",
        "  for batch in ds_test:\n",
        "    acc_avg = eval_step(\n",
        "        params_device,\n",
        "        batch['image'],\n",
        "        batch['label'],\n",
        "        path.model)\n",
        "    batch_size = batch['image'].shape[0]\n",
        "    # Need to recompute sum because last batch can have different size to allow\n",
        "    # for exact eval on the test set.\n",
        "    acc_sum.append(acc_avg * batch_size)\n",
        "    tot_num_samples += batch_size\n",
        "  del params_device\n",
        "  acc_avg = np.sum(acc_sum) / tot_num_samples\n",
        "  if 'test_quality' in path.metrics and not math.isnan(path.metrics['test_quality']):\n",
        "    assert np.isclose(path.metrics['test_quality'], acc_avg), \\\n",
        "        f'{path.task.name} {path.metrics[\"test_quality\"]} {acc_avg}'\n",
        "  path.metrics['test_quality'] = acc_avg\n",
        "\n",
        "def run_all_test_evals(pop):\n",
        "  eval_st = time.time()\n",
        "  # threads = []\n",
        "  for path in [path for paths in pop.paths.values() for path in paths if path.is_trainable()]:\n",
        "    if 'test_quality' in path.metrics and not math.isnan(path.metrics['test_quality']) and not TEST_MODELS_IMMUTABILITY:\n",
        "      continue\n",
        "    ds_test = path.task.get_ds('test', path.hparams)\n",
        "    run_test_eval(path, ds_test)\n",
        "  print(f'TEST EVAL TIME: {time.time() - eval_st:.2f} s')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WjIr72eO0oBq"
      },
      "outputs": [],
      "source": [
        "def reset_globals(exp_config):\n",
        "  Path.reset_globals(exp_config)\n",
        "  Component.reset_globals()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zc1DDzKl3sjc"
      },
      "outputs": [],
      "source": [
        "def init_population(exp_config:FrozenConfigDict, continue_exp_dir:str = ''):\n",
        "  reset_globals(exp_config)\n",
        "\n",
        "  Path.scorer = globals()[exp_config.scorer_class](**exp_config.scorer_kwargs)\n",
        "  pop = Population(exp_config=exp_config)\n",
        "\n",
        "  def reload_state(load_exp_dir):\n",
        "    pop.paths_df = df_read_from_file(\n",
        "        load_exp_dir,\n",
        "        df_name='paths')\n",
        "    pop.comps_df = df_read_from_file(\n",
        "        load_exp_dir,\n",
        "        df_name='components')\n",
        "    df_reloaded_population = df_read_from_file(\n",
        "        load_exp_dir,\n",
        "        df_name='population')\n",
        "    load_population_from_checkpoint(\n",
        "        pop,\n",
        "        load_exp_dir,\n",
        "        df_reloaded_population)\n",
        "    print('Loaded models from', load_exp_dir, ':')\n",
        "    df_leaderboard(pop_to_df(pop))\n",
        "    Path.counter = 1 + int(pop.paths_df['id'].max())\n",
        "    Component.counter = 1 + int(pop.comps_df['id'].max())\n",
        "\n",
        "  # Load population from previous experiment.\n",
        "  if continue_exp_dir:\n",
        "    reload_state(continue_exp_dir)\n",
        "    return pop\n",
        "  elif exp_config.load_experiment:\n",
        "    reload_state(exp_config.load_experiment_dir)\n",
        "\n",
        "  # Add new seed models.\n",
        "  if exp_config.load_rand_init or exp_config.load_vit_checkpoint:\n",
        "    hparams = exp_config.models_default_hparams.as_configdict()\n",
        "    # Add a randomly initialized model.\n",
        "    if exp_config.load_rand_init:\n",
        "      path0_params = get_vit_params_mapped(\n",
        "          **get_model_kwargs(hparams, exp_config))\n",
        "      path = Path(\n",
        "          hparams,\n",
        "          params2comps(path0_params, train_locks=[NOT_TRAINABLE]),\n",
        "          parent=None,\n",
        "          task=not_trainable)\n",
        "      pop.paths[not_trainable].append(path)\n",
        "    # Add model loaded from checkpoint.\n",
        "    if exp_config.load_vit_checkpoint:\n",
        "      path_params = get_vit_checkpoint_mapped(\n",
        "          hparams['ds_image_size'],\n",
        "          exp_config.load_vit_checkpoint_query)\n",
        "      path = Path(hparams, params2comps(\n",
        "          path_params,\n",
        "          train_locks=[NOT_TRAINABLE]),\n",
        "          parent=None,\n",
        "          task=not_trainable)\n",
        "      pop.paths[not_trainable].append(path)\n",
        "\n",
        "  return pop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3HlykGxjGpvg"
      },
      "outputs": [],
      "source": [
        "def latest_checkpoint(ckpt_dir: str, prefix: str = 'checkpoint_'):\n",
        "  ckpt_dir = os.fspath(ckpt_dir)\n",
        "  glob_path = os.path.join(ckpt_dir, f'{prefix}*')\n",
        "  checkpoint_files = flax_checkpoints.natural_sort(tf.io.gfile.glob(glob_path))\n",
        "  checkpoint_files = [f for f in checkpoint_files if not f.endswith('_tmp')]\n",
        "  return checkpoint_files[-1] if checkpoint_files else None\n",
        "\n",
        "def continue_exp(exp_dir):\n",
        "  # Load configs.\n",
        "  print('CONTINUING EXISTING EXPERIMENT:', exp_dir)\n",
        "  load_config_dict_file = os.path.join(exp_dir, 'config.json')\n",
        "  exp_config = FrozenConfigDict(json.load(\n",
        "      tf.io.gfile.GFile(load_config_dict_file, 'r')))\n",
        "  # Get loop_id from checkpoint file name.\n",
        "  checkpoint_path = latest_checkpoint(exp_dir+'/state_*/')\n",
        "  matched = re.findall(r'checkpoint_([0-9]+)_([0-9]+)$', checkpoint_path)\n",
        "  assert len(matched)==1\n",
        "  generation_id = int(matched[0][1])\n",
        "  loop_id = int(matched[0][0])\n",
        "  pop = init_population(exp_config, continue_exp_dir=exp_dir+f'/state_{loop_id}_{generation_id}/')\n",
        "  print('FROM CHECKPOINT:', loop_id, generation_id)\n",
        "  assert exp_config.experiment_dir == exp_dir\n",
        "  return pop, exp_config, generation_id, loop_id\n",
        "\n",
        "def setup_new_experiment(exp_config):\n",
        "  # Finalize and save config.\n",
        "  exp_config.experiment_id = exp_config.experiment_name \\\n",
        "      + datetime.datetime.strftime(\n",
        "          datetime.datetime.now(), ':%Y-%m-%d-%H-%M-%S')\n",
        "  exp_config.experiment_dir = os.path.join(\n",
        "      exp_config.experiments_root_dir, exp_config.experiment_id)\n",
        "  exp_config = FrozenConfigDict(exp_config)\n",
        "  pop = init_population(exp_config)\n",
        "  print('NEW EXPERIMENT:', exp_config.experiment_dir)\n",
        "  return pop, exp_config, 0, 0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PXCBJxPR4zd1"
      },
      "outputs": [],
      "source": [
        "def setup_exp():\n",
        "  if BENCHMARK == 'ViT tiny 3 layers / Chars benchmark':\n",
        "    exp_config = get_exp_config_ti3_chars()\n",
        "    exp_config.experiment_name += ':t3-chars'\n",
        "  elif BENCHMARK == 'ViT base / VDD benchmark':\n",
        "    exp_config = get_exp_config_base_deca()\n",
        "    exp_config.experiment_name += ':b-deca'\n",
        "  elif BENCHMARK == 'ViT tiny 0 layers / Cmaterdb benchmark':\n",
        "    exp_config = get_exp_config_ti0_cmaterdb()\n",
        "    exp_config.experiment_name += ':t0-cmaterdb'\n",
        "  elif BENCHMARK.startswith('ViT large / '):\n",
        "    exp_config = get_exp_config_large(BENCHMARK)\n",
        "    exp_config.experiment_name += ':large'\n",
        "  else:\n",
        "    assert False, BENCHMARK\n",
        "\n",
        "  if AUTO_TUNE:\n",
        "    assert CONFIGURATION == 'muNet' or CONFIGURATION.startswith('Size scale:')\n",
        "    exp_config.experiment_name += ':autotune'\n",
        "    if BENCHMARK in ['ViT tiny 3 layers / Chars benchmark',\n",
        "                      'ViT base / VDD benchmark']:\n",
        "      exp_config = exp_config_add_auto_tune(exp_config)\n",
        "    elif BENCHMARK in ['ViT tiny 0 layers / Cmaterdb benchmark'] or BENCHMARK.startswith('ViT large / '):\n",
        "      exp_config = exp_config_add_auto_tune_v2(exp_config)\n",
        "    else:\n",
        "      assert False, BENCHMARK\n",
        "\n",
        "  if CONFIGURATION == 'Finetune all':\n",
        "    exp_config = exp_config_set_baseline_finetune_all(exp_config)\n",
        "    exp_config.experiment_name += ':finetune'\n",
        "  elif CONFIGURATION.startswith('Freeze bottom layers'):\n",
        "    num_layers = int(CONFIGURATION.split(':')[1])\n",
        "    exp_config = exp_config_set_baseline_freeze_bottom_layers(\n",
        "        exp_config, num_layers)\n",
        "    exp_config.experiment_name += f':freeze{num_layers}'\n",
        "  elif CONFIGURATION.startswith('Adapters:'):\n",
        "    adapter_dim = int(CONFIGURATION.split(':')[1])\n",
        "    exp_config = exp_config_set_baseline_adapters(exp_config, adapter_dim)\n",
        "    exp_config.experiment_name += f':adapters{adapter_dim}'\n",
        "  elif CONFIGURATION.startswith('Size scale:'):\n",
        "    base_percent = int(CONFIGURATION.split(':')[1])\n",
        "    exp_config = exp_config_set_size_scale(exp_config, base_percent)\n",
        "    exp_config.experiment_name += f':size{base_percent}'\n",
        "  elif CONFIGURATION == 'muNet':\n",
        "    exp_config.experiment_name += f':munet'\n",
        "  else:\n",
        "    assert False, CONFIGURATION\n",
        "\n",
        "  if AUTO_CONTINUE:\n",
        "    exp_dir_prefix = os.path.join(\n",
        "      exp_config.experiments_root_dir, exp_config.experiment_name)\n",
        "    matching_dirs = tf.io.gfile.glob(exp_dir_prefix + '*')\n",
        "    assert len(matching_dirs) < 2, \\\n",
        "        f'Multiple dirs matched for auto restart {matching_dirs}'\n",
        "    if len(matching_dirs) == 1:\n",
        "      print('AUTO CONTINE')\n",
        "      return continue_exp(matching_dirs[0])\n",
        "\n",
        "  return setup_new_experiment(exp_config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uOhykDjptuKj"
      },
      "outputs": [],
      "source": [
        "# Main loop over tasks.\n",
        "pop, exp_config, generation_id, loop_id = setup_exp()\n",
        "\n",
        "devices = jax.local_devices()\n",
        "print('DEVICE COUNT:', len(devices))\n",
        "num_tasks = len(exp_config.task_names)\n",
        "num_loops = exp_config.num_task_iters * num_tasks\n",
        "write_threads = []\n",
        "for _ in range(num_loops):\n",
        "  if loop_id >= num_loops:\n",
        "    break\n",
        "  t_i = loop_id // num_tasks\n",
        "  task_idx = loop_id % num_tasks\n",
        "  task_name = exp_config.task_names[task_idx]\n",
        "  print('\\n\\n====')\n",
        "  print(f'LOOP: [{loop_id+1}/{exp_config.num_task_iters * num_tasks}]')\n",
        "  print(f'TASK: {task_name}')\n",
        "  task = Path.tasks(task_name=task_name)\n",
        "  pop.start_task(task)\n",
        "  task_iter(task, devices, pop, generation_id, loop_id, exp_config)\n",
        "  pop.end_task(task)\n",
        "  loop_id += 1\n",
        "  generation_id = 0\n",
        "\n",
        "  run_all_test_evals(pop)\n",
        "  write_threads = save_state(pop, generation_id, loop_id, exp_config)\n",
        "  # Display stats.\n",
        "  avg_time_per_sample = (\n",
        "      pop.paths_df['metrics.end_time'].mean() \\\n",
        "          - pop.paths_df['metrics.start_time_loop'].mean()\n",
        "      ) / len(devices)\n",
        "  print(f'Avg time per path: {avg_time_per_sample:.2f} s')\n",
        "\n",
        "# Wait for last state write to complete.\n",
        "for t in write_threads:\n",
        "  t.join()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "μ2Net.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}