{
  "cells": [
    {
      "metadata": {
        "id": "SuW4MkPr7isY"
      },
      "cell_type": "markdown",
      "source": [
        "# Customizable full experimentation pipeline\n",
        "\n",
        "This notebook provides (easily configurable) infrastructure to:\n",
        "\n",
        "\n",
        "1.  Pretrain a network (LSTM or Transformer) on data from one of the data generators. Pretraining is optional and can be skipped if you want to investigate weight tuning / prefix tuning of untrained networks.\n",
        "2.  Tune the network using various tuning methods and to a tuning distribution (tuning data generator). This potentially creates a number of tuned models from the same base pretrained network.\n",
        "3.  Evaluate sequential prediction performance of all tuned networks on potentially multiple evaluation data generators. Each evaluation can be repeated a number of times with different random seed (which will produce error bands in the comparison plots).\n",
        "4.  Compare all tuned models on each evaluation data generator. The comparison also includes the non-tuned base model, the Bayes predictor for the respective evaluation data generator (the \"Bayes optimal\" baseline), the Bayes predictor for the pretraining distribution (and a prompt tuned version of it), and untuned (random) prefixes.\n",
        "\n",
        "The full experiment can be customized via configuration dictionaries for each part. The standard settings can be customized via colab forms in the \"Main User Settings\" section below. Additional settings are accessible via the code in the configuration sections that follow.\n",
        "\n",
        "\u003e Comparing many tuning methods over many repetitions (with different seed) can quickly produce dozens of tuning runs. Keep models small and the number of tuning steps low to avoid very long run times.\n",
        "\n",
        "\n",
        "## Tuning methods\n",
        "\n",
        "The following tuning methods are available:\n",
        "* Full weight tuning.\n",
        "* Tuning of embedding or unembedding layer only, or both.\n",
        "* LoRA (transformers only): tuning of additive low rank matrices for all linear layers of the inner transformer block.\n",
        "* Gradient-based tuning of a prompt prefix:\n",
        "  * Simplex: prefix is sequence of vectors in the simplex spanned by the one-hot tokens.\n",
        "  * Real: prefix is sequence of real-valued vectors without constraints.\n",
        "  * Soft: embeddings of prefix tokens are tuned instead of tokens themselves.\n",
        "* Hard prefix: exhaustive search over all hard token sequences.\n",
        "\n",
        "\u003e Hard token prefix tuning is done via exhaustive search, which is only possible for short prefix length. Disable it for very long prefixes.\n",
        "\n",
        "\n",
        "## Data generators\n",
        "\n",
        "Thunnini currently implements three families of categorical distributions; this notebook restricts tokens to be binary (one-hot), leading to the following data generators:\n",
        "*  Single coin: coin with fixed bias (Binomial distribution).\n",
        "*  Mixture of two coins: two coins with different biases and a user definable mixture proportion (mixture of Binomials).\n",
        "*  Random coins: Beta distribution over coin biases (typically set uniform).\n",
        "\n",
        "\n",
        "## Performance metric (regret)\n",
        "\n",
        "The main comparison metric shown in most plots is \"regret\", which is a model's excess prediction error, i.e., expected (cumulative) log loss, relative to the best possible prediction error given by an oracle (the data generator) that knows the emission probabilities in each step. A regret of zero thus does not mean no prediction error, but lowest theoretically achievable prediction error (with hidden knowledge).\n",
        "\n",
        "Let the neural predictor with parameters $\\theta$ be $\\pi_\\theta$. And let the data generator for sequences $x_{1:N}$ of length $N$ be a family of sources parameterized by $\\tau$: $\\xi(x_{1:N}\\vert\\tau)$. For instance, $\\xi(\\cdot \\vert\\tau)$ could be the family of coins with bias $\\tau$. Given a distribution over $\\tau$ (e.g., the uniform distribution), the data distribution is:  \n",
        "$\\xi(x) = \\int \\xi(x\\vert \\tau) p(\\tau) d\\tau$.\n",
        "\n",
        "The expected log loss of the predictor over the data distribution is:  \n",
        "$\\mathcal{L}_\\xi(\\pi_\\theta)=\\mathbb{E}_{\\xi} \\left[ \\sum_{i=1}^N - \\log \\pi_\\theta(x_i \\vert x_{\u003ci}) \\right]$.  \n",
        "In practice, the expectation over $\\xi$ is replaced with the average over a sample $\\{x^{1}, \\ldots, x^{D}\\}$ with $x^{k}\\sim \\xi(\\cdot \\vert \\tau^* \\sim p(\\tau))$, i.e., for each sequence first sample a value of $\\tau$ and then sample the sequence from $\\xi(\\cdot \\vert \\tau)$.\n",
        "\n",
        "The cumulative regret of the predictor is its log loss relative to the best possible (oracle) prediction:  \n",
        "$\\mathcal{R}_\\xi(\\pi_\\theta)=1/D \\sum_{k=1}^D \\left[ \\sum_{i=1}^N -\\log \\pi_\\theta(x^k_i \\vert x^k_{\u003ci}) + \\log \\xi(x^k_i \\vert x^k_{\u003ci}, \\tau=\\tau^*) \\right]$,    \n",
        "where the term in the inner sum is the instantaneous regret per time step, and $\\tau^*$ is the ground-truth (coin bias) that is only known to the data generator.\n",
        "\n",
        "Without additional oracle knowledge, zero regret is generally not achievable in this setting. Instead, the best possible achievable regret is given by the Bayes predictor for the data generator:  \n",
        "$\\pi_{\\text{Bayes}}(x) = \\int \\xi(x\\vert \\tau)  p(\\tau) d\\tau = \\xi(x)$.\n",
        "\n",
        "Through pretraining, neural predictors will (if the architecture has enough capacity and training converges properly) achieve Bayes optimality on their respective pretraining distribution. Their predictions and regrets become indistinguishable from the Bayes predictor (on the pretraining distribution). By fine tuning with data from a particular value of $\\tau^*$, e.g., a coin with a particular bias, the fine tuned models can indirectly gain some or all of the \"oracle info\" (in their weights or via the tuned prefix) and may thus initially outperform the Bayes predictor for the pretraining distribution (not the tuning distribution) - but only for the particular downstream task $\\tau^*$ that the models were tuned for. In an abstract sense, this is what fine tuning aims for: there is additional knowledge over the downstream task(s) that was not available at pretraining time (e.g., that the downstream task is answering questions of a high-school biology exam), and the goal is to make this information available to the tuned predictor (by tuning weights or prompts; or even just coming up with an ad hoc prompt). After a sufficient number of observations, the Bayes predictor will catch up in terms of instantaneous regret, but cumulative regret will keep track of the initial difference of course.\n",
        "\n",
        "\n",
        "\u003e By default, evaluation plots show the median over evaluation repetitions (different seeds for fine tuning runs, i.e., different samples from the tuning data generator and random initial prefixes for prefix tuning) and 25-75 quantile error bars or shaded areas (thus covering 50% of the repetitions). Thin lines show individual repetitions.\n",
        "\n",
        "\n",
        "## Effect of prefixing on internal dynamics\n",
        "\n",
        "The final section of this notebook allows to record the internal states of neural predictors during evaluation. These states are then projected via PCA to 2D, which usually reveals a quite structured representation - typically one dimension aligns with counts of heads/tails and the other dimension with the number of observations. By plotting individual trajectories with various prefixes into the same projection, the effect of the different prefix types can be seen.\n",
        "\n",
        "\u003e This analysis only makes sense for comparing prefix tuning methods, since it is unclear how one would project activations from models with different weights (due to weight tuning or LoRA) into the same lower-dimensional space. Intuitively, different prefixes allow to set different initial points in the same activation space, whereas weight tuning methods reshape the activation space to have better geometry for the downstream task."
      ]
    },
    {
      "metadata": {
        "id": "uMC6_1OShR6p"
      },
      "cell_type": "markdown",
      "source": [
        "# Imports"
      ]
    },
    {
      "metadata": {
        "id": "evZftQd0sJII"
      },
      "cell_type": "code",
      "source": [
        "# @title Global imports\n",
        "import collections\n",
        "import dataclasses\n",
        "import logging\n",
        "import pathlib\n",
        "\n",
        "# Set default logging level\n",
        "logging.basicConfig(level=logging.WARNING,\n",
        "                    force = True)\n",
        "\n",
        "# Utils\n",
        "from matplotlib import pyplot as plt\n",
        "from sklearn.decomposition import PCA\n",
        "\n",
        "# NNs / Linear algebra\n",
        "import numpy as np\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "\n",
        "jax.config.update(\"jax_debug_nans\", False)\n",
        "%matplotlib inline"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "uZVcFQDY32Te"
      },
      "cell_type": "code",
      "source": [
        "# @title Thunnini imports\n",
        "from thunnini.src import builders\n",
        "from thunnini.src import config as config_lib\n",
        "from thunnini.src import evaluation\n",
        "from thunnini.src import plot_utils\n",
        "from thunnini.src import training\n",
        "from thunnini.src import tuning\n",
        "from thunnini.src import types"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "wguslMYdhJjF"
      },
      "cell_type": "markdown",
      "source": [
        "# Experiment Configuration"
      ]
    },
    {
      "metadata": {
        "id": "Q5tnAFlZAJqb"
      },
      "cell_type": "code",
      "source": [
        "# @title Main user settings\n",
        "\n",
        "# @markdown **Architecture**\n",
        "architecture = \"Transformer\"  # @param [\"LSTM\",\"LSTM_untrained\", \"Transformer\", \"Transformer_untrained\", \"Linear\", \"Linear_untrained\"]\n",
        "embedding_dim = 128  # @param {\"type\":\"integer\"}\n",
        "# @markdown ---\n",
        "\n",
        "\n",
        "# @markdown **Pretraining**\n",
        "pretraining_source = \"Random Coins\"  # @param [\"Single Coin\", \"Two-Coin Mixture\", \"Random Coins\"]\n",
        "pretraining_sequence_length = 100  # @param {\"type\":\"integer\"}\n",
        "pretraining_batch_size = 256  # @param {\"type\":\"integer\"}\n",
        "pretraining_num_steps = 1000  # @param {\"type\":\"integer\"}\n",
        "# @markdown ---\n",
        "\n",
        "# @markdown **Tuning**\n",
        "tuning_source = \"Two-Coin Mixture\"  # @param [\"Single Coin\", \"Two-Coin Mixture\", \"Random Coins\"]\n",
        "tuning_sequence_length = 50  # @param {\"type\":\"integer\"}\n",
        "tuning_batch_size = 256  # @param {\"type\":\"integer\"}\n",
        "tuning_num_steps = 1000  # @param {\"type\":\"integer\"}\n",
        "prefix_length = 6  # @param {\"type\":\"integer\"}\n",
        "tuning_num_repetitions = 10  # @param {\"type\":\"integer\"}\n",
        "\n",
        "# @markdown Prefix tuning\n",
        "hard_token_tuning = True  # @param {type:\"boolean\"}\n",
        "simplex_pf_tuning = True  # @param {type:\"boolean\"}\n",
        "real_pf_tuning = True  # @param {type:\"boolean\"}\n",
        "soft_pf_tuning = True  # @param {type:\"boolean\"}\n",
        "# @markdown Fine tuning\n",
        "full_fine_tuning = True  # @param {type:\"boolean\"}\n",
        "lora_tuning = True  # @param {type:\"boolean\"}\n",
        "# @markdown Embedding tuning\n",
        "embedding_tuning = True  # @param {type:\"boolean\"}\n",
        "unembedding_tuning = True  # @param {type:\"boolean\"}\n",
        "un_embedding_tuning = True  # @param {type:\"boolean\"}\n",
        "# @markdown Baselines\n",
        "random_prefix = True  # @param {type:\"boolean\"}\n",
        "pretrain_bayes = True  # @param {type:\"boolean\"}\n",
        "pretrain_bayes_pt = True  # @param {type:\"boolean\"}\n",
        "\n",
        "tuning_names = []\n",
        "if simplex_pf_tuning:\n",
        "  tuning_names.append(\"SimplexPT\")\n",
        "if real_pf_tuning:\n",
        "  tuning_names.append(\"RealPT\")\n",
        "if soft_pf_tuning:\n",
        "  tuning_names.append(\"SoftPT\")\n",
        "if full_fine_tuning:\n",
        "  if architecture.endswith(\"_untrained\"):\n",
        "    full_fine_tuning = False\n",
        "    print(\"Torso non-trainable. Full fine tuning was disabled.\")\n",
        "  else:\n",
        "    tuning_names.append(\"FullWT\")\n",
        "if lora_tuning:\n",
        "  if not architecture.startswith(\"Transformer\"):\n",
        "    raise ValueError(\"LoRA tuning is only supported for transformers.\")\n",
        "  tuning_names.append(\"LoRAWT\")\n",
        "if embedding_tuning:\n",
        "  tuning_names.append(\"EmbedWT\")\n",
        "if unembedding_tuning:\n",
        "  tuning_names.append(\"UnembedWT\")\n",
        "if un_embedding_tuning:\n",
        "  tuning_names.append(\"Un+EmbedWT\")\n",
        "if hard_token_tuning:\n",
        "  tuning_names.append(\"HardPT\")\n",
        "if random_prefix:\n",
        "  tuning_names.append(\"RandomPF\")\n",
        "if pretrain_bayes and architecture.endswith(\"_untrained\"):\n",
        "  print(\"PreBayes is only supported for pretrined models.\")\n",
        "  pretrain_bayes = False\n",
        "if pretrain_bayes_pt and architecture.endswith(\"_untrained\"):\n",
        "  print(\"PreBayesPT is only supported for pretrined models.\")\n",
        "  pretrain_bayes_pt = False\n",
        "# @markdown ---\n",
        "\n",
        "# @markdown **Evaluation**\n",
        "eval_sequence_length = 200  # @param {\"type\":\"integer\"}\n",
        "eval_num_sequences = 2048  # @param {\"type\":\"integer\"}\n",
        "eval_single_coin = True  # @param {type:\"boolean\"}\n",
        "eval_two_coin_mixture = True  # @param {type:\"boolean\"}\n",
        "eval_random_coins = True  # @param {type:\"boolean\"}\n",
        "\n",
        "eval_names = []\n",
        "if eval_single_coin:\n",
        "  eval_names.append(\"Single Coin\")\n",
        "if eval_two_coin_mixture:\n",
        "  eval_names.append(\"Two-Coin Mixture\")\n",
        "if eval_random_coins:\n",
        "  eval_names.append(\"Random Coins\")\n",
        "\n",
        "# @markdown ---\n",
        "store_results = True  # @param {type:\"boolean\"}\n",
        "store_path = \"/tmp/thunnini_exp/\" # @param {type:\"string\"}\n",
        "if store_results:\n",
        "    spath = pathlib.Path(store_path)\n",
        "    spath.mkdir(parents=True, exist_ok=True)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "qcna3ZLHAJqc"
      },
      "cell_type": "code",
      "source": [
        "# @title Data sources general configurations\n",
        "\n",
        "all_data_sources = {\n",
        "    \"Single Coin\": config_lib.CategoricalGeneratorConfig(\n",
        "        batch_size=128,\n",
        "        sequence_length=100,\n",
        "        vocab_size=2,\n",
        "        biases=np.array([0.2, 0.8]),\n",
        "    ),\n",
        "    \"Two-Coin Mixture\": config_lib.MixtureOfCategoricalsGeneratorConfig(\n",
        "        batch_size=128,\n",
        "        sequence_length=100,\n",
        "        vocab_size=2,\n",
        "        biases=np.array([[0.2, 0.8], [0.8, 0.2]]),\n",
        "        mixing_weights=np.array([0.5, 0.5]),\n",
        "    ),\n",
        "    \"Random Coins\": config_lib.DirichletCategoricalGeneratorConfig(\n",
        "        batch_size=128,\n",
        "        sequence_length=100,\n",
        "        vocab_size=2,\n",
        "        alphas=np.array([1, 1]),\n",
        "    ),\n",
        "}"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "taCYc4IdAJqc"
      },
      "cell_type": "code",
      "source": [
        "# @title Predictor configuration\n",
        "predictor_config = config_lib.PredictorConfig(\n",
        "    token_dimensionality=2,\n",
        "    embedding_dimensionality=embedding_dim,\n",
        ")\n",
        "\n",
        "train = False if architecture.endswith(\"_untrained\") else True\n",
        "if architecture.startswith(\"LSTM\"):\n",
        "  torso_config = config_lib.LSTMTorsoConfig(\n",
        "      is_trainable=train, hidden_sizes=[128], return_hidden_states=False\n",
        "  )\n",
        "elif architecture.startswith(\"Transformer\"):\n",
        "  torso_config = config_lib.TransformerTorsoConfig(\n",
        "      is_trainable=train,\n",
        "      hidden_sizes=[128],  # One layer per entry. Only width of MLP\n",
        "      # block is affected though.\n",
        "      num_attention_heads=4,\n",
        "      positional_encoding=\"SinCos\",\n",
        "      return_hidden_states=False,\n",
        "      use_bias=False,\n",
        "      widening_factor=4,\n",
        "      normalize_qk=True,\n",
        "      use_lora=True,\n",
        "      reduced_rank=4,\n",
        "  )\n",
        "elif architecture.startswith(\"Linear\"):\n",
        "  torso_config = config_lib.LinearTorsoConfig(\n",
        "      is_trainable=train,\n",
        "      hidden_sizes=[64, 32],\n",
        "  )\n",
        "else:\n",
        "  raise ValueError(f\"Unknown architecture: {architecture}\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "fjOiFwRSAJqc"
      },
      "cell_type": "code",
      "source": [
        "# @title Pretraining configuration\n",
        "training_data_config = dataclasses.replace(\n",
        "    all_data_sources[pretraining_source],\n",
        "    batch_size=pretraining_batch_size,\n",
        "    sequence_length=pretraining_sequence_length,\n",
        ")\n",
        "\n",
        "training_config = config_lib.TrainingConfig(\n",
        "    num_training_steps=pretraining_num_steps,\n",
        "    learning_rate=1e-3,\n",
        "    max_grad_norm=1.0,\n",
        "    data_gen_seed=0,\n",
        "    predictor_init_seed=0,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "wVt2IRmzAJqc"
      },
      "cell_type": "code",
      "source": [
        "# @title Tuning configuration\n",
        "tuning_data_config = dataclasses.replace(\n",
        "    all_data_sources[tuning_source],\n",
        "    batch_size=tuning_batch_size,\n",
        "    sequence_length=tuning_sequence_length,\n",
        ")\n",
        "\n",
        "# Define all soft prefix tuning methods\n",
        "tuning_config_base = config_lib.TuningConfig(\n",
        "    num_tuning_steps=tuning_num_steps,\n",
        "    learning_rate=5e-3,\n",
        "    max_grad_norm=1.0,\n",
        "    data_gen_seed=10,\n",
        "    prefix_init_seed=11,\n",
        "    tuning_method=\"prefix_real\",\n",
        "    num_tuning_repetitions=tuning_num_repetitions,\n",
        "    prefix_length=prefix_length,\n",
        "    prefix_init_method=\"one_hot\",  # [\"one_hot\", \"simplex\", \"zeros\"]\n",
        "    iterate_datagen_seed_over_repetitions=True,\n",
        ")\n",
        "\n",
        "tuning_configs = collections.OrderedDict()\n",
        "\n",
        "for tuning_name in tuning_names:\n",
        "  match tuning_name:\n",
        "    case \"SimplexPT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base, tuning_method=\"prefix_simplex\"\n",
        "      )\n",
        "    case \"RealPT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base, tuning_method=\"prefix_real\"\n",
        "      )\n",
        "    case \"SoftPT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base, tuning_method=\"prefix_soft\"\n",
        "      )\n",
        "    case \"FullWT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base,\n",
        "          tuning_method=\"full_parameters\",\n",
        "          prefix_length=None,\n",
        "          prefix_init_method=None,\n",
        "      )\n",
        "    case \"LoRAWT\":\n",
        "      if not torso_config.use_lora:\n",
        "        raise ValueError(\"Torso not set to support LoRA.\")\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base,\n",
        "          tuning_method=\"lora_finetune\",\n",
        "          prefix_length=None,\n",
        "          prefix_init_method=None,\n",
        "      )\n",
        "    case \"EmbedWT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base,\n",
        "          tuning_method=\"embedding\",\n",
        "          prefix_length=None,\n",
        "          prefix_init_method=None,\n",
        "      )\n",
        "    case \"UnembedWT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base,\n",
        "          tuning_method=\"unembedding\",\n",
        "          prefix_length=None,\n",
        "          prefix_init_method=None,\n",
        "      )\n",
        "    case \"Un+EmbedWT\":\n",
        "      tuning_configs[tuning_name] = dataclasses.replace(\n",
        "          tuning_config_base,\n",
        "          tuning_method=\"embedding_unembedding\",\n",
        "          prefix_length=None,\n",
        "          prefix_init_method=None,\n",
        "      )"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "bAZufw5JAJqd"
      },
      "cell_type": "code",
      "source": [
        "# @title Evaluation configuration\n",
        "eval_data_configs = {}\n",
        "for eval_name in eval_names:\n",
        "  eval_data_configs[eval_name] = dataclasses.replace(\n",
        "      all_data_sources[eval_name],\n",
        "      batch_size=eval_num_sequences,\n",
        "      sequence_length=eval_sequence_length,\n",
        "  )"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "g8qeLvZ2AJqd"
      },
      "cell_type": "code",
      "source": [
        "# Assign a unique color to each tuning method (for consistent colors)\n",
        "tuning_method_index = collections.OrderedDict([\n",
        "    (\"HardPT\", 0),\n",
        "    (\"SimplexPT\", 1),\n",
        "    (\"RealPT\", 2),\n",
        "    (\"SoftPT\", 3),\n",
        "    (\"FullWT\", 4),\n",
        "    (\"LoRAWT\", 5),\n",
        "    (\"EmbedWT\", 6),\n",
        "    (\"UnembedWT\", 7),\n",
        "    (\"Un+EmbedWT\", 8),\n",
        "    (\"TargetBayes\", 9), (\"EvalBayes\", 9),  # These are equivalent.\n",
        "    (\"PreBayes\", 10),\n",
        "    (\"NoTuning\", 11),\n",
        "    (\"PreBayesPT\", 12),\n",
        "    (\"RandomPF\", 13),\n",
        "    (\"ground_truth\", 14),\n",
        "])\n",
        "\n",
        "default_color_cycler = plt.cycler(\n",
        "    color=[\n",
        "        \"deepskyblue\",  # 0\n",
        "        \"lightseagreen\",  # 1\n",
        "        \"tab:blue\",  # 2\n",
        "        \"navy\",  # 3\n",
        "        \"goldenrod\",  # 4\n",
        "        \"tab:orange\",  # 5\n",
        "        \"sienna\",  # 6\n",
        "        \"tab:red\",  # 7\n",
        "        \"maroon\",  # 8\n",
        "        \"black\",  # 9\n",
        "        \"dimgray\",  # 10\n",
        "        \"olivedrab\",  # 11\n",
        "        \"darkgray\",  # 12\n",
        "        \"limegreen\",  # 13\n",
        "        \"palegreen\"  # 14\n",
        "    ]\n",
        ")\n",
        "\n",
        "# Set the default color cycle\n",
        "plt.rcParams[\"axes.prop_cycle\"] = default_color_cycler\n",
        "# Default plot settings\n",
        "rc_context = {\n",
        "    \"axes.facecolor\": \"whitesmoke\",\n",
        "    \"grid.color\": \"gainsboro\",\n",
        "    \"axes.edgecolor\": \"whitesmoke\",\n",
        "    \"axes.labelcolor\": \"#242c2e\",\n",
        "    \"text.color\": \"#242c2e\",\n",
        "    \"ytick.color\": \"#242c2e\",\n",
        "    \"xtick.color\": \"#242c2e\",\n",
        "    \"legend.edgecolor\": \"none\",\n",
        "    \"axes.labelsize\": \"xx-large\",\n",
        "    \"xtick.labelsize\": \"x-large\",\n",
        "    \"ytick.labelsize\": \"x-large\",\n",
        "    \"axes.titlesize\": \"xx-large\",\n",
        "}"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "12DuBDDjAJqd"
      },
      "cell_type": "code",
      "source": [
        "# @title Write configs to file\n",
        "\n",
        "if store_results:\n",
        "  with open(store_path + \"configs.txt\", \"w\") as f:\n",
        "    print(\"--- Architecture configuration ---\", file=f)\n",
        "    print(predictor_config, file=f)\n",
        "    print(torso_config, file=f)\n",
        "\n",
        "    print(\"\\n--- Training configuration ---\", file=f)\n",
        "    print(training_config, file=f)\n",
        "    print(f\"Training data generator: {pretraining_source}\", file=f)\n",
        "    print(\"\\t\", training_data_config, file=f)\n",
        "\n",
        "    print(\"\\n--- Tuning data configuration ---\", file=f)\n",
        "    print(f\"Tuning data generator: {tuning_source}\", file=f)\n",
        "    print(\"\\t\", tuning_data_config, file=f)\n",
        "    print(\"\\n--- Tuning configurations ---\", file=f)\n",
        "    for tuning_name in tuning_names:\n",
        "      if tuning_name in tuning_configs:\n",
        "        print(f\"\\t{tuning_name}\", file=f)\n",
        "        print(\"\\t\", tuning_configs[tuning_name], file=f)\n",
        "      else:\n",
        "        print(f\"\\t{tuning_name}\", file=f)\n",
        "\n",
        "    print(\"\\n--- Evaluation data configuration ---\", file=f)\n",
        "    for eval_name in eval_names:\n",
        "      print(f\"\\t{eval_name}\", file=f)\n",
        "      print(\"\\t\", eval_data_configs[eval_name], file=f)\n",
        "  print(\"Configs written to\", store_path+\"configs.txt\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "R_dk_obSAJqd"
      },
      "cell_type": "markdown",
      "source": [
        "# Pretraining"
      ]
    },
    {
      "metadata": {
        "id": "YncgYpmpAJqd"
      },
      "cell_type": "code",
      "source": [
        "print(\"--- Architecture configuration ---\")\n",
        "print(predictor_config)\n",
        "print(torso_config)\n",
        "print(\"\\n--- Training configuration ---\")\n",
        "print(training_config)\n",
        "print(training_data_config)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "4HPb7PDsAJqd"
      },
      "cell_type": "code",
      "source": [
        "# @title Pretrain predictor\n",
        "trained_params, train_results = training.train(\n",
        "    training_config=training_config,\n",
        "    predictor_config=predictor_config,\n",
        "    torso_config=torso_config,\n",
        "    data_config=training_data_config,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "HNKp9stlAJqd"
      },
      "cell_type": "code",
      "source": [
        "# Plot loss curve\n",
        "with plt.rc_context(rc_context):\n",
        "  if train_results:\n",
        "    ax = plot_utils.plot_performance_metric(\n",
        "        {architecture: [train_results[\"loss\"]]},\n",
        "        \"Training loss\",\n",
        "        aggregate_fn_only=True,  # No variability band, single repetition.\n",
        "        show_gridlines=True,\n",
        "    )\n",
        "    ax.set_xlabel(\"Training Step\")\n",
        "    ax.set_title(f\"Pretraining on {pretraining_source}.\")\n",
        "    if store_results:\n",
        "      plt.savefig(store_path + \"pretraining_loss_curve.pdf\", bbox_inches=\"tight\")\n",
        "      print(\"Figure written to:\", store_path + \"pretraining_loss_curve.pdf\")\n",
        "  else:\n",
        "    if torso_config.is_trainable:\n",
        "      raise ValueError(\"Training failed but torso is trainable. Aborting.\")\n",
        "    print(\n",
        "        \"Predictor initialized but training skipped since torso is not\"\n",
        "        \" trainable.\"\n",
        "    )"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "BtpXhXdDAJqd"
      },
      "cell_type": "markdown",
      "source": [
        "# Tuning and Evaluation"
      ]
    },
    {
      "metadata": {
        "id": "C1dvdDttAJqd"
      },
      "cell_type": "code",
      "source": [
        "print(\"--- Tuning data configuration ---\")\n",
        "print(tuning_data_config)\n",
        "\n",
        "print(\"\\n--- Tuning methods ---\")\n",
        "for tuning_name in tuning_names:\n",
        "  if tuning_name in tuning_configs:\n",
        "    print(f\"{tuning_name} ({tuning_configs[tuning_name].tuning_method})\")\n",
        "  else:\n",
        "    print(f\"{tuning_name}\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "_C2UYsAIAJqd"
      },
      "cell_type": "code",
      "source": [
        "# @title Run main tuning experiment and evaluations\n",
        "\n",
        "if \"FullWT\" in tuning_configs and not torso_config.is_trainable:\n",
        "  raise ValueError(\"Full weight tuning requires trainable torso.\")\n",
        "\n",
        "# Set logging level so we see tuning experiment progress in colab.\n",
        "logging.basicConfig(level=logging.INFO,\n",
        "                    force = True)\n",
        "\n",
        "results, sequences = tuning.run_tuning_experiment(\n",
        "  predictor_config=predictor_config,\n",
        "  torso_config=torso_config,\n",
        "  predictor_params=trained_params,\n",
        "  tuning_configs=tuning_configs,\n",
        "  tuning_data_config=tuning_data_config,\n",
        "  eval_data_configs=eval_data_configs,\n",
        "  eval_datgen_seed=0,\n",
        "  eval_batching_batch_size=-1,\n",
        "  evaluate_untuned_predictor=True,\n",
        "  return_tuned_prefix=True,\n",
        ")\n",
        "\n",
        "# Reset logging level\n",
        "logging.basicConfig(level=logging.WARNING,\n",
        "                    force = True)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "dQC0XyBhAJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Tune hard prefix (exhaustive search)\n",
        "\n",
        "def int_to_binary_one_hot(number: int, binary_length: int) -\u003e np.ndarray:\n",
        "  \"\"\"Returns one-hot binary representation of number as array.\"\"\"\n",
        "  binary_rep_list = list(np.binary_repr(number).zfill(binary_length))\n",
        "  binary_array = np.array(binary_rep_list, dtype=np.uint8)\n",
        "  one_hot = np.vstack([binary_array, 1 - binary_array]).transpose()\n",
        "  return one_hot\n",
        "\n",
        "\n",
        "def one_hot_binary_to_str(one_hot: np.ndarray) -\u003e str:\n",
        "  \"\"\"Returns string representation of one-hot binary representation.\"\"\"\n",
        "  return \"\".join(one_hot[:, 0].astype(str))\n",
        "\n",
        "\n",
        "if (\"HardPT\" in tuning_names) or pretrain_bayes_pt:\n",
        "  hard_pf_tuning_batch_size = 2048\n",
        "\n",
        "  # Exhaustively generate all hard prefixes\n",
        "  all_hard_prefixes = [\n",
        "      int_to_binary_one_hot(i, prefix_length) for i in range(2**prefix_length)\n",
        "  ]\n",
        "\n",
        "  # Draw sequences for tuning, make sure to hit the batch size as\n",
        "  # tuning_data_config.batch_size may be \u003c hard_pf_tuning_batch_size.\n",
        "  hard_pf_tune_dg = builders.build_datagen(tuning_data_config)\n",
        "  hard_pf_tune_sequences = jnp.array([])\n",
        "  while len(hard_pf_tune_sequences) \u003c hard_pf_tuning_batch_size:\n",
        "    seqs = hard_pf_tune_dg.generate(\n",
        "        rng_key=jax.random.PRNGKey(5), return_ground_truth_log_probs=False\n",
        "    )\n",
        "    if hard_pf_tune_sequences.size \u003e 0:\n",
        "      hard_pf_tune_sequences = jax.numpy.concatenate(\n",
        "          [hard_pf_tune_sequences, seqs], axis=0\n",
        "      )\n",
        "    else:\n",
        "      hard_pf_tune_sequences = seqs\n",
        "  hard_pf_tune_sequences = hard_pf_tune_sequences[:hard_pf_tuning_batch_size]\n",
        "\n",
        "  if \"HardPT\" in tuning_names:\n",
        "    # Evaluate all hard prefixes on tuning \"batch\"\n",
        "    hard_pf_logits, hard_pf_losses = evaluation.evaluate_prefix_list(\n",
        "        prefix_list=all_hard_prefixes,\n",
        "        prefix_type=\"prepend\",\n",
        "        predictor_config=predictor_config,\n",
        "        torso_config=torso_config,\n",
        "        predictor_params=trained_params,\n",
        "        sequences=hard_pf_tune_sequences,\n",
        "        batch_size=-1,  # Set to -1 to evaluate all sequences as single batch\n",
        "    )\n",
        "else:\n",
        "  print(\"Prefix tuning skipped.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "UAHCWODvAJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Process hard prefix tuning results\n",
        "\n",
        "if \"HardPT\" in tuning_names:\n",
        "  all_losses = np.array(hard_pf_losses)\n",
        "  cum_losses = np.sum(all_losses, axis=-1)\n",
        "  avg_cum_losses = np.mean(cum_losses, axis=-1)\n",
        "\n",
        "  sort_inds = np.argsort(avg_cum_losses, axis=0)\n",
        "  sorted_avg_cum_losses = avg_cum_losses[sort_inds]\n",
        "  sorted_hard_prefixes = np.array(all_hard_prefixes)[sort_inds]\n",
        "\n",
        "  best_hard_prefix = sorted_hard_prefixes[0]\n",
        "  best_hard_prefix_str = one_hot_binary_to_str(best_hard_prefix)\n",
        "\n",
        "  print(\n",
        "      \"Best mean cumulative loss:\",\n",
        "      sorted_avg_cum_losses[0],\n",
        "      \"(\",\n",
        "      best_hard_prefix_str,\n",
        "      \")\",\n",
        "  )\n",
        "  print(\n",
        "      \"Worst mean cumulative loss:\",\n",
        "      sorted_avg_cum_losses[-1],\n",
        "      \"(\",\n",
        "      one_hot_binary_to_str(sorted_hard_prefixes[-1]),\n",
        "      \")\",\n",
        "  )\n",
        "else:\n",
        "  print(\"Hard prefix tuning skipped.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "VmRbTstUAJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Plot hard prefix tuning results\n",
        "\n",
        "if \"HardPT\" in tuning_names:\n",
        "  with plt.rc_context(rc_context):\n",
        "    fig = plt.figure(figsize=(15, 3.5))\n",
        "    ax = fig.gca()\n",
        "    xvec = range(len(all_hard_prefixes))\n",
        "    ax.bar(xvec, sorted_avg_cum_losses, zorder=3)\n",
        "    ax.set_xticks(xvec)\n",
        "    ax.set_xticklabels(map(one_hot_binary_to_str, sorted_hard_prefixes))\n",
        "    ax.tick_params(axis=\"x\", labelrotation=90)\n",
        "    ax.set_ylabel(\"Cumulative loss [nats]\")\n",
        "    ax.set_xlabel(\"Hard prefix\")\n",
        "    if torso_config.is_trainable:\n",
        "      ax.set_title(\n",
        "          f\"Hard prefix tuning ({hard_pf_tuning_batch_size} trajectories)\\n{architecture}: {pretraining_source} → {tuning_source}.\"\n",
        "      )\n",
        "    else:\n",
        "      ax.set_title(\n",
        "          f\"Hard prefix tuning ({hard_pf_tuning_batch_size} trajectories)\\n{architecture} → {tuning_source}.\"\n",
        "      )\n",
        "    ax.grid(True, axis=\"y\", zorder=0)\n",
        "  if store_results:\n",
        "    plt.savefig(store_path + \"hard_prefix_tuning.pdf\", bbox_inches=\"tight\")\n",
        "    print(\"Figure written to:\", store_path + \"hard_prefix_tuning.pdf\")\n",
        "else:\n",
        "  print(\"Hard prefix tuning skipped.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "XYzhNSryAJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Evaluate best hard prefix and add to results\n",
        "\n",
        "if \"HardPT\" in tuning_names:\n",
        "  results[\"HardPT\"] = collections.OrderedDict()\n",
        "  results[\"HardPT\"][\"tuned_prefix\"] = best_hard_prefix\n",
        "  for eval_name, eval_sequences in sequences.items():\n",
        "    hard_pf_eval_logits, hard_pf_eval_losses = evaluation.evaluate_prefix_list(\n",
        "        prefix_list=[best_hard_prefix],\n",
        "        prefix_type=\"prepend\",\n",
        "        predictor_config=predictor_config,\n",
        "        torso_config=torso_config,\n",
        "        predictor_params=trained_params,\n",
        "        sequences=eval_sequences,\n",
        "        batch_size=-1,  # Set to -1 to evaluate all sequences as single batch\n",
        "    )\n",
        "    results[\"HardPT\"][eval_name] = {}\n",
        "    results[\"HardPT\"][eval_name][\"logits\"] = hard_pf_eval_logits\n",
        "    results[\"HardPT\"][eval_name][\"losses\"] = hard_pf_eval_losses\n",
        "else:\n",
        "  print(\"Hard prefix tuning skipped.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "HtyDl8JlAJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Evaluate untuned real prefixes and add to results (as a control)\n",
        "\n",
        "if \"RandomPF\" in tuning_names:\n",
        "  # Collect all initial prefixes from tuning\n",
        "  all_init_prefixes = results[\"RealPT\"][\"initial_prefix\"]\n",
        "\n",
        "  results[\"RandomPF\"] = collections.OrderedDict()\n",
        "  results[\"RandomPF\"][\"initial_prefix\"] = all_init_prefixes\n",
        "  for eval_name, eval_sequences in sequences.items():\n",
        "    init_pf_logits, init_pf_losses = evaluation.evaluate_prefix_list(\n",
        "        prefix_list=all_init_prefixes,\n",
        "        prefix_type=\"prepend\",\n",
        "        predictor_config=predictor_config,\n",
        "        torso_config=torso_config,\n",
        "        predictor_params=trained_params,\n",
        "        sequences=eval_sequences,\n",
        "        batch_size=-1,  # Set to -1 to evaluate all sequences as single batch\n",
        "    )\n",
        "    results[\"RandomPF\"][eval_name] = {}\n",
        "    results[\"RandomPF\"][eval_name][\"logits\"] = init_pf_logits\n",
        "    results[\"RandomPF\"][eval_name][\"losses\"] = init_pf_losses"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "lzlYCjOK3Qqa"
      },
      "cell_type": "code",
      "source": [
        "# @title Evaluate Bayesian predictor for pretraining distr. on eval distr.\n",
        "\n",
        "if pretrain_bayes or pretrain_bayes_pt:\n",
        "  pretraining_dg = builders.build_datagen(training_data_config)\n",
        "  if pretrain_bayes:\n",
        "    results[\"PreBayes\"] = collections.OrderedDict()\n",
        "    for eval_name, eval_sequences in sequences.items():\n",
        "      # Evaluate Pretraining Bayes\n",
        "      preb_logits, preb_losses = pretraining_dg.solve(eval_sequences)\n",
        "      results[\"PreBayes\"][eval_name] = {}\n",
        "      results[\"PreBayes\"][eval_name][\"logits\"] = preb_logits\n",
        "      results[\"PreBayes\"][eval_name][\"losses\"] = preb_losses\n",
        "\n",
        "  if pretrain_bayes_pt:\n",
        "    results[\"PreBayesPT\"] = collections.OrderedDict()\n",
        "    for eval_name, eval_sequences in sequences.items():\n",
        "      # Prefix tune Pretraining Bayes\n",
        "      best_prefix = None\n",
        "      best_prefix_cum_loss = np.inf\n",
        "      for prefix in all_hard_prefixes:\n",
        "        prefixes = jnp.tile(prefix, (hard_pf_tune_sequences.shape[0], 1, 1))\n",
        "        _, prebpt_losses = pretraining_dg.solve(\n",
        "            jnp.concatenate((prefixes, hard_pf_tune_sequences), axis=1)\n",
        "        )\n",
        "        prebpt_losses = prebpt_losses[:, prefix_length:]\n",
        "        prebt_cum_loss = jnp.mean(jnp.sum(prebpt_losses, axis=-1), axis=-1)\n",
        "        if prebt_cum_loss \u003c best_prefix_cum_loss:\n",
        "          best_prefix = prefix\n",
        "          best_prefix_cum_loss = prebt_cum_loss\n",
        "\n",
        "      print(\"Best prefix (PreBayesPT):\", one_hot_binary_to_str(best_prefix))\n",
        "      print(\"Best prefix cumulative loss (PreBayesPT):\", best_prefix_cum_loss, \"nats.\")\n",
        "      prebpt_logits, prebpt_losses = pretraining_dg.solve(\n",
        "          jnp.concatenate((jnp.tile(best_prefix, (eval_sequences.shape[0], 1, 1)), eval_sequences), axis=1)\n",
        "      )\n",
        "      results[\"PreBayesPT\"][eval_name] = {}\n",
        "      results[\"PreBayesPT\"][eval_name][\"logits\"] = prebpt_logits[:, prefix_length:, :]\n",
        "      results[\"PreBayesPT\"][eval_name][\"losses\"] = prebpt_losses[:, prefix_length:]"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "ghaEohJeAJqe"
      },
      "cell_type": "markdown",
      "source": [
        "# Visualization of all tuning results"
      ]
    },
    {
      "metadata": {
        "id": "r3a538d8AJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Postprocess all results (compute regrets and prepare for plotting)\n",
        "\n",
        "# Compute regrets\n",
        "(gt_losses, instant_regrets, cumulative_regrets, tuning_loss_curves) = (\n",
        "    plot_utils.postprocess_tuning_experiment_results(\n",
        "        results, tuning_names, eval_names\n",
        "    )\n",
        ")\n",
        "\n",
        "\n",
        "# Rename 'Bayes' to 'TargetBayes' or 'EvalBayes' and move HardPT and NoTuning\n",
        "# to group prefix tuning methods and baselines.\n",
        "for eval_name in eval_names:\n",
        "  if 'Bayes' in results:  # this will also move to end\n",
        "    if eval_name == tuning_source:\n",
        "      instant_regrets[eval_name]['TargetBayes'] = instant_regrets[eval_name].pop('Bayes')\n",
        "      cumulative_regrets[eval_name]['TargetBayes'] = cumulative_regrets[eval_name].pop('Bayes')\n",
        "    else:\n",
        "      instant_regrets[eval_name]['EvalBayes'] = instant_regrets[eval_name].pop('Bayes')\n",
        "      cumulative_regrets[eval_name]['EvalBayes'] = cumulative_regrets[eval_name].pop('Bayes')\n",
        "  if pretrain_bayes:\n",
        "    # Manually process PreBayes results (not covered by postprocess above)\n",
        "    preb_losses = results[\"PreBayes\"][eval_name][\"losses\"]\n",
        "    instant_regret_preb = jnp.mean(preb_losses-gt_losses[eval_name], axis=0)\n",
        "    instant_regrets[eval_name][\"PreBayes\"] = [instant_regret_preb]\n",
        "    cumulative_regrets[eval_name][\"PreBayes\"] = [np.cumsum(instant_regret_preb)]\n",
        "  if 'HardPT' in results:\n",
        "    instant_regrets[eval_name].move_to_end('HardPT', last=False)\n",
        "    cumulative_regrets[eval_name].move_to_end('HardPT', last=False)\n",
        "  if 'NoTuning' in results:\n",
        "    instant_regrets[eval_name].move_to_end('NoTuning')\n",
        "    cumulative_regrets[eval_name].move_to_end('NoTuning')\n",
        "  if pretrain_bayes_pt:\n",
        "    # Manually process PreBayesPT results (not covered by postprocess above)\n",
        "    prebpt_losses = results[\"PreBayesPT\"][eval_name][\"losses\"]\n",
        "    instant_regret_prebpt = jnp.mean(prebpt_losses-gt_losses[eval_name], axis=0)\n",
        "    instant_regrets[eval_name][\"PreBayesPT\"] = [instant_regret_prebpt]\n",
        "    cumulative_regrets[eval_name][\"PreBayesPT\"] = [np.cumsum(instant_regret_prebpt)]\n",
        "  if 'RandomPF' in results:\n",
        "    instant_regrets[eval_name].move_to_end('RandomPF')\n",
        "    cumulative_regrets[eval_name].move_to_end('RandomPF')"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "tXOI--MtTHoE"
      },
      "cell_type": "code",
      "source": [
        "# @title Construct legends grouped per method\n",
        "\n",
        "pt_methods = []\n",
        "wt_methods = []\n",
        "baselines = []\n",
        "for mname in tuning_method_index:\n",
        "  if mname in tuning_names:\n",
        "    if mname.endswith(\"PT\"):\n",
        "      pt_methods.append(mname)\n",
        "    elif mname.endswith(\"WT\"):\n",
        "      wt_methods.append(mname)\n",
        "  else:\n",
        "    if mname == \"TargetBayes\":\n",
        "      baselines.append(\"TargetBayes\")\n",
        "    elif mname == \"PreBayes\" and pretrain_bayes:\n",
        "      baselines.append(\"PreBayes\")\n",
        "    elif mname == \"NoTuning\":\n",
        "      baselines.append(\"NoTuning\")\n",
        "    elif mname == \"PreBayesPT\" and pretrain_bayes_pt:\n",
        "      baselines.append(\"PreBayesPT\")\n",
        "\n",
        "if \"RandomPF\" in tuning_names:\n",
        "    baselines.append(\"RandomPF\")\n",
        "\n",
        "n_pt_methods = len(pt_methods)\n",
        "n_wt_methods = len(wt_methods)\n",
        "n_baselines = len(baselines)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "pyVMqef6AJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Tuning loss curves\n",
        "\n",
        "colors = [f\"C{tuning_method_index[name]}\" for name in tuning_loss_curves]\n",
        "\n",
        "with plt.rc_context(rc_context):\n",
        "  ax = plot_utils.plot_performance_metric(\n",
        "      tuning_loss_curves,\n",
        "      \"Tuning loss\",\n",
        "      aggregate_fn_only=False,\n",
        "      show_gridlines=True,\n",
        "      show_individual_lines=True,\n",
        "      colors=colors,\n",
        "  )\n",
        "  if torso_config.is_trainable:\n",
        "    ax.set_title(\n",
        "        f\"{architecture}: {pretraining_source} → {tuning_source}\"\n",
        "    )\n",
        "  else:\n",
        "    ax.set_title(f\"{architecture}; untrained → {tuning_source}\")\n",
        "  ax.set_xlabel(\"Tuning step\")\n",
        "  ax.get_legend().set_title(\"Tuning method\")\n",
        "\n",
        "if store_results:\n",
        "  plt.savefig(store_path + \"tuning_loss_curves.pdf\", bbox_inches=\"tight\")\n",
        "  print(\"Figure written to:\", store_path + \"tuning_loss_curves.pdf\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "YAtmOT-LAJqe"
      },
      "cell_type": "code",
      "source": [
        "# @title Evaluation performance of tuned models\n",
        "\n",
        "with plt.rc_context(rc_context):\n",
        "  fig, axes = plt.subplots(\n",
        "      nrows=2, ncols=len(eval_names), figsize=(6.5 * len(eval_names), 8)\n",
        "  )\n",
        "  for i, eval_name in enumerate(eval_names):\n",
        "    if len(axes.shape) == 1:\n",
        "      ax = axes[0]\n",
        "      ax2 = axes[1]\n",
        "    else:\n",
        "      ax = axes[0, i]\n",
        "      ax2 = axes[1, i]\n",
        "    colors = [\n",
        "        f\"C{tuning_method_index[name]}\"\n",
        "        for name in cumulative_regrets[eval_name]\n",
        "    ]\n",
        "    plot_utils.plot_performance_metric(\n",
        "        instant_regrets[eval_name],\n",
        "        \"Instant regret [nats]\",\n",
        "        axis=ax,\n",
        "        aggregate_fn_only=False,\n",
        "        show_gridlines=True,\n",
        "        show_individual_lines=True,\n",
        "        colors=colors,\n",
        "    )\n",
        "    ax.set_title(\"Evaluation on \" + eval_name)\n",
        "    ax.set_xlabel(\"\")\n",
        "    ax.get_legend().remove()\n",
        "\n",
        "    plot_utils.plot_performance_metric(\n",
        "        cumulative_regrets[eval_name],\n",
        "        \"Cumulative regret [nats]\",\n",
        "        axis=ax2,\n",
        "        aggregate_fn_only=False,\n",
        "        show_gridlines=True,\n",
        "        show_individual_lines=True,\n",
        "        colors=colors,\n",
        "    )\n",
        "    if i \u003e 0:\n",
        "      ax2.get_legend().set_visible(False)\n",
        "      ax.set_ylabel(\"\")\n",
        "      ax2.set_ylabel(\"\")\n",
        "    else:\n",
        "      leg = ax2.legend(\n",
        "          #title=\"Tuning method\",\n",
        "          ncols=np.ceil(len(colors)/(4-len(eval_names))),\n",
        "          loc=\"lower center\",\n",
        "          bbox_to_anchor=(0.5, -0.05),\n",
        "          bbox_transform=fig.transFigure,\n",
        "          edgecolor=\"dimgray\",\n",
        "      )\n",
        "    ax2.set_xlabel(\"Step $n$\")\n",
        "\n",
        "  # Add markers for tuning- and pretraining-length\n",
        "  for ax in axes.flatten():\n",
        "    ylims = ax.get_ylim()\n",
        "    ax.vlines(\n",
        "        [\n",
        "            tuning_data_config.sequence_length - 1,\n",
        "            training_data_config.sequence_length - 1,\n",
        "            eval_sequence_length - 1,\n",
        "        ],\n",
        "        ymin=ylims[0],\n",
        "        ymax=ylims[1],\n",
        "        colors=[\"black\"],\n",
        "        linewidth=1.5,\n",
        "        linestyle=\":\",\n",
        "    )\n",
        "    ax.text(\n",
        "        x=tuning_data_config.sequence_length,\n",
        "        y=ylims[1],\n",
        "        s=\"$N_\\\\text{tune}$\",\n",
        "        rotation=0,\n",
        "        ha=\"center\",\n",
        "        va=\"top\",\n",
        "        fontsize=11,\n",
        "        backgroundcolor=\"gainsboro\",\n",
        "        color=\"#242c2e\",\n",
        "    )\n",
        "    ax.text(\n",
        "        x=training_data_config.sequence_length,\n",
        "        y=ylims[1],\n",
        "        s=\"$N_\\\\text{train}$\",\n",
        "        rotation=0,\n",
        "        ha=\"center\",\n",
        "        va=\"top\",\n",
        "        fontsize=11,\n",
        "        backgroundcolor=\"gainsboro\",\n",
        "        color=\"#242c2e\",\n",
        "    )\n",
        "    tb = ax.text(\n",
        "        x=eval_sequence_length,\n",
        "        y=ylims[1],\n",
        "        s=\"$N_\\\\text{eval}$\",\n",
        "        rotation=0,\n",
        "        ha=\"center\",\n",
        "        va=\"top\",\n",
        "        fontsize=11,\n",
        "        backgroundcolor=\"gainsboro\",\n",
        "        color=\"#242c2e\",\n",
        "    )\n",
        "\n",
        "  if torso_config.is_trainable:\n",
        "    fig.suptitle(\n",
        "        f\"{architecture}: {pretraining_source} → {tuning_source}\",\n",
        "        fontsize=18,\n",
        "        y=0.95,\n",
        "    )\n",
        "  else:\n",
        "    fig.suptitle(\n",
        "      f\"{architecture} → {tuning_source}\",\n",
        "      fontsize=18,\n",
        "      y=0.95,\n",
        "    )\n",
        "\n",
        "if store_results:\n",
        "  plt.savefig(store_path + \"evaluation_results.pdf\", bbox_inches=\"tight\")\n",
        "  print(\"Figure written to:\", store_path + \"evaluation_results.pdf\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "urNw-CRgAJqe"
      },
      "cell_type": "code",
      "source": [
        "#@title Quantitative performance at important locations\n",
        "\n",
        "share_y_limits = True  # Share y-axis limits across subplots?\n",
        "n_timesteps = 3  # Same as bar_locations and bar_labels\n",
        "\n",
        "with plt.rc_context(rc_context):\n",
        "  for eval_name in eval_names:\n",
        "    fig, axes = plt.subplots(\n",
        "        ncols=n_timesteps,\n",
        "        nrows=1,\n",
        "        figsize=(5.5 * n_timesteps, 4),\n",
        "        sharey=share_y_limits,\n",
        "    )\n",
        "    bar_locations = [\n",
        "        tuning_data_config.sequence_length - 1,\n",
        "        training_data_config.sequence_length - 1,\n",
        "        eval_data_configs[eval_name].sequence_length - 1,\n",
        "    ]\n",
        "    bar_labels = [\n",
        "        \"$N_\\\\text{tune}$\",\n",
        "        \"$N_\\\\text{train}$\",\n",
        "        \"$N_\\\\text{eval}$\"\n",
        "    ]\n",
        "\n",
        "    colors = [\n",
        "        f\"C{tuning_method_index[name]}\"\n",
        "        for name in cumulative_regrets[eval_name]\n",
        "    ]\n",
        "    for j, bar_location in enumerate(bar_locations):\n",
        "      ax = axes[j]\n",
        "      plot_utils.plot_performance_metric(\n",
        "          cumulative_regrets[eval_name],\n",
        "          \"Cumulative regret [nats]\",\n",
        "          axis=ax,\n",
        "          bar_plot=True,\n",
        "          bar_plot_line_index=bar_location,\n",
        "          aggregate_fn_only=False,\n",
        "          show_gridlines=True,\n",
        "          show_individual_lines=False,\n",
        "          colors=colors,\n",
        "      )\n",
        "      for i, label in enumerate(ax.get_xticklabels()):\n",
        "        label.set_color(colors[i])\n",
        "\n",
        "      ax.set_title(f\"$n={bar_location}$ (\" + bar_labels[j] + \")\")\n",
        "      if j \u003e 0:\n",
        "        ax.set_ylabel(\"\")\n",
        "\n",
        "    if torso_config.is_trainable:\n",
        "      fig.suptitle(\n",
        "          f\"{architecture}: {pretraining_source} → {tuning_source}\\nEvaluation on {eval_name}\",\n",
        "          y=1.10,\n",
        "          fontsize=18,\n",
        "      )\n",
        "    else:\n",
        "      fig.suptitle(\n",
        "          f\"{architecture} → {tuning_source}\\nEvaluation on {eval_name}\",\n",
        "          y=1.10,\n",
        "          fontsize=18,\n",
        "      )\n",
        "\n",
        "    # Add markers for separating tuning method types / baselines.\n",
        "    ylim_top = 0\n",
        "    for ax in axes:\n",
        "      yl_top = ax.get_ylim()[1]\n",
        "      if yl_top \u003e ylim_top:\n",
        "        ylim_top = yl_top\n",
        "    for ax in axes:\n",
        "      margins = ax.margins()\n",
        "      marg_increase = 0.15\n",
        "      ax.margins(margins[0], margins[1]+marg_increase)  # make vertical space by increasing margins\n",
        "      if not share_y_limits:\n",
        "        ylim_top = ax.get_ylim()[1]\n",
        "      ylim_marg = ylim_top * (1+margins[1]+marg_increase)\n",
        "      ax.margins(0.05, 0.05)\n",
        "      separator_locations = np.array([\n",
        "              0,\n",
        "              n_pt_methods,\n",
        "              n_wt_methods + n_pt_methods,\n",
        "              n_wt_methods + n_pt_methods + n_baselines,\n",
        "          ])-0.5\n",
        "      separator_strings = [\"Prefix T.\", \"Weight T.\", \"Baselines\"]\n",
        "      ax.vlines(\n",
        "          separator_locations[1:],\n",
        "          ymin=ax.get_ylim()[0],\n",
        "          ymax=ylim_marg,\n",
        "          colors=[\"black\"],\n",
        "          linewidth=1.5,\n",
        "          linestyle=\":\",\n",
        "      )\n",
        "      sep_widths = (separator_locations[1:] - separator_locations[:-1])/2.0\n",
        "      sep_centers = separator_locations[:-1] + sep_widths\n",
        "      for i, location in enumerate(sep_centers):\n",
        "        tb = ax.text(\n",
        "            x=location,\n",
        "            y=ylim_marg,\n",
        "            s=separator_strings[i],\n",
        "            rotation=0,\n",
        "            ha=\"center\",\n",
        "            va=\"top\",\n",
        "            fontsize=11,\n",
        "            backgroundcolor=\"gainsboro\",\n",
        "            color=\"#242c2e\",\n",
        "            fontweight=\"heavy\",\n",
        "        )\n",
        "        tb.set_bbox(dict(color=\"gainsboro\", alpha=0.45))\n",
        "\n",
        "    if store_results:\n",
        "      path = store_path + \"evaluation_results_detail_\" + eval_name.replace(\" \", \"_\") + \".pdf\"\n",
        "      plt.savefig(path, bbox_inches=\"tight\")\n",
        "      print(\"Figure written to:\", path)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "SUYjisigAJqe"
      },
      "cell_type": "markdown",
      "source": [
        "# Internal dynamics"
      ]
    },
    {
      "metadata": {
        "id": "hVJCwMGQAJqe"
      },
      "cell_type": "code",
      "source": [
        "# Collect states of untuned network on sequences from training distribution\n",
        "torso_config_w_states = dataclasses.replace(\n",
        "    torso_config, return_hidden_states=True\n",
        ")\n",
        "predictor = builders.build_predictor(predictor_config, torso_config_w_states)\n",
        "\n",
        "training_datagen = builders.build_datagen(training_data_config)\n",
        "training_sequences, training_log_probs = training_datagen.generate(\n",
        "    rng_key=jax.random.PRNGKey(5), return_ground_truth_log_probs=True\n",
        ")\n",
        "training_probs = np.exp(training_log_probs)\n",
        "\n",
        "_, training_states, _, training_prefix_states = (\n",
        "    evaluation.predictions_and_states_from_sequences(\n",
        "        predictor_config=None,\n",
        "        torso_config=None,\n",
        "        predictor_params=trained_params,\n",
        "        sequences=training_sequences,\n",
        "        predictor_instance=predictor,\n",
        "        prefix_type=\"none\",\n",
        "        prefix=None,\n",
        "    )\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "-3lqDGCKAJqe"
      },
      "cell_type": "code",
      "source": [
        "# Helper function to concatenate states across layers and reduce datapoints\n",
        "def concat_states(\n",
        "    states: types.Hidden,\n",
        "    prefix_states: types.PrefixHidden,\n",
        "    state_string: str,\n",
        "    dim_red_no_sequences: int,\n",
        "    dim_red_sequence_length: int,\n",
        ") -\u003e np.ndarray | None:\n",
        "  \"\"\"Concatenate states across layers and selects subset.\"\"\"\n",
        "  states_concat = None\n",
        "  if states is None:\n",
        "    return None\n",
        "\n",
        "  for name in states:\n",
        "    if name.endswith(state_string):\n",
        "      # Concatenate prefix (if given) with states on time axis\n",
        "      if prefix_states is None:\n",
        "        state_sequence = states[name]\n",
        "      else:\n",
        "        state_sequence = np.concat([prefix_states[name], states[name]], axis=1)\n",
        "\n",
        "      # Concatenate across layers on feature axis\n",
        "      if states_concat is None:\n",
        "        states_concat = state_sequence[\n",
        "            :dim_red_no_sequences, :dim_red_sequence_length\n",
        "        ]\n",
        "      else:\n",
        "        states_concat = np.concatenate(\n",
        "            [\n",
        "                states_concat,\n",
        "                state_sequence[:dim_red_no_sequences, :dim_red_sequence_length],\n",
        "            ],\n",
        "            axis=-1,\n",
        "        )\n",
        "\n",
        "  return states_concat"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "zGJuCFlRAJqe"
      },
      "cell_type": "code",
      "source": [
        "# Dimensionality reduction (PCA) of internal state\n",
        "dim_red_no_sequences = 100\n",
        "dim_red_sequence_length = 50\n",
        "\n",
        "state_string = None\n",
        "if isinstance(torso_config, config_lib.LSTMTorsoConfig):\n",
        "  state_string = \"cell\"  # \"cell\" or \"hidden\"\n",
        "elif isinstance(torso_config, config_lib.TransformerTorsoConfig):\n",
        "  state_string = \"attention_out\"\n",
        "else:\n",
        "  raise ValueError(\"Invalid torso type for state analysis.\")\n",
        "\n",
        "# Concatenate states across layers and reduce datapoints\n",
        "trainging_states_concat = concat_states(\n",
        "    training_states,\n",
        "    training_prefix_states,\n",
        "    state_string,\n",
        "    dim_red_no_sequences,\n",
        "    dim_red_sequence_length,\n",
        ")\n",
        "# Flatten non-feature dimension\n",
        "training_states_flat = trainging_states_concat.reshape(\n",
        "    -1, trainging_states_concat.shape[-1]\n",
        ")\n",
        "\n",
        "# Perform PCA\n",
        "model = PCA(n_components=2, whiten=True)\n",
        "training_states_projected = model.fit_transform(training_states_flat)\n",
        "\n",
        "# Unflatten non-feature dimensions\n",
        "training_states_projected = training_states_projected.reshape(\n",
        "    dim_red_no_sequences, dim_red_sequence_length, -1\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "qKUCviBCAJqf"
      },
      "cell_type": "code",
      "source": [
        "# Plot PC projection of non-prefixed net with training sequences\n",
        "\n",
        "dark_plot = True  # Dark background? Disable for print-friendly version\n",
        "if dark_plot:\n",
        "  rc_context_internal = {\n",
        "      \"figure.facecolor\": \"#1b1e21\",\n",
        "      \"text.color\": \"lightgray\",\n",
        "      \"legend.facecolor\": \"dimgray\",\n",
        "      \"legend.edgecolor\": \"none\",\n",
        "      \"ytick.color\": \"lightgray\",\n",
        "      \"axes.labelcolor\": \"lightgray\",\n",
        "  }\n",
        "else:\n",
        "  rc_context_internal = rc_context\n",
        "\n",
        "with plt.rc_context(rc_context_internal):\n",
        "  # Two subplots, one that colors step index and one that colors gt probability\n",
        "  fig, axes = plt.subplots(\n",
        "      nrows=1, ncols=2, figsize=(10, 4), sharex=True, sharey=True\n",
        "  )\n",
        "  for j, ax in enumerate(axes):\n",
        "    for i in range(dim_red_no_sequences):\n",
        "      # Individual sequences as gray lines and colored scatter points\n",
        "      ax.plot(\n",
        "          training_states_projected[i, :, 0],\n",
        "          training_states_projected[i, :, 1],\n",
        "          color=\"grey\",\n",
        "          marker=None,\n",
        "          alpha=0.1,\n",
        "          linewidth=1,\n",
        "          zorder=-1,\n",
        "      )\n",
        "      if j == 0:\n",
        "        # Color by step index.\n",
        "        sc_color = np.linspace(0, 1, dim_red_sequence_length)\n",
        "        sc_cmap = plt.cm.plasma\n",
        "      else:\n",
        "        # Color by ground-truth probability\n",
        "        sc_color = training_probs[i, :dim_red_sequence_length, 0]\n",
        "        sc_cmap = plt.cm.viridis\n",
        "      sc = ax.scatter(\n",
        "          training_states_projected[i, :, 0],\n",
        "          training_states_projected[i, :, 1],\n",
        "          c=sc_color,\n",
        "          marker=\".\",\n",
        "          alpha=0.5,\n",
        "          zorder=3,\n",
        "          cmap=sc_cmap,\n",
        "          vmin=0,\n",
        "          vmax=1,\n",
        "          s=12,\n",
        "      )\n",
        "      if i == dim_red_no_sequences - 1:\n",
        "        if j == 0:\n",
        "          cbar = plt.colorbar(sc, label=\"Step $n$\", ticks=[0, 1], aspect=35)\n",
        "          cbar.ax.set_yticklabels([\"0\", f\"{dim_red_sequence_length}\"])\n",
        "        else:\n",
        "          cbar = plt.colorbar(\n",
        "              sc,\n",
        "              label=\"Ground-truth probability $\\\\tau$\",\n",
        "              ticks=[0, 0.5, 1],\n",
        "              aspect=35,\n",
        "          )\n",
        "    ax.axis(\"off\")\n",
        "    cbar.ax.yaxis.set_ticks_position(\"left\")\n",
        "    cbar.solids.set(alpha=1)\n",
        "    cbar.outline.set_visible(False)\n",
        "    cbar.ax.invert_yaxis()\n",
        "\n",
        "  if torso_config.is_trainable:\n",
        "    fig.suptitle(\n",
        "        f\"{architecture} pretrained on {pretraining_source}\\nState ({state_string}) on {dim_red_no_sequences} {pretraining_source} sequences (length: {dim_red_sequence_length})\",\n",
        "        fontsize=18,\n",
        "    )\n",
        "  else:\n",
        "    fig.suptitle(\n",
        "        f\"{architecture}\\nState ({state_string}) on {dim_red_no_sequences} {pretraining_source} sequences (length: {dim_red_sequence_length})\",\n",
        "        fontsize=18,\n",
        "    )\n",
        "  fig.tight_layout()\n",
        "\n",
        "if store_results:\n",
        "  plt.savefig(store_path + \"internal_states.pdf\", bbox_inches=\"tight\")\n",
        "  print(\"Figure written to:\", store_path + \"internal_states.pdf\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "YuwawVSpAJqf"
      },
      "cell_type": "code",
      "source": [
        "# Evaluate various prefixes (on all eval datagens) and plot projected state.\n",
        "no_sequences_to_show = (\n",
        "    20  # spread evenly across repetitions; at least 1 per repetition.\n",
        ")\n",
        "pca_results = {}\n",
        "\n",
        "# Only compare prefix methods that use same weights as above\n",
        "method_exclude_list = [\n",
        "    \"TargetBayes\",\n",
        "    \"EvalBayes\",\n",
        "    \"Bayes\",\n",
        "    \"PreBayes\",\n",
        "    \"PreBayesPT\",\n",
        "    \"ground_truth\",\n",
        "    \"FullWT\",\n",
        "    \"LoRAWT\",\n",
        "    \"EmbedWT\",\n",
        "    \"UnembedWT\",\n",
        "    \"Un+EmbedWT\",\n",
        "]\n",
        "\n",
        "if dark_plot:\n",
        "  init_color = \"white\"\n",
        "else:\n",
        "  init_color = \"tomato\"\n",
        "\n",
        "with plt.rc_context(rc_context_internal):\n",
        "  for eval_name in eval_names:\n",
        "    # draw sequences from eval generator\n",
        "    eval_datagen = builders.build_datagen(eval_data_configs[eval_name])\n",
        "    eval_sequences = eval_datagen.generate(\n",
        "        rng_key=jax.random.PRNGKey(5), return_ground_truth_log_probs=False\n",
        "    )\n",
        "\n",
        "    for i, method in enumerate(results):\n",
        "      if method in method_exclude_list:\n",
        "        continue\n",
        "\n",
        "      # Figure out prefix-type and prefix\n",
        "      match method:\n",
        "        case \"NoTuning\":\n",
        "          prefix_type = \"none\"\n",
        "          prefixes = [None]\n",
        "        case \"RandomPF\":\n",
        "          prefix_type = \"prepend\"\n",
        "          prefixes = results[method][\"initial_prefix\"]\n",
        "        case \"HardPT\":\n",
        "          prefix_type = \"prepend\"\n",
        "          prefixes = [results[method][\"tuned_prefix\"]]\n",
        "        case \"SimplexPT\":\n",
        "          prefix_type = \"simplex\"\n",
        "          prefixes = results[method][\"tuned_prefix\"]\n",
        "        case \"RealPT\":\n",
        "          prefix_type = \"prepend\"\n",
        "          prefixes = results[method][\"tuned_prefix\"]\n",
        "        case \"SoftPT\":\n",
        "          prefix_type = \"embedding\"\n",
        "          prefixes = results[method][\"tuned_prefix\"]\n",
        "        case _:\n",
        "          raise ValueError(f\"Unknown or unsupported tuning method: {method}.\")\n",
        "\n",
        "      # Evaluate prefix on sequences, preprocess, and project to PCs\n",
        "      eval_projected_states = []\n",
        "      no_sequences_per_repetition = int(\n",
        "          np.ceil(no_sequences_to_show / len(prefixes))\n",
        "      )\n",
        "      print(\n",
        "          f\"Processing: {method} ({len(prefixes)} repetitions,\"\n",
        "          f\" {no_sequences_per_repetition} sequences per repetition)\"\n",
        "      )\n",
        "      for prefix in prefixes:\n",
        "        _, eval_states, _, eval_prefix_states = (\n",
        "            evaluation.predictions_and_states_from_sequences(\n",
        "                predictor_config=None,\n",
        "                torso_config=None,\n",
        "                predictor_params=trained_params,\n",
        "                sequences=eval_sequences,\n",
        "                predictor_instance=predictor,  # Reuse predictor instance from above.\n",
        "                prefix_type=prefix_type,\n",
        "                prefix=prefix,\n",
        "            )\n",
        "        )\n",
        "        # Concatenate states across layers and reduce datapoints\n",
        "        eval_states_concat = concat_states(\n",
        "            eval_states,\n",
        "            eval_prefix_states,\n",
        "            state_string,\n",
        "            no_sequences_per_repetition,\n",
        "            dim_red_sequence_length,\n",
        "        )\n",
        "        # Flatten non-feature dimension\n",
        "        eval_states_flat = eval_states_concat.reshape(\n",
        "            -1, eval_states_concat.shape[-1]\n",
        "        )\n",
        "        # Perform PCA\n",
        "        eval_states_projected = model.transform(eval_states_flat)\n",
        "        # Unflatten non-feature dimensions\n",
        "        eval_states_projected = eval_states_projected.reshape(\n",
        "            no_sequences_per_repetition, dim_red_sequence_length, -1\n",
        "        )\n",
        "        eval_projected_states.append(eval_states_projected)\n",
        "      pca_results[method] = np.concatenate(eval_projected_states, axis=0)\n",
        "\n",
        "    # Plot all projections\n",
        "    n_subplots = len(pca_results)\n",
        "    fig, axes = plt.subplots(\n",
        "        nrows=1,\n",
        "        ncols=n_subplots,\n",
        "        figsize=(2.8 * n_subplots, 4),\n",
        "        sharex=True,\n",
        "        sharey=True,\n",
        "    )\n",
        "    for j, method in enumerate(pca_results):\n",
        "      ax = axes[j]\n",
        "      ax.set_title(f\"{method}\")\n",
        "\n",
        "      for k in range(dim_red_no_sequences):\n",
        "        # Plot grid from training distribution\n",
        "        (pretrain_lines,) = ax.plot(\n",
        "            training_states_projected[k, :, 0],\n",
        "            training_states_projected[k, :, 1],\n",
        "            color=\"grey\",\n",
        "            marker=None,\n",
        "            alpha=0.1,\n",
        "            linewidth=1,\n",
        "            zorder=-1,\n",
        "        )\n",
        "\n",
        "      color = plot_utils.get_color_by_index(tuning_method_index[method])\n",
        "      for seq in pca_results[method]:\n",
        "        (eval_lines,) = ax.plot(\n",
        "            seq[:, 0],\n",
        "            seq[:, 1],\n",
        "            color=color,\n",
        "            marker=\".\",\n",
        "            alpha=0.6,\n",
        "            linewidth=1,\n",
        "            zorder=-1,\n",
        "        )\n",
        "        # Initial state\n",
        "        if method == \"NoTuning\":\n",
        "          (init_marker,) = ax.plot(\n",
        "              seq[0, 0],\n",
        "              seq[0, 1],\n",
        "              marker=\"o\",\n",
        "              ls=\"\",\n",
        "              color=init_color,\n",
        "              fillstyle=\"none\",\n",
        "          )\n",
        "        else:\n",
        "          (init_marker,) = ax.plot(\n",
        "              seq[prefix_length, 0],\n",
        "              seq[prefix_length, 1],\n",
        "              marker=\"o\",\n",
        "              ls=\"\",\n",
        "              color=init_color,\n",
        "              fillstyle=\"none\",\n",
        "          )\n",
        "      ax.axis(\"off\")\n",
        "\n",
        "    init_marker.set_label(\"$n=L_{\\\\text{prefix}}\"+f\"={prefix_length}\"+\"$\")\n",
        "    eval_lines.set_label(f\"{eval_name}\")\n",
        "    pretrain_lines.set_label(f\"{pretraining_source} (no prefix)\")\n",
        "    legend = fig.legend(frameon=True, fontsize=14, ncols=3, loc=\"lower right\")\n",
        "    for lh in legend.legend_handles:\n",
        "      lh.set_alpha(1)\n",
        "    if torso_config.is_trainable:\n",
        "      fig.suptitle(\n",
        "          f\"{architecture}: {pretraining_source} → {tuning_source}\\nState ({state_string}) on {no_sequences_to_show} {eval_name} sequences (length: {dim_red_sequence_length})\",\n",
        "          fontsize=18,\n",
        "      )\n",
        "    else:\n",
        "      fig.suptitle(\n",
        "          f\"{architecture} → {tuning_source}\\nState ({state_string}) on {no_sequences_to_show} {eval_name} sequences (length: {dim_red_sequence_length})\",\n",
        "          fontsize=18,\n",
        "      )\n",
        "    fig.tight_layout()\n",
        "\n",
        "if store_results:\n",
        "  plt.savefig(store_path + \"internal_states_prefix_trajectories.pdf\", bbox_inches=\"tight\")\n",
        "  print(\"Figure written to:\", store_path + \"internal_states_prefix_trajectories.pdf\")"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python/gpu:ml_notebook",
        "kind": "private"
      },
      "name": "ThunniniExperiment.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
