{
  "cells": [
    {
      "metadata": {
        "id": "zxMmL36TBwUS"
      },
      "cell_type": "markdown",
      "source": [
        "# Thunnini Demo\n",
        "\n",
        "This notebook briefly showcases Thunnini's main features. The notebook pretrains a network (LSTM or Transformer) on a distribution over coins with random uniform bias (using the `DirichletCategorical` data generator). This network is then fine tuned on a mixture of two coins with bias 0.2 and 0.8 via soft prefix tuning (i.e., by tuning the embeddings of the first 6 tokens). The tuned prefix is then used during evaluation on the same mixture of two coins, and finally performance of the tuned predictor is compared against the Bayes predictor for the two-coin mixture, the pretrained network, and the untrained network.\n",
        "\n",
        "The main aim of this notebook is to showcase how easy it is to set up predictors, data generators and pretraining, tuning, and evaluation with Thunnini.\n",
        "See `ThunniniExperiment.ipynb` for a much more comprehensive notebook that features most of Thunnini's functionality and wraps it into an easily configurable interface."
      ]
    },
    {
      "metadata": {
        "id": "uMC6_1OShR6p"
      },
      "cell_type": "markdown",
      "source": [
        "# Imports"
      ]
    },
    {
      "metadata": {
        "id": "evZftQd0sJII"
      },
      "cell_type": "code",
      "source": [
        "# @title Global imports\n",
        "\n",
        "import dataclasses\n",
        "\n",
        "# Utils\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "# NNs / Linear algebra\n",
        "import numpy as np\n",
        "import jax\n",
        "\n",
        "jax.config.update(\"jax_debug_nans\", False)\n",
        "%matplotlib inline"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "uZVcFQDY32Te"
      },
      "cell_type": "code",
      "source": [
        "# @title Thunnini imports\n",
        "from thunnini.src import builders\n",
        "from thunnini.src import config as config_lib\n",
        "from thunnini.src import evaluation\n",
        "from thunnini.src import plot_utils\n",
        "from thunnini.src import training\n",
        "from thunnini.src import tuning"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "wguslMYdhJjF"
      },
      "cell_type": "markdown",
      "source": [
        "# Experiment Configurations"
      ]
    },
    {
      "metadata": {
        "id": "6U4dtDqLgrrn"
      },
      "cell_type": "code",
      "source": [
        "#@title Predictor configuration\n",
        "\n",
        "embedding_dim = 16\n",
        "torso_type = \"LSTM\" #\"LSTM\", \"Transformer\"\n",
        "hidden_sizes = [64, 32]\n",
        "\n",
        "\n",
        "predictor_config = config_lib.PredictorConfig(\n",
        "    token_dimensionality=2,  # binary tokens\n",
        "    embedding_dimensionality=embedding_dim,\n",
        ")\n",
        "\n",
        "\n",
        "if torso_type == \"LSTM\":\n",
        "  torso_config = config_lib.LSTMTorsoConfig(\n",
        "    is_trainable=True,\n",
        "    hidden_sizes=hidden_sizes,\n",
        "    return_hidden_states=False\n",
        ")\n",
        "else:\n",
        "  torso_config = config_lib.TransformerTorsoConfig(\n",
        "    is_trainable=True,\n",
        "    hidden_sizes=hidden_sizes,\n",
        "    num_attention_heads=4,\n",
        "    positional_encoding = 'SinCos',\n",
        "    return_hidden_states=False,\n",
        "    use_bias=False,\n",
        "    widening_factor=4,\n",
        "    normalize_qk=True,\n",
        "    use_lora=True,\n",
        "    reduced_rank=4,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "K8Ex14OsMDij"
      },
      "cell_type": "code",
      "source": [
        "#@title Training configuration\n",
        "\n",
        "training_data_config = config_lib.DirichletCategoricalGeneratorConfig(\n",
        "    batch_size=128,\n",
        "    sequence_length=50,\n",
        "    vocab_size=2,\n",
        "    alphas=np.array([1, 1]),\n",
        ")\n",
        "\n",
        "training_config = config_lib.TrainingConfig(\n",
        "    num_training_steps=1000,\n",
        "    learning_rate=5e-3,\n",
        "    max_grad_norm=1.0,\n",
        "    data_gen_seed=0,\n",
        "    predictor_init_seed=0,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "2GCZN_dkH05N"
      },
      "cell_type": "code",
      "source": [
        "#@title Tuning configuration\n",
        "\n",
        "tuning_data_config = config_lib.MixtureOfCategoricalsGeneratorConfig(\n",
        "    batch_size=128,\n",
        "    sequence_length=50,\n",
        "    vocab_size=2,\n",
        "    biases=np.array([[0.2, 0.8], [0.8, 0.2]]),\n",
        "    mixing_weights=np.array([0.25, 0.75]),\n",
        ")\n",
        "\n",
        "tuning_config = config_lib.TuningConfig(\n",
        "    num_tuning_steps=1000,\n",
        "    learning_rate=5e-3,\n",
        "    max_grad_norm=1.0,\n",
        "    data_gen_seed=10,\n",
        "    prefix_init_seed=10,\n",
        "    tuning_method=\"prefix_soft\",\n",
        "    prefix_length=6,\n",
        "    prefix_init_method=\"one_hot\",\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "aoOLKl1Ag_k8"
      },
      "cell_type": "code",
      "source": [
        "#@title Evaluation configuration\n",
        "\n",
        "# Evaluation will be on a single batch of this generator, so we choose\n",
        "# a large batch.\n",
        "eval_data_config = config_lib.MixtureOfCategoricalsGeneratorConfig(\n",
        "    batch_size=1024,\n",
        "    sequence_length=100,\n",
        "    vocab_size=2,\n",
        "    biases=np.array([[0.2, 0.8], [0.8, 0.2]]),\n",
        "    mixing_weights=np.array([0.25, 0.75]),\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "DU2LsMOphDFN"
      },
      "cell_type": "markdown",
      "source": [
        "# Pretraining"
      ]
    },
    {
      "metadata": {
        "id": "ctpIZqiaMd_3"
      },
      "cell_type": "code",
      "source": [
        "#@title Pretrain predictor\n",
        "trained_params, train_results = training.train(\n",
        "    training_config=training_config,\n",
        "    predictor_config=predictor_config,\n",
        "    torso_config=torso_config,\n",
        "    data_config=training_data_config,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "jt3JMFHjoHAe"
      },
      "cell_type": "code",
      "source": [
        "# Plot training loss curve\n",
        "ax = plot_utils.plot_performance_metric(\n",
        "    {torso_type: [train_results['loss']]},\n",
        "    'Training loss',\n",
        "    aggregate_fn_only = True,  # No variability band needed, single repetition.\n",
        "    show_gridlines = True,\n",
        ")\n",
        "ax.set_xlabel('Training Step')\n",
        "ax.set_title('Pretraining on ' + training_data_config.generator_type)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "x1E0dlJLMJ9b"
      },
      "cell_type": "code",
      "source": [
        "# @title Manually using the pretrained predictor\n",
        "\n",
        "# The code below demonstrates how to manually use the a predictor and\n",
        "# evaluate it on some sequences. Thunnini also has convenience functions that\n",
        "# encapsulate this, see, e.g., `evaluation.evaluate_predictor_from_datagen`.\n",
        "\n",
        "# Instantiate data generator and sample a batch.\n",
        "datagen_tmp = builders.build_datagen(training_data_config)\n",
        "batch_tmp = datagen_tmp.generate(\n",
        "    rng_key=jax.random.PRNGKey(1337),\n",
        "    return_ground_truth_log_probs=False\n",
        "    )\n",
        "\n",
        "# We'll also change the torso config to return the hidden states.\n",
        "torso_config_tmp = dataclasses.replace(torso_config, return_hidden_states=True)\n",
        "# The predictor is stateless - simply build a new instance.\n",
        "predictor_tmp = builders.build_predictor(predictor_config, torso_config_tmp)\n",
        "# Run a forward pass.\n",
        "logits, hidden_states, prefix_logits, prefix_hidden = predictor_tmp.apply(\n",
        "      trained_params, sequences=batch_tmp, prefix_type='None', prefix=None\n",
        "      )\n",
        "predictor_log_losses = datagen_tmp.instant_log_loss_from_logits(\n",
        "      logits, batch_tmp\n",
        "      )\n",
        "\n",
        "print('Instant log loss shape:', predictor_log_losses.shape)\n",
        "print('Hidden states dict keys:', hidden_states.keys())\n",
        "\n",
        "print('Prefix logits:', prefix_logits)  # Will be None - no prefix used.\n",
        "print('Prefix hidden states:', prefix_hidden)  # Will be None - no prefix used."
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "yqBAumj6KJ8m"
      },
      "cell_type": "markdown",
      "source": [
        "# Tuning"
      ]
    },
    {
      "metadata": {
        "id": "y-mqSajpjCk8"
      },
      "cell_type": "code",
      "source": [
        "#@title Tune pretrained predictor\n",
        "tuned_params, tuned_prefix, tuning_results = tuning.tune(\n",
        "    tuning_config=tuning_config,\n",
        "    predictor_config=predictor_config,\n",
        "    torso_config=torso_config,\n",
        "    predictor_params=trained_params,\n",
        "    data_config=tuning_data_config,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "PACyQocmqele"
      },
      "cell_type": "code",
      "source": [
        "# Plot tuning loss curve\n",
        "ax = plot_utils.plot_performance_metric(\n",
        "    {torso_type: [tuning_results['loss']]},\n",
        "    'Tuning loss',\n",
        "    aggregate_fn_only = True,  # No variability band needed, single repetition.\n",
        "    show_gridlines = True,\n",
        ")\n",
        "ax.set_xlabel('Tuning Step')\n",
        "ax.set_title(tuning_config.tuning_method + ' tuning on ' + tuning_data_config.generator_type)"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "ZziHpwyt4Ssu"
      },
      "cell_type": "markdown",
      "source": [
        "# Evaluation"
      ]
    },
    {
      "metadata": {
        "id": "kxeSUGZH_g4h"
      },
      "cell_type": "code",
      "source": [
        "#@title Evaluate pretrained predictor\n",
        "\n",
        "eval_results_trained = evaluation.evaluate_predictor_from_datagen(\n",
        "    predictor_config=predictor_config,\n",
        "    torso_config=torso_config,\n",
        "    predictor_params=trained_params,\n",
        "    datagen_config=eval_data_config,\n",
        "    datagen_seed=1337,\n",
        "    datagen_num_batches=1,\n",
        "    return_gt_and_optimal_results=True,\n",
        ")\n",
        "sequences, trained_logits, trained_log_losses, bo_log_probs, bo_losses, gt_log_probs, gt_losses = eval_results_trained"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "X7a8NbXBLzxg"
      },
      "cell_type": "code",
      "source": [
        "#@title Evaluate tuned predictor on the same sequences\n",
        "\n",
        "tuned_logits, tuned_log_losses = evaluation.evaluate_predictor_from_sequences(\n",
        "    predictor_config=predictor_config,\n",
        "    torso_config=torso_config,\n",
        "    predictor_params=tuned_params,\n",
        "    prefix_type = \"embedding\",\n",
        "    prefix = tuned_prefix,\n",
        "    sequences=sequences,\n",
        "    batch_size=-1,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "ax5ASI-lVuFN"
      },
      "cell_type": "code",
      "source": [
        "#@title Construct and evaluate untrained predictor on the same sequences\n",
        "untrained_predictor = builders.build_predictor(predictor_config, torso_config)\n",
        "untrained_params = untrained_predictor.init(\n",
        "    rngs=jax.random.PRNGKey(815),\n",
        "    sequences=sequences[0:10],  # Take some sequences as dummy sequences\n",
        ")\n",
        "\n",
        "untrained_logits, untrained_log_losses = evaluation.evaluate_predictor_from_sequences(\n",
        "    predictor_config=predictor_config,\n",
        "    torso_config=torso_config,\n",
        "    predictor_params=untrained_params,\n",
        "    sequences=sequences,\n",
        "    batch_size = -1,\n",
        ")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "x4ujM-gaS5_L"
      },
      "cell_type": "code",
      "source": [
        "#@title Compute regrets\n",
        "instant_regret = {\n",
        "    'Bayes-optimal (' + eval_data_config.generator_type + ')': [np.mean(bo_losses - gt_losses, axis=0)],\n",
        "    'Pretrained ' + torso_type + ' (' + training_data_config.generator_type + ')': [np.mean(trained_log_losses - gt_losses, axis=0)],\n",
        "    'Tuned ' + torso_type + ' (' + tuning_data_config.generator_type + ')': [np.mean(tuned_log_losses - gt_losses, axis=0)],\n",
        "    'Untrained ' + torso_type: [np.mean(untrained_log_losses - gt_losses, axis=0)],\n",
        "    }\n",
        "\n",
        "cumulative_regret = {}\n",
        "for model, regret in instant_regret.items():\n",
        "  cumulative_regret[model] = [np.cumsum(regret)]"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "3CyTXS_1lC8x"
      },
      "cell_type": "code",
      "source": [
        "# Plot evaluation results\n",
        "fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(9, 8))\n",
        "\n",
        "plot_utils.plot_performance_metric(\n",
        "    instant_regret,\n",
        "    'Instant regret [nats]',\n",
        "    axis=axes[0],\n",
        "    aggregate_fn_only=True,  # No variability band needed, single repetition.\n",
        "    show_gridlines=True,\n",
        "    )\n",
        "axes[0].set_title('Evaluation on ' + eval_data_config.generator_type)\n",
        "axes[0].set_xlabel('')\n",
        "axes[0].get_legend().remove()\n",
        "\n",
        "plot_utils.plot_performance_metric(\n",
        "    cumulative_regret,\n",
        "    'Cumulative regret [nats]',\n",
        "    axis=axes[1],\n",
        "    aggregate_fn_only=True,  # No variability band needed, single repetition.\n",
        "    show_gridlines=True,\n",
        "    )"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {
        "id": "xqKO5OA88cQT"
      },
      "cell_type": "code",
      "source": [
        "# @title Show one trajectory\n",
        "\n",
        "fig = plt.figure(figsize=(9, 4))\n",
        "seq_len = sequences.shape[1]\n",
        "xvec = np.arange(seq_len)\n",
        "\n",
        "# Plot observations from the first eval sequence\n",
        "plt.plot(xvec, sequences[0, :, 0], '.', label='Observations')\n",
        "# Plot the ground-truth probability (gt_log_probs are the same for each timestep)\n",
        "plt.hlines(np.exp(gt_log_probs[0,0,0]), xmin=0, xmax=seq_len, label='Ground truth', color='goldenrod', linewidth=4)\n",
        "\n",
        "# Plot predictions\n",
        "plt.plot(xvec, jax.nn.softmax(trained_logits[0, :, :])[:,0], label='Pretrained ' + torso_type, linewidth=2)\n",
        "plt.plot(xvec, jax.nn.softmax(tuned_logits[0, :, :])[:,0], label='Tuned' + torso_type, linewidth=2)\n",
        "plt.plot(xvec, jax.nn.softmax(untrained_logits[0, :, :])[:,0], label='Untrained ' + torso_type, linewidth=2)\n",
        "plt.plot(xvec, jax.nn.softmax(bo_log_probs[0, :, :])[:,0], label='Bayes-optimal', color='C0', linewidth=2)\n",
        "\n",
        "plt.legend()\n",
        "plt.xlabel('Step')\n",
        "plt.yticks([0, 0.5, 1])\n",
        "plt.grid('on')\n",
        "plt.ylabel('Probability')\n",
        "plt.title('Single trajectory predictions')"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python/gpu:ml_notebook",
        "kind": "private"
      },
      "name": "ThunniniDemo.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
