{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ccd9e14d-1b63-4d49-8795-7905dea1c122",
   "metadata": {},
   "source": [
    "# The EGGROLL Recipe from Scratch\n",
    "\n",
    "In this notebook, we reconstruct a simple version of our primary Jax Codebase to help people get started with their own creations using EGGROLL. This also serves as a detailed description of all the \"quirks\" of the EGGROLL codebase.\n",
    "\n",
    "You can run this notebook on Google Colab with any device type (CPU, GPU, TPU) or on your own system as long as jax, optax, and matplotlib are installed.\n",
    "\n",
    "## High-level Description of EGGROLL\n",
    "\n",
    "**EGGROLL** stands for Evolution Guided General Optimization via Low-rank Learning. We recommend reading the project homepage and [Lilian Weng's blog post on ES](https://lilianweng.github.io/posts/2019-09-05-evolution-strategies/) for a more detailed overview of EGGROLL and Evolution Strategies in general.\n",
    "\n",
    "At a high level, the key idea is to use *low-rank perturbations*, calculate *fitnesses*, and construct an update using the fitness-weighted sum of perturbations.\n",
    "\n",
    "$$ \\nabla_{\\theta}\\mathbb{E}_{\\epsilon_1, \\epsilon_2 \\sim N(0,I_{d})} F(\\theta+\\sigma\\epsilon_2 \\epsilon_1^T) = \\frac{1}{\\sigma}\\mathbb{E}_{\\epsilon_1,\\epsilon_2\\sim N(0,I_{d})}\\{F(\\theta+\\sigma\\epsilon_2 \\epsilon_1^T)\\epsilon_2 \\epsilon_1^T\\} $$\n",
    "\n",
    "Naively computing $F(\\theta+\\sigma\\epsilon_2 \\epsilon_1^T)$ by materializing the full $\\epsilon_2\\epsilon_1^T$ matrix (which would be as large as the original matrix) would be very inefficient since it results in an expensive batched matrix multiplication. We would instead like to convert it into a regular matrix multiplication and cheap batched low-rank products.\n",
    "\n",
    "This codebase is designed to easily extend EGGROLL to new problem settings by automatically handling the conversion of matrix multiplication to the low-rank updates with on-the-fly noise reconstructions.\n",
    "\n",
    "## Part 1: The \"Noiser\"\n",
    "\n",
    "Efficient EGGROLL requires converting the internals of matrix multiplications (between inputs and weights) into the \"perturbed\" matrix multiplication as defined by EGGROLL. For this, we define a **Noiser** which handles the way fundamental operations should be handled by our ES algorithm.\n",
    "\n",
    "Below we show the *base noiser*, which performs no perturbations, but highlights the general interface of a noiser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a962efe5-31ce-48d4-852f-d63ee71d5684",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Noiser:\n",
    "    @classmethod\n",
    "    def init_noiser(cls, params, sigma, lr, *args, solver=None, solver_kwargs=None, **kwargs):\n",
    "        \"\"\"\n",
    "        params: parameters of the model\n",
    "        sigma: initial sigma of noiser\n",
    "        lr: learning rate of optimization process\n",
    "        solver (optional, keyword arg): the optax solver to use (i.e. optax.adamw)\n",
    "        solver_kwargs (optional, keyword arg): the optax solver's keyword arguments (excluding learning rate)\n",
    "        \n",
    "        Return frozen_noiser_params and noiser_params\n",
    "        \"\"\"\n",
    "        return {}, {}\n",
    "    \n",
    "    @classmethod\n",
    "    def do_mm(cls, frozen_noiser_params, noiser_params, param, base_key, iterinfo, x):\n",
    "        return x @ param.T\n",
    "\n",
    "    @classmethod\n",
    "    def do_Tmm(cls, frozen_noiser_params, noiser_params, param, base_key, iterinfo, x):\n",
    "        return x @ param\n",
    "\n",
    "    @classmethod\n",
    "    def get_noisy_standard(cls, frozen_noiser_params, noiser_params, param, base_key, iterinfo):\n",
    "        return param\n",
    "\n",
    "    @classmethod\n",
    "    def convert_fitnesses(cls, frozen_noiser_params, noiser_params, raw_scores):\n",
    "        return raw_scores\n",
    "\n",
    "    @classmethod\n",
    "    def do_updates(cls, frozen_noiser_params, noiser_params, params, base_keys, fitnesses, iterinfos, es_map):\n",
    "        return noiser_params, params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6108cd33-3650-4426-9390-c78959757bb8",
   "metadata": {},
   "source": [
    "There are a handful of \"quirks\" here to keep in mind. First, note that all methods are \"class methods\" which is an intentional choice to preserve the functional style of jax. Throughout this codebase, classes are collections of functions and do not store any internal state, and subclasses can inherit and override these functions, preventing code duplication. (The only exception to the above are NamedTuples which only store data but do not contain logic)\n",
    "\n",
    "You may notice that the output of init_noiser is \"frozen_noiser_params\" and \"noiser_params\" and may ask why there are two separate outputs. Essentially, to preserve the \"jit\" compilation capabilities of jax, we only want to keep \"frozen\" data that the compiler should be aware of within frozen_noiser_params while other components, like large arrays or components that change, should be in noiser_params.\n",
    "\n",
    "The key methods of Noiser for modifying the model behavior are do_mm, do_Tmm, and get_noisy_standard, which return noisy versions of these fundamental operations. The \"get_noisy_standard\" operation just noises a standard parameter that isn't involved in matmuls (like the bias and weight parameters of a layernorm). For each of these operations, you will notice two additional inputs: base_key and iterinfo.\n",
    "\n",
    "The base_key is a jax PRNG key that is the \"base\" key for this parameter. The iterinfo is either None or is a tuple of (epoch, thread_id) to indicate which perturbation is currently being worked on. Note that \"epoch\" refers to the number of update steps that have already occurred and \"thread_id\" refers to the perturbation number within a single epoch.\n",
    "\n",
    "**IMPORTANT:** iterinfo must be either None or a tuple of two ints. If you want to calculate multiple perturbations in parallel (which is the key component of EGGROLL), you need to use jax vmap, which we will explain in our worked-out example.\n",
    "\n",
    "## Part 2: The \"Model\"\n",
    "\n",
    "In our codebase, instead of directly interact with the Noiser, we define Models, which define how some network component is initialized and its forward pass.\n",
    "\n",
    "Below we show the abstract base model and its general interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "885b1dde-79fe-4330-a4b0-fce8920052cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import NamedTuple\n",
    "\n",
    "PARAM = 0\n",
    "MM_PARAM = 1\n",
    "EXCLUDED=3\n",
    "\n",
    "class CommonInit(NamedTuple):\n",
    "    frozen_params: any\n",
    "    params: any\n",
    "    scan_map: any\n",
    "    es_map: any\n",
    "\n",
    "class CommonParams(NamedTuple):\n",
    "    noiser: any\n",
    "    frozen_noiser_params: any\n",
    "    noiser_params: any\n",
    "    frozen_params: any\n",
    "    params: any\n",
    "    es_tree_key: any\n",
    "    iterinfo: any\n",
    "\n",
    "class Model:\n",
    "    @classmethod\n",
    "    def rand_init(cls, key, *args, **kwargs):\n",
    "        \"\"\"\n",
    "        Initialize model\n",
    "\n",
    "        returns frozen_params, params, scan_map, es_map as CommonInit\n",
    "        \"\"\"\n",
    "        raise NotImplementedError(\"Randomize Weights is not implemented\")\n",
    "\n",
    "    @classmethod\n",
    "    def forward(cls,\n",
    "                noiser, frozen_noiser_params, noiser_params,\n",
    "                frozen_params, params, es_tree_key, iterinfo, *args, **kwargs):\n",
    "        \"\"\"\n",
    "        Forward pass of model\n",
    "\n",
    "        returns just the output\n",
    "        \"\"\"\n",
    "        return cls._forward(CommonParams(noiser, frozen_noiser_params, noiser_params, frozen_params, params, es_tree_key, iterinfo), *args, **kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def _forward(cls, common_params, *args, **kwargs):\n",
    "        raise NotImplementedError(\"Forward pass is not implemented\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba18108b-268c-47bd-a613-4fa1e48c4665",
   "metadata": {},
   "source": [
    "As with the Noiser, there are multiple quirks with the Model.\n",
    "\n",
    "The rand_init function takes in some jax PRNG key and outputs a \"CommonInit\" consisting of frozen_params, params, scan_map, and es_map. The frozen_params are just frozen aspects of the model (like some fixed configuration option), and params are the standard parameters of the model as a jax pytree (the actual weights and biases of the model). \n",
    "\n",
    "For the purpose of this tutorial, scan_map can be mostly ignored, but it defines if a dimension of the parameter in a module is \"scanned\" over, which is only really relevant for LLMs; see the rwkv7 implementation for an example usecase.\n",
    "\n",
    "Finally, the es_map is a pytree of the same shape as params which dictates whether a model each component of the model should be treated as a PARAM, MM_PARAM, or EXCLUDED.\n",
    "\n",
    "The forward function just defines the forward pass of the model. Most of the parameters are the same as we have explained before. Noiser is the Noiser (sub-)class and frozen_noiser_params and noiser_params are the outputs of init_noiser. The frozen_params and params are the outputs of rand_init. The iterinfo is the same as described earlier, either None or (epoch, thread_id).\n",
    "\n",
    "The es_tree_key is a new parameter, and it is a tree of jax keys used to define the base_keys of the noiser.\n",
    "\n",
    "Below is a list of general convenience methods used for combining Models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc612902-de61-4395-a4e5-1e257586442f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "def recursive_scan_split(param, base_key, scan_tuple):\n",
    "    # scan_tuple = tuple() implies no split\n",
    "    if len(scan_tuple) == 0:\n",
    "        return base_key\n",
    "    # otherwise, it is (0, 1, ...)\n",
    "    split_keys = jax.random.split(base_key, param.shape[scan_tuple[0]])\n",
    "    return jax.vmap(recursive_scan_split, in_axes=(None, 0, None))(param, split_keys, scan_tuple[1:])\n",
    "\n",
    "def simple_es_tree_key(params, base_key, scan_map):\n",
    "    vals, treedef = jax.tree.flatten(params)\n",
    "    all_keys = jax.random.split(base_key, len(vals))\n",
    "    partial_key_tree = jax.tree.unflatten(treedef, all_keys)\n",
    "    return jax.tree.map(recursive_scan_split, params, partial_key_tree, scan_map)\n",
    "\n",
    "def merge_inits(**kwargs):\n",
    "    params = {}\n",
    "    frozen_params = {}\n",
    "    scan_map = {}\n",
    "    es_map = {}\n",
    "    for k in kwargs:\n",
    "        params[k] = kwargs[k].params #k_params\n",
    "        scan_map[k] = kwargs[k].scan_map #k_scan_map\n",
    "        es_map[k] = kwargs[k].es_map #k_es_map\n",
    "        if kwargs[k].frozen_params is not None:\n",
    "            frozen_params[k] = kwargs[k].frozen_params\n",
    "    if not frozen_params:\n",
    "        frozen_params = None\n",
    "\n",
    "    return CommonInit(frozen_params, params, scan_map, es_map)\n",
    "\n",
    "def merge_frozen(common, **kwargs):\n",
    "    new_frozen_params = common.frozen_params or {}\n",
    "    new_frozen_params = new_frozen_params | kwargs\n",
    "    return common._replace(frozen_params=new_frozen_params)\n",
    "\n",
    "def call_submodule(cls, name, common_params, *args, **kwargs):\n",
    "    sub_common_params = common_params._replace(\n",
    "        frozen_params=common_params.frozen_params[name] if common_params.frozen_params and name in common_params.frozen_params else None,\n",
    "        params=common_params.params[name],\n",
    "        es_tree_key=common_params.es_tree_key[name]\n",
    "    )\n",
    "    return cls._forward(sub_common_params, *args, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57faa6ef-a29b-459f-875f-73d57082e20e",
   "metadata": {},
   "source": [
    "The simple_es_tree_key generates the es_tree_key used for Model forward passes.\n",
    "\n",
    "The other methods are helpers used to compose existing Model \"submodules\" into a larger module.\n",
    "\n",
    "The fundamental, atomic Model classes are Parameter, MM, and TMM, implemented below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5af45d59-15d7-4adb-b4c9-869bbfd3db95",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter(Model):\n",
    "    @classmethod\n",
    "    def rand_init(cls, key, shape, scale, raw_value, dtype, *args, **kwargs):\n",
    "        if raw_value is not None:\n",
    "            params = raw_value.astype(dtype=dtype)\n",
    "        else:\n",
    "            params = (jax.random.normal(key, shape) * scale).astype(dtype=dtype)\n",
    "        \n",
    "        frozen_params = None\n",
    "        scan_map = ()\n",
    "        es_map = PARAM\n",
    "        return CommonInit(frozen_params, params, scan_map, es_map)\n",
    "\n",
    "    @classmethod\n",
    "    def _forward(cls, common_params, *args, **kwargs):\n",
    "        return common_params.noiser.get_noisy_standard(common_params.frozen_noiser_params, common_params.noiser_params, common_params.params, common_params.es_tree_key, common_params.iterinfo)\n",
    "\n",
    "class MM(Model):\n",
    "    @classmethod\n",
    "    def rand_init(cls, key, in_dim, out_dim, dtype, *args, **kwargs):\n",
    "        scale = 1 / jnp.sqrt(in_dim)\n",
    "        params = (jax.random.normal(key, (out_dim, in_dim)) * scale).astype(dtype=dtype)\n",
    "        frozen_params = None\n",
    "        scan_map = ()\n",
    "        es_map = MM_PARAM\n",
    "        return CommonInit(frozen_params, params, scan_map, es_map)\n",
    "\n",
    "    @classmethod\n",
    "    def _forward(cls, common_params, x, *args, **kwargs):\n",
    "        return common_params.noiser.do_mm(common_params.frozen_noiser_params, common_params.noiser_params, common_params.params, common_params.es_tree_key, common_params.iterinfo, x)\n",
    "\n",
    "class TMM(Model):\n",
    "    @classmethod\n",
    "    def rand_init(cls, key, in_dim, out_dim, dtype, *args, **kwargs):\n",
    "        scale = 1 / jnp.sqrt(in_dim)\n",
    "        params = jax.random.normal(key, (in_dim, out_dim), dtype=dtype) * scale\n",
    "        frozen_params = None\n",
    "        scan_map = ()\n",
    "        es_map = MM_PARAM\n",
    "        return CommonInit(frozen_params, params, scan_map, es_map)\n",
    "\n",
    "    @classmethod\n",
    "    def _forward(cls, common_params, x, *args, **kwargs):\n",
    "        return common_params.noiser.do_Tmm(common_params.frozen_noiser_params, common_params.noiser_params, common_params.params, common_params.es_tree_key, common_params.iterinfo, x)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "114ec5cb-19c5-4efd-89c7-fa5b8dac93e0",
   "metadata": {},
   "source": [
    "These fundamental Model classes can then be composed into more complex neural networks. We implement Linear layers and MLPs below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f76328d-f365-4890-a08a-438c0f61c818",
   "metadata": {},
   "outputs": [],
   "source": [
    "def layer_norm(x, eps=1e-5):\n",
    "    mean = jnp.mean(x, axis=-1, keepdims=True)\n",
    "    var = jnp.var(x, axis=-1, keepdims=True)\n",
    "    std = jnp.sqrt(var + eps)\n",
    "    return (x - mean) / std\n",
    "\n",
    "ACTIVATIONS = {\n",
    "    'relu': jax.nn.relu,\n",
    "    'silu': jax.nn.silu,\n",
    "    'pqn': lambda x: jax.nn.relu(layer_norm(x))\n",
    "}\n",
    "\n",
    "class Linear(Model):\n",
    "    @classmethod\n",
    "    def rand_init(cls, key, in_dim, out_dim, use_bias, dtype, *args, **kwargs):\n",
    "        if use_bias:\n",
    "            return merge_inits(\n",
    "                weight=MM.rand_init(key, in_dim, out_dim, dtype),\n",
    "                bias=Parameter.rand_init(key, None, None, jnp.zeros(out_dim, dtype=dtype), dtype)\n",
    "            )\n",
    "        else:\n",
    "            return merge_inits(\n",
    "                weight=MM.rand_init(key, in_dim, out_dim, dtype),\n",
    "            )\n",
    "\n",
    "    @classmethod\n",
    "    def _forward(cls, common_params, x, *args, **kwargs):\n",
    "        ans = call_submodule(MM, 'weight', common_params, x)\n",
    "        if \"bias\" in common_params.params:\n",
    "            ans += call_submodule(Parameter, 'bias', common_params)\n",
    "        return ans\n",
    "            \n",
    "class MLP(Model):\n",
    "    @classmethod\n",
    "    def rand_init(cls, key, in_dim, out_dim, hidden_dims, use_bias, activation, dtype, *args, **kwargs):\n",
    "        input_dims = [in_dim] + list(hidden_dims)\n",
    "        output_dims = list(hidden_dims) + [out_dim]\n",
    "\n",
    "        all_keys = jax.random.split(key, len(input_dims))\n",
    "\n",
    "        merged_params = merge_inits(**{str(t): Linear.rand_init(all_keys[t], input_dims[t], output_dims[t], use_bias, dtype) for t in range(len(input_dims))})\n",
    "        return merge_frozen(merged_params, activation=activation)\n",
    "\n",
    "    @classmethod\n",
    "    def _forward(cls, common_params, x, *args, **kwargs):\n",
    "        num_blocks = len(common_params.params)\n",
    "        for t in range(num_blocks):\n",
    "            x = call_submodule(Linear, str(t), common_params, x)\n",
    "            if t != num_blocks - 1:\n",
    "                x = ACTIVATIONS[common_params.frozen_params['activation']](x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddbf3e37-2bb2-481b-a90d-797f36c95010",
   "metadata": {},
   "source": [
    "## Part 3: The EGGROLL Noiser\n",
    "\n",
    "Now that we have established an interface for the Noiser and Model, we are ready to implement the EGGROLL noiser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d6d9d64-2800-44d5-a9f2-50038ffb1944",
   "metadata": {},
   "outputs": [],
   "source": [
    "import optax\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "def get_sign_key(frozen_noiser_params, iterinfo, key):\n",
    "    epoch, thread_id = iterinfo\n",
    "    true_epoch = 0 if frozen_noiser_params[\"noise_reuse\"] == 0 else epoch // frozen_noiser_params[\"noise_reuse\"]\n",
    "    true_thread_idx = thread_id // 2\n",
    "    sign = jnp.where(thread_id % 2 == 0, 1, -1)\n",
    "    key = jax.random.fold_in(jax.random.fold_in(key, true_epoch), true_thread_idx)\n",
    "    return sign, key\n",
    "\n",
    "def get_lora_update_params(frozen_noiser_params, base_sigma, iterinfo, param, key):\n",
    "    sign, p_key = get_sign_key(frozen_noiser_params, iterinfo, key)\n",
    "    sigma = base_sigma * sign\n",
    "\n",
    "    a, b = param.shape\n",
    "    lora_params = jax.random.normal(p_key, (a+b, frozen_noiser_params[\"rank\"]), dtype=param.dtype)\n",
    "    B = lora_params[:b] # b x r\n",
    "    A = lora_params[b:] # a x r\n",
    "\n",
    "    # update is A @ B.T\n",
    "    return A * sigma, B\n",
    "\n",
    "def get_nonlora_update_params(frozen_noiser_params, base_sigma, iterinfo, param, key):\n",
    "    sign, p_key = get_sign_key(frozen_noiser_params, iterinfo, key)\n",
    "    sigma = base_sigma * sign\n",
    "    updates = jax.random.normal(p_key, param.shape, dtype=param.dtype)\n",
    "    return updates * sigma\n",
    "\n",
    "def _simple_full_update(base_sigma, param, key, scores, iterinfo, frozen_noiser_params):\n",
    "    if frozen_noiser_params[\"freeze_nonlora\"]:\n",
    "        return jnp.zeros_like(param)\n",
    "    _, thread_ids = iterinfo\n",
    "    sigma = jnp.where(thread_ids % 2 == 0, base_sigma, -base_sigma)\n",
    "    updates = jax.vmap(partial(get_nonlora_update_params, frozen_noiser_params), in_axes=(None, 0, None, None))(base_sigma, iterinfo, param, key)\n",
    "    broadcasted_scores = jnp.reshape(scores, scores.shape + (1,) * len(param.shape))\n",
    "    broadcasted_sigma = jnp.reshape(sigma, sigma.shape + (1,) * len(param.shape))\n",
    "    return jnp.astype(jnp.mean(broadcasted_scores * updates, axis=0), param.dtype)\n",
    "\n",
    "def _simple_lora_update(base_sigma, param, key, scores, iterinfo, frozen_noiser_params):\n",
    "    A, B = jax.vmap(partial(get_lora_update_params, frozen_noiser_params), in_axes=(None, 0, None, None))(base_sigma / jnp.sqrt(frozen_noiser_params[\"rank\"]), iterinfo, param, key)\n",
    "    broadcasted_scores = jnp.reshape(scores, scores.shape + (1,1))\n",
    "    A = broadcasted_scores * A # N x a x r for A vs N x b x r for B -> final update is just a x b\n",
    "    num_envs = scores.shape[0]\n",
    "    # return A.T @ B / num_envs\n",
    "    return jnp.einsum('nir,njr->ij', A, B) / num_envs\n",
    "\n",
    "def _noop_update(base_sigma, param, key, scores, iterinfo, frozen_noiser_params):\n",
    "    return jnp.zeros_like(param)\n",
    "\n",
    "class EggRoll(Noiser):\n",
    "    @classmethod\n",
    "    def init_noiser(cls, params, sigma, lr, *args, solver=None, solver_kwargs=None, group_size=0, freeze_nonlora=False, noise_reuse=1, rank=1, **kwargs):\n",
    "        \"\"\"\n",
    "        Return frozen_noiser_params and noiser_params\n",
    "        \"\"\"\n",
    "        if solver is None:\n",
    "            solver = optax.sgd\n",
    "        if solver_kwargs is None:\n",
    "            solver_kwargs = {}\n",
    "        true_solver = solver(lr, **solver_kwargs)\n",
    "        opt_state = true_solver.init(params)\n",
    "        \n",
    "        return {\"group_size\": group_size, \"freeze_nonlora\": freeze_nonlora, \"noise_reuse\": noise_reuse, \"solver\": true_solver, \"rank\": rank}, {\"sigma\": sigma, \"opt_state\": opt_state}\n",
    "    \n",
    "    @classmethod\n",
    "    def do_mm(cls, frozen_noiser_params, noiser_params, param, base_key, iterinfo, x):\n",
    "        base_ans = x @ param.T\n",
    "        if iterinfo is None:\n",
    "            return base_ans\n",
    "        A, B = get_lora_update_params(frozen_noiser_params, noiser_params[\"sigma\"] / jnp.sqrt(frozen_noiser_params[\"rank\"]), iterinfo, param, base_key)\n",
    "        return base_ans + x @ B @ A.T\n",
    "\n",
    "    @classmethod\n",
    "    def do_Tmm(cls, frozen_noiser_params, noiser_params, param, base_key, iterinfo, x):\n",
    "        base_ans = x @ param\n",
    "        if iterinfo is None:\n",
    "            return base_ans\n",
    "        A, B = get_lora_update_params(frozen_noiser_params, noiser_params[\"sigma\"] / jnp.sqrt(frozen_noiser_params[\"rank\"]), iterinfo, param, base_key)\n",
    "        return base_ans + x @ A @ B.T\n",
    "\n",
    "    @classmethod\n",
    "    def get_noisy_standard(cls, frozen_noiser_params, noiser_params, param, base_key, iterinfo):\n",
    "        if iterinfo is None or frozen_noiser_params[\"freeze_nonlora\"]:\n",
    "            return param\n",
    "        return param + get_nonlora_update_params(frozen_noiser_params, noiser_params[\"sigma\"], iterinfo, param, base_key)\n",
    "\n",
    "    @classmethod\n",
    "    def convert_fitnesses(cls, frozen_noiser_params, noiser_params, raw_scores, num_episodes_list=None):\n",
    "        group_size = frozen_noiser_params[\"group_size\"]\n",
    "        if group_size == 0:\n",
    "            true_scores = (raw_scores - jnp.mean(raw_scores, keepdims=True)) / jnp.sqrt(jnp.var(raw_scores, keepdims=True) + 1e-5)\n",
    "        else:\n",
    "            group_scores = raw_scores.reshape((-1, group_size))\n",
    "            true_scores = (group_scores - jnp.mean(group_scores, axis=-1, keepdims=True)) / jnp.sqrt(jnp.var(raw_scores, keepdims=True) + 1e-5)\n",
    "            true_scores = true_scores.ravel()\n",
    "        return true_scores\n",
    "\n",
    "    @classmethod\n",
    "    def _do_update(cls, param, base_key, fitnesses, iterinfos, map_classification, sigma, frozen_noiser_params, **kwargs):\n",
    "        update_fn = [_simple_full_update, _simple_lora_update, _noop_update, _noop_update][map_classification]\n",
    "\n",
    "        if len(base_key.shape) == 0:\n",
    "            new_grad = update_fn(sigma, param, base_key, fitnesses, iterinfos, frozen_noiser_params)\n",
    "        else:\n",
    "            new_grad = jax.lax.scan(lambda _, x: (0, update_fn(sigma, x[0], x[1], fitnesses, iterinfos, frozen_noiser_params)), 0, xs=(param, base_key))[1]\n",
    "\n",
    "        return -(new_grad * jnp.sqrt(fitnesses.size)).astype(param.dtype)\n",
    "\n",
    "    @classmethod\n",
    "    def do_updates(cls, frozen_noiser_params, noiser_params, params, base_keys, fitnesses, iterinfos, es_map):\n",
    "        new_grad = jax.tree.map(lambda p, k, m: cls._do_update(p, k, fitnesses, iterinfos, m, noiser_params[\"sigma\"], frozen_noiser_params), params, base_keys, es_map)\n",
    "        updates, noiser_params[\"opt_state\"] = frozen_noiser_params[\"solver\"].update(new_grad, noiser_params[\"opt_state\"], params)\n",
    "        return noiser_params, optax.apply_updates(params, updates)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e24b1c80-73b1-4299-96d5-be874b81fcf6",
   "metadata": {},
   "source": [
    "Note how EGGROLL transforms do_mm (and do_Tmm) by decomposing it into the base matrix multiplication and then multiplying by the lora components. Also notice how the _do_update function returns the negative gradient, because this allows us to use standard optax optimizers (like sgd or adam) which typically attempt to minimize a loss function instead of maximizing a fitness function.\n",
    "\n",
    "## Part 4: Hands-On Experiments\n",
    "Now that we have an implementation of EGGROLL, let's test it out on a toy problem of modeling a simple function. Let's start by building the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d338b3-1e0c-4685-9956-72bfac527ac6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "max_x_dataset = 10\n",
    "num_train_points = 1024\n",
    "num_test_points = 1024\n",
    "\n",
    "dataset_seed = 0\n",
    "dataset_key = jax.random.key(dataset_seed)\n",
    "\n",
    "train_key, test_key = jax.random.split(dataset_key)\n",
    "train_x = jax.random.uniform(train_key, shape=(num_train_points,), minval=-max_x_dataset, maxval=max_x_dataset)\n",
    "test_x = jax.random.uniform(test_key, shape=(num_train_points,), minval=-max_x_dataset, maxval=max_x_dataset)\n",
    "\n",
    "def pred_fn(x):\n",
    "    return x ** 2 + x + 1\n",
    "\n",
    "train_y = pred_fn(train_x)\n",
    "test_y = pred_fn(test_x)\n",
    "\n",
    "plt.figure(figsize=(8, 4))\n",
    "plt.scatter(train_x, train_y, label='Training Set', color='blue')\n",
    "plt.scatter(test_x, test_y, label='Test Set', color='red')\n",
    "plt.legend()\n",
    "plt.show()\n",
    "\n",
    "@jax.jit\n",
    "def batch_calculate_fitness(y_pred, y_true):\n",
    "    return -((y_pred-y_true) ** 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf9c8fb4-e146-4547-b5b6-94fc3968ec66",
   "metadata": {},
   "source": [
    "Next, let's initialize an MLP model that we will train alongside the EggRoll noiser."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "116d6f66-213a-40ae-8f11-80721ba7bc1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "NOISER = EggRoll\n",
    "MODEL = MLP\n",
    "\n",
    "num_envs = 4096  # number of unique perturbations\n",
    "\n",
    "num_epochs = 1000\n",
    "sigma = 1.0\n",
    "lr = optax.schedules.linear_schedule(1.0, 0.0, num_epochs)\n",
    "rank = 1\n",
    "generations_per_prompt = 128  # needs to be even and divisor of num_envs\n",
    "\n",
    "in_dim = 1\n",
    "out_dim = 1\n",
    "hidden_dim = 64\n",
    "n_layer = 3\n",
    "\n",
    "key = jax.random.key(1)\n",
    "model_key = jax.random.fold_in(key, 0)\n",
    "es_key = jax.random.fold_in(key, 1)\n",
    "\n",
    "frozen_params, params, scan_map, es_map = MODEL.rand_init(\n",
    "    model_key, in_dim=in_dim, out_dim=out_dim, hidden_dims=[hidden_dim] * n_layer, use_bias=True, activation=\"pqn\", dtype=\"float32\"\n",
    ")\n",
    "\n",
    "print(\"Frozen params:\", frozen_params)\n",
    "print(\"Params (shape):\", jax.tree.map(lambda x: x.shape, params))\n",
    "print(\"Scan map:\", scan_map)\n",
    "print(\"ES map:\", es_map)\n",
    "\n",
    "es_tree_key = simple_es_tree_key(params, es_key, scan_map)\n",
    "\n",
    "print(\"ES Tree Key:\", es_tree_key)\n",
    "\n",
    "frozen_noiser_params, noiser_params = NOISER.init_noiser(params, sigma, lr, group_size=generations_per_prompt, rank=rank)\n",
    "# frozen_noiser_params, noiser_params = NOISER.init_noiser(params, sigma, lr, solver=optax.adamw, solver_kwargs={\"b1\": 0.9, \"b2\": 0.999}, rank=rank)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f06c4e-6371-4a32-8d9a-7021292bf75d",
   "metadata": {},
   "source": [
    "Now we can define the forward and update functions. Note that the regular \"jit_forward\" includes iterinfo to include the eggroll perturbations, whereas \"jit_forward_eval\" only uses the unperturbed base parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a16f9d8-e442-41ed-b1a4-4d7d2ea8c14a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# inputs are noiser_params, params, iterinfo, input\n",
    "jit_forward = jax.jit(jax.vmap(lambda n, p, i, x: MODEL.forward(NOISER, frozen_noiser_params, n, frozen_params, p, es_tree_key, i, x), in_axes=(None, None, 0, 0)))\n",
    "# inputs are noiser_params, params, input\n",
    "jit_forward_eval = jax.jit(jax.vmap(lambda n, p, x: MODEL.forward(NOISER, frozen_noiser_params, n, frozen_params, p, es_tree_key, None, x), in_axes=(None, None, 0)))\n",
    "# inputs are noiser_params, params, fitnesses, iterinfo\n",
    "jit_update = jax.jit(lambda n, p, f, i: NOISER.do_updates(frozen_noiser_params, n, p, es_tree_key, f, i, es_map))\n",
    "\n",
    "print(\"Warmup\")\n",
    "compile_iterinfos = (jnp.zeros(num_envs, dtype=jnp.int32), jnp.zeros(num_envs, dtype=jnp.int32))\n",
    "tmp_forward_out = jax.block_until_ready(jit_forward(noiser_params, params, compile_iterinfos, jnp.zeros((num_envs, 1))))\n",
    "print(\"forward compiled; output shape is\", tmp_forward_out.shape)\n",
    "tmp_forward_eval_out = jax.block_until_ready(jit_forward_eval(noiser_params, params, jnp.zeros((num_test_points, 1))))\n",
    "print(\"forward eval compiled; output shape is\", tmp_forward_eval_out.shape)\n",
    "tmp_update_out = jax.block_until_ready(jit_update(noiser_params, params, jnp.zeros(num_envs), compile_iterinfos))\n",
    "print(\"update compiled\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f0fbbfc-2a7f-4c79-ab76-238ff06586a8",
   "metadata": {},
   "source": [
    "Finally, we can define the optimization loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e97cdcb-c126-4222-ac04-d2aa2bdb318f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "prompts_per_epoch = num_envs // generations_per_prompt\n",
    "\n",
    "losses = []\n",
    "all_data = []\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    iterinfo = (jnp.full(num_envs, epoch, dtype=jnp.int32), jnp.arange(num_envs))\n",
    "\n",
    "    # data loading\n",
    "    unique_data_input_idxes = (jnp.arange(prompts_per_epoch) + epoch * prompts_per_epoch) % num_train_points\n",
    "    indices = jnp.repeat(unique_data_input_idxes, generations_per_prompt, axis=0)\n",
    "    train_batch_x = train_x[indices]\n",
    "    train_batch_y = train_y[indices]\n",
    "\n",
    "    # getting outputs from perturbations (THE KEY EGGROLL LOGIC)\n",
    "    outputs_batch = jit_forward(noiser_params, params, iterinfo, train_batch_x[:, None])[:, 0]\n",
    "    raw_scores = batch_calculate_fitness(outputs_batch, train_batch_y)\n",
    "    fitnesses = NOISER.convert_fitnesses(frozen_noiser_params, noiser_params, raw_scores)\n",
    "    noiser_params, params = jit_update(noiser_params, params, fitnesses, iterinfo)\n",
    "    \n",
    "    # evaluating the quality of the parameters on the test set\n",
    "    test_batch = jit_forward_eval(noiser_params, params, test_x[:, None])[:, 0]\n",
    "    all_data.append(test_batch)\n",
    "    raw_test_fitness_scores = batch_calculate_fitness(test_batch, test_y)\n",
    "    losses.append(-jnp.mean(raw_test_fitness_scores).item())\n",
    "    if epoch % 10 == 0:\n",
    "        print(f\"({epoch}) Avg validation loss is {-jnp.mean(raw_test_fitness_scores)} (positive, lower is better)\")\n",
    "\n",
    "    # linear sigma decay\n",
    "    noiser_params[\"sigma\"] = sigma * (1 - epoch / num_epochs)\n",
    "\n",
    "all_data = jnp.array(all_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee694b32-222f-4b3e-91b0-60de1279ed43",
   "metadata": {},
   "source": [
    "Let's plot how well it does!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b442113c-3e2f-41b2-ab0d-e0dad61dbe6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot loss curves\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, layout=\"constrained\", figsize=(8, 4))\n",
    "ax1.plot(range(num_epochs), losses)\n",
    "ax1.set_title(\"Test Loss\")\n",
    "\n",
    "ax2.plot(range(num_epochs), losses)\n",
    "ax2.set_yscale(\"log\")\n",
    "ax2.set_title(\"Test Loss (log y)\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "499f0a22-af83-4dc2-bc08-89153f4941f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipywidgets import interact\n",
    "import ipywidgets as widgets\n",
    "\n",
    "slider = widgets.IntSlider(min=0, max=num_epochs-1, step=1, value=num_epochs-1, description=\"Epoch Number:\")\n",
    "\n",
    "def update_slider(change):\n",
    "    plt.gca().set_ylim(jnp.min(test_y) - 10, jnp.max(test_y) + 10)\n",
    "    plt.scatter(test_x, test_y, label='Ground Truth', color='blue')\n",
    "    plt.scatter(test_x, all_data[change, :], label='Model Predictions', color='green')\n",
    "    plt.legend()\n",
    "    plt.gca().set_title(f'Model reconstruction at Time Step: {change}')\n",
    "\n",
    "_ = interact(update_slider, change=slider)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfc327fb-302a-482f-934f-88b41d7d6df5",
   "metadata": {},
   "source": [
    "Something to note is that we have not done any hyperparameter tuning for learning rate and sigma. Instead we start at 1.0 and perform a linear decay over the epochs. (This is extremely naive, since exponential decays are more typically used for evolution strategies and we would typically want to tune both hyperparameters anyways)\n",
    "\n",
    "A fun follow-up exercise may be to try tuning the hyperparameters or trying different schedules. Additionally, we just use SGD for simplicity, but you can also try out different optimizers and see which ones improve performance.\n",
    "\n",
    "## Conclusion\n",
    "\n",
    "Congratulations! You have now trained your first toy network with EGGROLL. We recommend playing around with this script and checking different hyperparameters. The workflow used for this toy setting is very similar to that of more complicated problems; you just need to change the model, data loading, and fitness function to your new setting.\n",
    "\n",
    "For those interested in extending EGGROLL to the RWKV LLMs or multi-gpu configurations, check out our multi-gpu evolution script."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
