{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "Requirements: Install the jaxopt package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env CUDA_VISIBLE_DEVICES=0\n",
    "\n",
    "import jax\n",
    "from jax import numpy as jnp\n",
    "from jax import nn\n",
    "import numpy as np\n",
    "from flax import linen as nn\n",
    "from flax.linen import initializers as nni\n",
    "from flax.traverse_util import flatten_dict, unflatten_dict\n",
    "from jax import random as jr\n",
    "from util import *\n",
    "import optax\n",
    "import pickle\n",
    "import argparse\n",
    "from matplotlib import cm\n",
    "\n",
    "from model import Transformer\n",
    "from tasks import Task, ModPSeq2SeqTask"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SimpleTask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# L = Lipschitz constant\n",
    "# seed = random seed\n",
    "# The following function trains single layer transformers on SimpleTask. Models are trained with max training length selected from train_lengths, \n",
    "#   and tested on sequences of lengths in test_lengths.\n",
    "def run(L, seed):\n",
    "    \n",
    "    rng = RNG(seed)\n",
    "    vocab_size = 3\n",
    "    problem = Task(vocab_size, 0, 1)\n",
    "    \n",
    "    max_length = 80\n",
    "\n",
    "    d_model = 16\n",
    "    heads = 1\n",
    "    width = 256\n",
    "    \n",
    "    # initialize model\n",
    "    model = Transformer(vocab_size, max_length, 1, d_model, heads, width, 0, False)\n",
    "    p0 = model.init(rng.next(), vmap(lambda key: problem.sample(max_length, key, 1))(rng.next(2))[0])\n",
    "    p0 = flatten_dict(p0[\"params\"], sep=\".\")\n",
    "\n",
    "    criterion = lambda f, y: (y - f)**2\n",
    "\n",
    "    @partial(jit, static_argnames=\"mutable\")\n",
    "    def f(p, *args, **kwargs):\n",
    "        p = dict(params=unflatten_dict(p, sep=\".\"))\n",
    "        return model.apply(p, *args, **kwargs)\n",
    "        \n",
    "    @jit\n",
    "    def loss_fn(p, batch):\n",
    "        x, y = batch\n",
    "        # only compute prediction on second token onwards\n",
    "        return vmap(criterion)(f(p, x)[:,1:], y).mean()\n",
    "\n",
    "    @jit\n",
    "    def test_loss_fn(p, batch):\n",
    "        x, y = batch\n",
    "        # only compute prediction on last token\n",
    "        return vmap(criterion)(f(p, x)[:,-1], y[:,-1]).mean()\n",
    "\n",
    "    models = []\n",
    "\n",
    "    # maximum training lengths\n",
    "    train_lengths = [i for i in jnp.arange(4, 21)]\n",
    "    \n",
    "    train_losses = []\n",
    "\n",
    "    # loop over train_lengths\n",
    "    for train_length in train_lengths:\n",
    "\n",
    "        p = p0\n",
    "        lr = 3e-3\n",
    "        steps = 2**13\n",
    "        save_every = steps // 128\n",
    "        batch_size = 2**10\n",
    "        max_size = 2**20 # 2**24\n",
    "        epoch_len = max_size // batch_size\n",
    "        sample_fn = lambda k: vmap(lambda key: problem.sample(train_length, key, L))(jr.split(k, epoch_len * batch_size))\n",
    "\n",
    "        def batch_iterator(key):\n",
    "            while True:\n",
    "                key, subkey = jr.split(key)\n",
    "                batches = sample_fn(subkey)\n",
    "                for i in range(epoch_len):\n",
    "                    yield tree_map(lambda x: x[batch_size * i : batch_size * (i + 1)], batches)\n",
    "                    \n",
    "        iterator = batch_iterator(rng.next())\n",
    "                    \n",
    "        # muP scaling for Adam\n",
    "        opt = optax.multi_transform(\n",
    "        {\n",
    "            'embed': optax.adam(learning_rate = lr),\n",
    "            'hidden': optax.adam(learning_rate = lr/d_model)\n",
    "        },\n",
    "        {\n",
    "            'wte':'embed',\n",
    "            'unembed':'embed',\n",
    "            'Q':'hidden',\n",
    "            'K':'hidden',\n",
    "            'V':'hidden',\n",
    "            'O':'hidden',\n",
    "            'layer1.kernel': 'hidden',\n",
    "            'layer1.bias': 'hidden',\n",
    "            'layer2.kernel': 'hidden'\n",
    "        })\n",
    "\n",
    "        @jit\n",
    "        def step_fn(p, batch, opt_state):\n",
    "            loss, g = jax.value_and_grad(loss_fn)(p, batch)\n",
    "            updates, opt_state = opt.update(g, opt_state, p)\n",
    "            p = optax.apply_updates(p, updates)\n",
    "            return p, opt_state, loss\n",
    "\n",
    "\n",
    "        train_loss = []\n",
    "        test_loss = []\n",
    "\n",
    "        opt_state = opt.init(p0)\n",
    "        for i in trange(steps):\n",
    "            batch = next(iterator)\n",
    "            loss = loss_fn(p, batch)\n",
    "            if loss < 1e-5 or jnp.isinf(loss):\n",
    "                break\n",
    "            p, opt_state, loss = step_fn(p, batch, opt_state)\n",
    "            train_loss.append(loss)\n",
    "        \n",
    "        print(\"final loss = \", loss)\n",
    "        train_losses.append(loss)\n",
    "        models.append(p)\n",
    "        \n",
    "    # evaluate models on test_lengths\n",
    "    \n",
    "    test_samples = 2**10\n",
    "    test_rng = rng.next()\n",
    "    test_lengths = [i for i in jnp.arange(4, 81, 4)]\n",
    "    all_losses = []\n",
    "    for test_length in test_lengths:\n",
    "        test_losses = []\n",
    "        testx, testy = vmap(lambda key: problem.sample(test_length, key, L))(jr.split(test_rng, test_samples))\n",
    "        for p in models:\n",
    "            test_loss = test_loss_fn(p, (testx, testy))\n",
    "            test_losses.append(test_loss)\n",
    "        all_losses.append(test_losses)\n",
    "        \n",
    "    all_losses = np.array(all_losses)\n",
    "        \n",
    "    return all_losses, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of random seeds to consider. each trial takes ~30 minutes\n",
    "num_trials = 5\n",
    "seeds = [i for i in jnp.arange(0, num_trials)]\n",
    "\n",
    "# Lipschitz constants to consider\n",
    "Ls = [i for i in jnp.arange(2, 6, 0.5)]\n",
    "\n",
    "\n",
    "# loop over all Lipschitz constants L and seeds\n",
    "all_results = []\n",
    "for L in Ls:\n",
    "    results = []\n",
    "    for seed in seeds:\n",
    "        res, _ = run(L, seed)\n",
    "        results.append(res)\n",
    "    results = np.array(results)\n",
    "    results = np.mean(results, axis=0)\n",
    "    all_results.append(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot test loss as a function of test length, for varying train lengths.\n",
    "results = all_results[1]\n",
    "train_lengths = np.arange(4, 21)\n",
    "test_lengths = np.arange(4, 81, 4)\n",
    "\n",
    "gradient = np.linspace(0, 1, results.shape[1])\n",
    "colors = cm.viridis(gradient)\n",
    "\n",
    "plt.style.use('style.mplstyle')\n",
    "for i in np.arange(0, results.shape[1], 2):\n",
    "    plt.plot(test_lengths, [result[i] for result in results], label = train_lengths[i], color = colors[i])\n",
    "plt.xlabel(\"Test length\")\n",
    "plt.ylabel(\"Test loss\")\n",
    "plt.yscale(\"log\")\n",
    "plt.legend(bbox_to_anchor=(1.22, 1.02), loc=\"upper right\", fontsize = 12)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot limiting test loss as a function of train length.\n",
    "plt.plot(train_lengths, results[-1], marker = 'o', color = colors[0])\n",
    "plt.ylabel(\"Test loss\")\n",
    "plt.xlabel(\"Train length\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xscale(\"log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot minimum train length needed to obtain loss <= epsilon\n",
    "epsilon = 1e-2\n",
    "Ns = []\n",
    "for i in range(len(Ls)):\n",
    "    Ns.append(train_lengths[jnp.argmax(all_results[i][-1, :] < epsilon)])\n",
    "plt.plot(Ls, Ns, color = colors[0], marker = 'o')\n",
    "plt.xlabel(r'$\\omega$')\n",
    "plt.ylabel(\"N\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mod P Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# delta = periodicity parameter in ModPTask\n",
    "def run(delta, seed):\n",
    "    rng = RNG(seed)\n",
    "    \n",
    "    problem = ModPSeq2SeqTask(delta, 1) # attend to position == 1 mod k.\n",
    "    vocab_size = 2\n",
    "\n",
    "    max_length = 40*delta\n",
    "\n",
    "    d_model = 16\n",
    "    heads = 1\n",
    "    width = 256\n",
    "\n",
    "    # init model\n",
    "    model = Transformer(vocab_size, max_length, 1, d_model, heads, width, delta, True)\n",
    "    p0 = model.init(rng.next(), vmap(lambda key: problem.sample(max_length, key))(rng.next(2))[0])\n",
    "    p0 = flatten_dict(p0[\"params\"], sep=\".\")\n",
    "\n",
    "    criterion = lambda f, y: (y - f)**2\n",
    "\n",
    "    @partial(jit, static_argnames=\"mutable\")\n",
    "    def f(p, *args, **kwargs):\n",
    "        p = dict(params=unflatten_dict(p, sep=\".\"))\n",
    "        return model.apply(p, *args, **kwargs)\n",
    "        \n",
    "    @jit\n",
    "    def loss_fn(p, batch):\n",
    "        x, y = batch\n",
    "        # only compute prediction on delta onwards\n",
    "        return vmap(criterion)(f(p, x)[:,delta-1:], y).mean()\n",
    "\n",
    "    @jit\n",
    "    def test_loss_fn(p, batch):\n",
    "        x, y = batch\n",
    "        # only compute prediction on last token\n",
    "        return vmap(criterion)(f(p, x)[:,-1], y[:,-1]).mean()\n",
    "\n",
    "    models = []\n",
    "\n",
    "    train_lengths = [i for i in jnp.arange(delta, 5*delta+1, delta)]\n",
    "    \n",
    "    train_losses = []\n",
    "\n",
    "    for train_length in train_lengths:\n",
    "\n",
    "        p = p0\n",
    "        lr = 3e-3\n",
    "        steps = 2**13\n",
    "        save_every = steps // 128\n",
    "        batch_size = 2**10\n",
    "        max_size = 2**20 # 2**24\n",
    "        epoch_len = max_size // batch_size\n",
    "        sample_fn = lambda k: vmap(lambda key: problem.sample(train_length, key))(jr.split(k, epoch_len * batch_size))\n",
    "\n",
    "        def batch_iterator(key):\n",
    "            while True:\n",
    "                key, subkey = jr.split(key)\n",
    "                batches = sample_fn(subkey)\n",
    "                for i in range(epoch_len):\n",
    "                    yield tree_map(lambda x: x[batch_size * i : batch_size * (i + 1)], batches)\n",
    "                    \n",
    "        iterator = batch_iterator(rng.next())\n",
    "                    \n",
    "        # muP scaling\n",
    "        opt = optax.multi_transform(\n",
    "        {\n",
    "            'embed': optax.adam(learning_rate = lr),\n",
    "            'hidden': optax.adam(learning_rate = lr/d_model)\n",
    "        },\n",
    "        {\n",
    "            'wte':'embed',\n",
    "            'PeriodicPositionalEncoding_0.wpe': 'embed',\n",
    "            'unembed':'embed',\n",
    "            'Q':'hidden',\n",
    "            'K':'hidden',\n",
    "            'V':'hidden',\n",
    "            'O':'hidden',\n",
    "            'layer1.kernel': 'hidden',\n",
    "            'layer1.bias': 'hidden',\n",
    "            'layer2.kernel': 'hidden'\n",
    "        })\n",
    "\n",
    "        @jit\n",
    "        def step_fn(p, batch, opt_state):\n",
    "            loss, g = jax.value_and_grad(loss_fn)(p, batch)\n",
    "            updates, opt_state = opt.update(g, opt_state, p)\n",
    "            p = optax.apply_updates(p, updates)\n",
    "            return p, opt_state, loss\n",
    "\n",
    "\n",
    "        train_loss = []\n",
    "        test_loss = []\n",
    "\n",
    "        opt_state = opt.init(p0)\n",
    "        for i in trange(steps):\n",
    "            batch = next(iterator)\n",
    "            loss = loss_fn(p, batch)\n",
    "            if loss < 1e-5 or jnp.isinf(loss):\n",
    "                break\n",
    "            p, opt_state, loss = step_fn(p, batch, opt_state)\n",
    "            train_loss.append(loss)\n",
    "        print(\"final loss = \", loss)\n",
    "        train_losses.append(loss)\n",
    "        models.append(p)\n",
    "        \n",
    "    # evaluate on test lengths\n",
    "    \n",
    "    test_samples = 2**10\n",
    "    test_rng = rng.next()\n",
    "\n",
    "    test_lengths = [i for i in jnp.arange(delta, 20*delta + 1, delta)]\n",
    "    all_losses = []\n",
    "    for test_length in test_lengths:\n",
    "        test_losses = []\n",
    "        testx, testy = vmap(lambda key: problem.sample(test_length, key))(jr.split(test_rng, test_samples))\n",
    "        for p in models:\n",
    "            test_loss = test_loss_fn(p, (testx, testy))\n",
    "            test_losses.append(test_loss)\n",
    "        all_losses.append(test_losses)\n",
    "        \n",
    "    all_losses = np.array(all_losses)\n",
    "        \n",
    "    return all_losses, models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of random seeds to consider. each trial takes ~10 minutes\n",
    "num_trials = 5\n",
    "seeds = [i for i in jnp.arange(num_trials)]\n",
    "\n",
    "# periodicity parameters delta\n",
    "deltas = [i for i in range(3, 9)]\n",
    "\n",
    "\n",
    "# loop over delta and seed\n",
    "all_results = []\n",
    "for delta in deltas:\n",
    "    results = []\n",
    "    for seed in seeds:\n",
    "        res, _ = run(delta, seed)\n",
    "        results.append(res)\n",
    "    results = np.array(results)\n",
    "    results = np.mean(results, axis=0)\n",
    "    all_results.append(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot test loss as a function of test length, for varying train lengths.\n",
    "results = all_results[-1]\n",
    "train_lengths = np.arange(5, 26, 5)\n",
    "test_lengths = np.arange(5, 101, 5)\n",
    "\n",
    "gradient = np.linspace(0, 1, results.shape[1])\n",
    "colors = cm.viridis(gradient)\n",
    "\n",
    "plt.style.use('style.mplstyle')\n",
    "for i in np.arange(0, results.shape[1]):\n",
    "    plt.plot(test_lengths, [result[i] for result in results], label = train_lengths[i], color = colors[i])\n",
    "plt.xlabel(\"Test length\")\n",
    "plt.ylabel(\"Test loss\")\n",
    "plt.yscale(\"log\")\n",
    "plt.legend(bbox_to_anchor=(1.22, 1.02), loc=\"upper right\", fontsize = 12)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot limiting test loss as a function of train length\n",
    "plt.plot(train_lengths, results[-1], marker = 'o', color = colors[0])\n",
    "plt.ylabel(\"Test loss\")\n",
    "plt.xlabel(\"Train length\")\n",
    "plt.yscale(\"log\")\n",
    "plt.xscale(\"log\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot limiting test loss as delta varies.\n",
    "gradient = np.linspace(0, 1, len(all_results))\n",
    "colors = cm.viridis(gradient)\n",
    "\n",
    "for i in range(len(all_results)):\n",
    "    plt.plot(np.arange(1, 6), all_results[i][-1, :], color = colors[i], marker = 'o', label = r'$\\Delta = $' + str(deltas[i]))\n",
    "plt.yscale(\"log\")\n",
    "plt.xlabel(\"Train length/\" + r'$\\Delta$')\n",
    "plt.ylabel(\"Test loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Attention Probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate test samples\n",
    "rng = RNG(0)\n",
    "test_samples = 1024\n",
    "test_length = 80\n",
    "problem = ModPSeq2SeqTask(5, 1)\n",
    "testx, testy = vmap(lambda key: problem.sample(test_length, key))(jr.split(rng.next(), test_samples))\n",
    "\n",
    "# initialize empty model\n",
    "max_length = 80\n",
    "delta = 5\n",
    "model = Transformer(2, max_length, 1, 16, 1, 256, delta, True)\n",
    "p0 = model.init(rng.next(), vmap(lambda key: problem.sample(80, key))(rng.next(2))[0])\n",
    "p0 = flatten_dict(p0[\"params\"], sep=\".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, models = run(delta, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# x = input sequence\n",
    "# p = model\n",
    "# output attention probabilities on the sequence x (length T sequence)\n",
    "def compute_attn_logits(x, p):  \n",
    "    T = testx.shape[-1]\n",
    "    x = model.embed(x, p['wte'])\n",
    "    repeat_count = (max_length + delta - 1) // delta\n",
    "    wpe_big = jnp.tile(p['PeriodicPositionalEncoding_0.wpe'], (repeat_count, 1))\n",
    "    x = x + wpe_big[:T]\n",
    "    \n",
    "    attn = jnp.einsum(\"...ij,jm,km,...lk -> ...il\", x, p['Q'][0], p['K'][0], x)\n",
    "        \n",
    "    attn = attn/model.d\n",
    "    attn = jnp.where(jnp.tri(T), attn, -jnp.inf)\n",
    "    attn = nn.softmax(attn)\n",
    "\n",
    "    return attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# mean and std of k mod p tokens\n",
    "yes_mean = []\n",
    "yes_std = []\n",
    "\n",
    "# mean and std of remaining tokens\n",
    "no_mean = []\n",
    "no_std = []\n",
    "\n",
    "# compute average attention for each model\n",
    "for p in models:\n",
    "    attn = compute_attn_logits(testx, p)\n",
    "    \n",
    "    last_attn = attn[:, -1, :]\n",
    "    last_attn = last_attn.reshape(-1)\n",
    "    last_attn = last_attn.reshape(-1, 5)\n",
    "    \n",
    "    yes_mean.append(last_attn.mean(axis=0)[1])\n",
    "    yes_std.append(last_attn.std(axis=0)[1])\n",
    "    \n",
    "    no_mean.append(last_attn[:, [0, 2, 3, 4]].mean())\n",
    "    no_std.append(last_attn[:, [0, 2, 3, 4]].std())\n",
    "yes_mean = np.array(yes_mean)\n",
    "yes_std = np.array(yes_std)\n",
    "no_mean = np.array(no_mean)\n",
    "no_std = np.array(no_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate plot\n",
    "gradient = np.linspace(0, 1, 2)\n",
    "colors = cm.viridis(gradient)\n",
    "\n",
    "\n",
    "plt.style.use('style.mplstyle')\n",
    "\n",
    "plt.plot(np.arange(5, 26, 5), yes_mean, label = r'$= k$' + 'mod p', color = colors[0])\n",
    "plt.fill_between(np.arange(5, 26, 5), yes_mean - yes_std, yes_mean+yes_std, color = colors[0], alpha = 0.5)\n",
    "\n",
    "plt.plot(np.arange(5, 26, 5), no_mean, label = r'$\\neq k$' + 'mod p', color = colors[1])\n",
    "plt.fill_between(np.arange(5, 26, 5), no_mean - no_std, no_mean+no_std, color = colors[1], alpha = 0.5)\n",
    "\n",
    "plt.plot([5, 25], [1./16, 1./16], linestyle = 'dashed', color = 'black')\n",
    "\n",
    "plt.xlabel(\"Train Length\")\n",
    "plt.ylabel(\"Attention probabilities\")\n",
    "\n",
    "plt.legend(loc=\"lower right\", bbox_to_anchor=(1.0, 0.1), fontsize = 12)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
