{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from zoology.analysis.utils import fetch_wandb_runs\n",
    "import wandb\n",
    "\n",
    "##################################################################################################################\n",
    "# get these from latex, see below\n",
    "COLUMNWIDTH = 245.71811 # width of a single column\n",
    "TEXTWIDTH = 397.48499 # width of two columns\n",
    "PPP = 439.3701 # powerpoint\n",
    "\n",
    "FONTSIZE = 10\n",
    "CONTEXT = \"paper\" # either 'paper' or 'talk'\n",
    "##################################################################################################################\n",
    "\n",
    "def set_size(width, fraction=1, subplot=[1, 1]):\n",
    "    \"\"\" Set aesthetic figure dimensions to avoid scaling in latex.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    width: float\n",
    "            Width in pts. Run \"\\showthe\\textwidth\" after \\begin{document} in Latex and search for the\n",
    "            textwidth in the log file. Alternatively use \"\\showthe\\columnwidth\" if the paper is double\n",
    "            column. The output looks like this:\n",
    "                > xyz.0pt.\n",
    "                X.Y \\showthe\\textwidth\n",
    "            Pass xyz to this function.\n",
    "    fraction: float\n",
    "            Fraction of the width which you wish the figure to occupy\n",
    "    subplot: list\n",
    "            [rows, columns] of subplots\n",
    "    \"\"\"\n",
    "    # Width of figure\n",
    "    fig_width_pt = width * fraction\n",
    "\n",
    "    # Convert from pt to inches\n",
    "    inches_per_pt = 1 / 72.27\n",
    "\n",
    "    # Golden ratio to set aesthetic figure height\n",
    "    golden_ratio = (5**.5 - 1) / 2\n",
    "\n",
    "    fig_width_in = fig_width_pt * inches_per_pt\n",
    "    fig_height_in = fig_width_in * golden_ratio * (float(subplot[0]) / float(subplot[1]))\n",
    "\n",
    "    fig_dim = (fig_width_in, fig_height_in)\n",
    "    return fig_dim\n",
    "\n",
    "if CONTEXT == \"talk\":\n",
    "    FACTOR = 1.5\n",
    "elif CONTEXT == \"paper\":\n",
    "    FACTOR = 1.\n",
    "else:\n",
    "    FACTOR = 0.\n",
    "\n",
    "sns.set(CONTEXT, \"whitegrid\", rc={'xtick.bottom': True,\n",
    "                                  'ytick.left': True,\n",
    "                                  'xtick.color': '0.1',\n",
    "                                  'ytick.color': '0.1',\n",
    "                                  'text.usetex': True,\n",
    "                                  'font.family': 'serif',\n",
    "                                  'font.serif': 'Times',\n",
    "                                  'figure.titlesize': 0.9*FONTSIZE,\n",
    "                                  'axes.titlesize': 0.9*FONTSIZE,\n",
    "                                  'axes.labelsize': 0.9*FONTSIZE,\n",
    "                                  'font.size': 0.9*FONTSIZE,\n",
    "                                  'legend.fontsize': int(0.8*FONTSIZE),\n",
    "                                  'legend.title_fontsize': int(0.8*FONTSIZE),\n",
    "                                  'xtick.labelsize': int(0.8*FONTSIZE),\n",
    "                                  'ytick.labelsize': int(0.8*FONTSIZE),\n",
    "                                  'axes.linewidth': 0.8*FACTOR,\n",
    "                                  'grid.linewidth': 0.5*FACTOR,\n",
    "                                  'xtick.major.size': 4.5*FACTOR,\n",
    "                                  'xtick.major.width': 0.8*FACTOR,\n",
    "                                  'xtick.minor.width': 0.64*FACTOR,\n",
    "                                  'ytick.major.size': 4.5*FACTOR,\n",
    "                                  'ytick.major.width': 0.8*FACTOR,\n",
    "                                  'ytick.minor.width': 0.64*FACTOR,})\n",
    "sns.set_palette(\"deep\")\n",
    "\n",
    "### other definitions ######################################################################\n",
    "# e.g. plt.rcParams['text.latex.preamble'] = [r'\\usepackage{newtxmath}']\n",
    "plt.rcParams['text.latex.preamble'] = r\"\\usepackage{newtxmath}\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Attention & SSMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = [\n",
    "    \"sm-attention\",\n",
    "    \"lin-attention\",\n",
    "    \"norm-attention\",\n",
    "    \"s6-2\",\n",
    "    \"s6\"\n",
    "]\n",
    "\n",
    "df_64_mamba = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-17-seqlen64-kv4\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )\n",
    "df_64 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen64-kv4\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )\n",
    "df_128 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen128-kv8\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )\n",
    "df_256 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen256-kv16\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )\n",
    "df_512 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen512-kv64\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_df_64 = {}\n",
    "for model in MODELS:\n",
    "    try:\n",
    "        if model in [\"s6\", \"s6-2\"]:\n",
    "            split_df_64[model] = df_64_mamba[df_64_mamba[\"model.sequence_mixer.name\"] == model]\n",
    "        else:\n",
    "            split_df_64[model] = df_64[df_64[\"name\"].str.startswith(model)]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue\n",
    "\n",
    "split_df_128 = {}\n",
    "for model in MODELS:\n",
    "    try:\n",
    "        if model in [\"s6\", \"s6-2\"]:\n",
    "            split_df_128[model] = df_128[df_128[\"model.sequence_mixer.name\"] == model]\n",
    "        else:\n",
    "            split_df_128[model] = df_128[df_128[\"name\"].str.startswith(model)]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue\n",
    "\n",
    "split_df_256 = {}\n",
    "for model in MODELS:\n",
    "    try:\n",
    "        if model in [\"s6\", \"s6-2\"]:\n",
    "            split_df_256[model] = df_256[df_256[\"model.sequence_mixer.name\"] == model]\n",
    "        else:\n",
    "            split_df_256[model] = df_256[df_256[\"name\"].str.startswith(model)]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue\n",
    "\n",
    "split_df_512 = {}\n",
    "for model in MODELS:\n",
    "    try:\n",
    "        if model in [\"s6\", \"s6-2\"]:\n",
    "            split_df_512[model] = df_512[df_512[\"model.sequence_mixer.name\"] == model]\n",
    "        else:\n",
    "            split_df_512[model] = df_512[df_512[\"name\"].str.startswith(model)]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EVAL = {}\n",
    "seq_len = [\"64\", \"128\", \"256\", \"512\"]\n",
    "for s, split_df in enumerate([split_df_64, split_df_128, split_df_256, split_df_512]):\n",
    "    model_eval = {}\n",
    "    for model in MODELS:\n",
    "        D_QK = sorted(split_df[model][\"model.d_qk\"].unique())\n",
    "        D_MODEL = sorted(split_df[model][\"model.d_model\"].unique())\n",
    "        if 16 in D_QK: # remove n=16\n",
    "            D_QK.pop(0)\n",
    "\n",
    "        eval_dict = {\"d_qk\": D_QK, \"d_model\": D_MODEL}\n",
    "        eval_mat = np.zeros((len(D_MODEL), len(D_QK))) # always aranged as row=d_model, col=d_qk\n",
    "        for i, d_model in enumerate(D_MODEL):\n",
    "            d_idx = split_df[model][\"model.d_model\"] == d_model\n",
    "            cache = split_df[model][d_idx]\n",
    "            for j, d_qk in enumerate(D_QK):\n",
    "                qk_idx = cache[\"model.d_qk\"] == d_qk\n",
    "                eval_mat[i,j] = cache[qk_idx][\"valid/accuracy\"].dropna().max()\n",
    "        eval_dict[\"acc\"] = eval_mat\n",
    "\n",
    "        model_eval[model] = eval_dict\n",
    "    \n",
    "    EVAL[seq_len[s]] = model_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(5, 4, figsize=set_size(TEXTWIDTH, subplot=[8, 4]), sharey=True, sharex=True, gridspec_kw={'width_ratios': [1, 1, 1, 1.25]})\n",
    "LABELS = {\"sm-attention\": \"Softmax att.\",\n",
    "          \"lin-attention\": \"Linear att.\",\n",
    "          \"norm-attention\": \"Normalized att.\",\n",
    "          \"s6\": \"S6\",\n",
    "          \"s6-2\": \"SSD\"}\n",
    "KV_PAIRS = {\"64\": 4, \"128\": 8, \"256\": 16, \"512\": 64}\n",
    "cmap = sns.color_palette(\"coolwarm\", as_cmap=True)\n",
    "\n",
    "for i, eval in enumerate(EVAL):\n",
    "    for j, model in enumerate(MODELS):\n",
    "        acc = EVAL[eval][model][\"acc\"]\n",
    "        acc_flat = [val if val <= 0.99 else \"$>$99\" for val in acc.flatten()]\n",
    "        annot = np.asarray([\"{0:.1f}\".format(val*100) if type(val) is np.float64 else val for val in acc_flat]).reshape(acc.shape)\n",
    "        if i == len(EVAL)-1:\n",
    "            sns.heatmap(acc, annot=annot, fmt=\"\", annot_kws={\"fontsize\": int(0.7*FONTSIZE)}, cmap=cmap, vmin=0, vmax=1, xticklabels = EVAL[eval][model][\"d_qk\"], yticklabels=EVAL[eval][model][\"d_model\"],\n",
    "                        linewidths=0.3, linecolor='white', ax=ax[j, i], cbar=True)\n",
    "        else:\n",
    "            sns.heatmap(acc, annot=annot, fmt=\"\", annot_kws={\"fontsize\": int(0.7*FONTSIZE)}, cmap=cmap, vmin=0, vmax=1, xticklabels = EVAL[eval][model][\"d_qk\"], yticklabels=EVAL[eval][model][\"d_model\"],\n",
    "                        linewidths=0.3, linecolor='white', ax=ax[j, i], cbar=False)\n",
    "        if j == 0:\n",
    "            ax[j, i].set_title(\"L: {0}, KV-pairs: {1}\".format(eval,KV_PAIRS[eval]))\n",
    "        if i == 0:\n",
    "            ax[j,i].set_title(LABELS[model], rotation='vertical', x=-0.55, y=0.5, ha=\"center\", va=\"center\",fontsize=0.9*FONTSIZE)\n",
    "            ax[j,i].set_ylabel(\"d\")\n",
    "        if i != 0:\n",
    "            ax[j,i].tick_params(axis='y', which='both', left=False)\n",
    "        if j == len(MODELS)-1:\n",
    "            ax[j,i].set_xlabel(\"n\")\n",
    "        if j != len(MODELS)-1:\n",
    "            ax[j,i].tick_params(axis='x', which='both', bottom=False)\n",
    "        \n",
    "        ax[j,i].grid(False)\n",
    "ax[0,0].text(2., -0.32, \"L: {0}, KV-pairs: {1}\".format(\"64\",KV_PAIRS[\"64\"]), ha='center', fontsize=0.9*FONTSIZE)\n",
    "        \n",
    "plt.tight_layout()\n",
    "plt.savefig('full.pdf', format='pdf', bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Figure 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lin_att = {\"256\": EVAL[\"256\"][\"lin-attention\"], \"512\": EVAL[\"512\"][\"lin-attention\"]}\n",
    "sm_att = {\"256\": EVAL[\"256\"][\"sm-attention\"], \"512\": EVAL[\"512\"][\"sm-attention\"]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=set_size(TEXTWIDTH, subplot=[1, 2]), sharey=True)\n",
    "LABELS = [\"Linear attention (16)\", \"Softmax attention (2)\"]\n",
    "colors = sns.color_palette(\"deep\")\n",
    "\n",
    "# 256\n",
    "lin_eval = lin_att[\"256\"]\n",
    "sm_eval = sm_att[\"256\"]\n",
    "\n",
    "ax[0].plot(lin_eval[\"acc\"][-1,:], c=colors[0], marker=\"o\", label=LABELS[0])\n",
    "ax[0].plot(sm_eval[\"acc\"][-1,:], c=colors[1], marker=\"o\", label=LABELS[1])\n",
    "\n",
    "ax[0].set_ylim([0.0, 1.04])\n",
    "ax[0].set_xticks(ticks=np.arange(len(lin_eval[\"d_qk\"])),labels=lin_eval[\"d_qk\"])\n",
    "ax[0].set_xlabel(\"State expansion $n$\")\n",
    "ax[0].set_ylabel(\"Accuracy\")\n",
    "ax[0].set_title(\"L: {0}, KV-pairs: {1}\".format(\"256\",16))\n",
    "ax[0].grid(True)\n",
    "\n",
    "# 512\n",
    "lin_eval = lin_att[\"512\"]\n",
    "sm_eval = sm_att[\"512\"]\n",
    "\n",
    "ax[1].plot(lin_eval[\"acc\"][-1,:], c=colors[0], marker=\"o\", label=LABELS[0])\n",
    "ax[1].plot(sm_eval[\"acc\"][-1,:], c=colors[1], marker=\"o\", label=LABELS[1])\n",
    "\n",
    "ax[1].set_ylim([0.0, 1.04])\n",
    "ax[1].set_xticks(ticks=np.arange(len(lin_eval[\"d_qk\"])),labels=lin_eval[\"d_qk\"])\n",
    "ax[1].set_xlabel(\"State expansion $n$\")\n",
    "ax[1].tick_params(axis='y', which='both', left=False)\n",
    "ax[1].set_title(\"L: {0}, KV-pairs: {1}\".format(\"512\",64))\n",
    "ax[1].legend(loc='lower right')\n",
    "ax[1].grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('state-expansion.pdf', format='pdf', bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Figure 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 1, figsize=set_size(TEXTWIDTH, fraction=0.55, subplot=[0.9, 1]))\n",
    "LABELS = [\"Softmax att. (2)\",\n",
    "          \"Linear att. (16)\",\n",
    "          \"Normalized att. (21)\",\n",
    "          \"SSD [Dao and Gu, 2024]\",\n",
    "          \"S6 [Gu and Dao, 2023]\"]\n",
    "colors = sns.color_palette(\"deep\")\n",
    "\n",
    "model_eval = EVAL[\"512\"]\n",
    "n = 128\n",
    "\n",
    "for i,model in enumerate(MODELS):\n",
    "    acc = model_eval[model][\"acc\"]\n",
    "    idx = np.array(model_eval[model][\"d_qk\"]) == n\n",
    "    ax.plot(acc[:,idx], c=colors[i], marker=\"o\", label=LABELS[i])\n",
    "\n",
    "ax.set_ylim([-0.04, 1.04])\n",
    "ax.set_xticks(ticks=np.arange(len(model_eval[model][\"d_model\"])),labels=model_eval[model][\"d_model\"])\n",
    "ax.set_xlabel(\"Model size $d$\")\n",
    "ax.set_ylabel(\"Accuracy\")\n",
    "#ax.set_title(\"L: {0}, KV-pairs: {1}\".format(\"512\",64))\n",
    "#ax.legend(ncol=2, loc=\"center\", bbox_to_anchor=(0.48, 1.23))\n",
    "ax.grid(True)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('normalization.pdf', format='pdf', bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### RNNs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_64 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen64-kv4\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )\n",
    "df_128 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen128-kv8\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )\n",
    "df_256 = fetch_wandb_runs(\n",
    "        sweep_id=[\n",
    "            \"run2024-05-20-seqlen256-kv16\"\n",
    "        ], \n",
    "        project_name=\"neurips-2024\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_df_64 = {}\n",
    "for model in [\"qlstm\", \"qlstm-rev\"]:\n",
    "    try:\n",
    "        if model in [\"qlstm\"]:\n",
    "            split_df_64[model] = df_64[df_64[\"model.sequence_mixer.kwargs.reversed\"] == False]\n",
    "        elif model in [\"qlstm-rev\"]:\n",
    "            split_df_64[model] = df_64[df_64[\"model.sequence_mixer.kwargs.reversed\"] == True]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue\n",
    "\n",
    "split_df_128 = {}\n",
    "for model in [\"qlstm\", \"qlstm-rev\"]:\n",
    "    try:\n",
    "        if model in [\"qlstm\"]:\n",
    "            split_df_128[model] = df_128[df_128[\"model.sequence_mixer.kwargs.reversed\"] == False]\n",
    "        elif model in [\"qlstm-rev\"]:\n",
    "            split_df_128[model] = df_128[df_128[\"model.sequence_mixer.kwargs.reversed\"] == True]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue\n",
    "\n",
    "split_df_256 = {}\n",
    "for model in [\"qlstm\", \"qlstm-rev\"]:\n",
    "    try:\n",
    "        if model in [\"qlstm\"]:\n",
    "            split_df_256[model] = df_256[df_256[\"model.sequence_mixer.kwargs.reversed\"] == False]\n",
    "        elif model in [\"qlstm-rev\"]:\n",
    "            split_df_256[model] = df_256[df_256[\"model.sequence_mixer.kwargs.reversed\"] == True]\n",
    "    except:\n",
    "        print(\"No {} found!\".format(model))\n",
    "        continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EVAL = {}\n",
    "seq_len = [\"64\", \"128\", \"256\"]\n",
    "for s, split_df in enumerate([split_df_64, split_df_128, split_df_256]):\n",
    "    model_eval = {}\n",
    "    for model in [\"qlstm\", \"qlstm-rev\"]:\n",
    "        D_QK = sorted(split_df[model][\"model.d_qk\"].unique())\n",
    "        D_MODEL = sorted(split_df[model][\"model.d_model\"].unique())\n",
    "\n",
    "        eval_dict = {\"d_qk\": D_QK, \"d_model\": D_MODEL}\n",
    "        eval_mat = np.zeros((len(D_MODEL), len(D_QK))) # always aranged as row=d_model, col=d_qk\n",
    "        for i, d_model in enumerate(D_MODEL):\n",
    "            d_idx = split_df[model][\"model.d_model\"] == d_model\n",
    "            cache = split_df[model][d_idx]\n",
    "            for j, d_qk in enumerate(D_QK):\n",
    "                qk_idx = cache[\"model.d_qk\"] == d_qk\n",
    "                eval_mat[i,j] = cache[qk_idx][\"valid/accuracy\"].dropna().max()\n",
    "        eval_dict[\"acc\"] = eval_mat\n",
    "\n",
    "        model_eval[model] = eval_dict\n",
    "    \n",
    "    EVAL[seq_len[s]] = model_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 3, figsize=set_size(TEXTWIDTH, subplot=[1.4, 3]), sharey=False)\n",
    "LABELS = {\"qlstm\": \"qLSTM\", \"qlstm-rev\": \"qLSTM w/ (22)\"}\n",
    "colors = sns.color_palette(\"deep\")\n",
    "\n",
    "# 64\n",
    "model_eval = EVAL[\"64\"]\n",
    "\n",
    "for i,model in enumerate([\"qlstm\", \"qlstm-rev\"]):\n",
    "    ax[0].plot(model_eval[model][\"acc\"], c=colors[i], marker=\"o\", label=LABELS[model])\n",
    "    print(model)\n",
    "    print(model_eval[model][\"acc\"])\n",
    "ax[0].set_ylim([0.8, 1.01])\n",
    "ax[0].set_xticks(ticks=np.arange(len(model_eval[\"qlstm\"][\"d_model\"])),labels=model_eval[\"qlstm\"][\"d_model\"])\n",
    "ax[0].set_xlabel(\"Model size $d$\")\n",
    "ax[0].set_ylabel(\"Accuracy\")\n",
    "ax[0].set_title(\"L: {0}, KV-pairs: {1}\".format(\"64\",4))\n",
    "ax[0].legend()\n",
    "ax[0].grid(True)\n",
    "\n",
    "# 128\n",
    "model_eval = EVAL[\"128\"]\n",
    "\n",
    "for i,model in enumerate([\"qlstm\", \"qlstm-rev\"]):\n",
    "    ax[1].plot(model_eval[model][\"acc\"], c=colors[i], marker=\"o\", label=LABELS[model])\n",
    "    print(model)\n",
    "    print(model_eval[model][\"acc\"])\n",
    "ax[1].set_ylim([0.5, 1.01])\n",
    "ax[1].set_xticks(ticks=np.arange(len(model_eval[\"qlstm\"][\"d_model\"])),labels=model_eval[\"qlstm\"][\"d_model\"])\n",
    "ax[1].set_xlabel(\"Model size $d$\")\n",
    "ax[1].set_ylabel(\"Accuracy\")\n",
    "ax[1].set_title(\"L: {0}, KV-pairs: {1}\".format(\"128\",8))\n",
    "ax[1].grid(True)\n",
    "\n",
    "# 256\n",
    "model_eval = EVAL[\"256\"]\n",
    "\n",
    "for i,model in enumerate([\"qlstm\", \"qlstm-rev\"]):\n",
    "    ax[2].plot(model_eval[model][\"acc\"], c=colors[i], marker=\"o\", label=LABELS[model])\n",
    "    print(model)\n",
    "    print(model_eval[model][\"acc\"])\n",
    "ax[2].set_ylim([0., 1.01])\n",
    "ax[2].set_xticks(ticks=np.arange(len(model_eval[\"qlstm-rev\"][\"d_model\"])),labels=model_eval[\"qlstm-rev\"][\"d_model\"])\n",
    "ax[2].set_xlabel(\"Model size $d$\")\n",
    "ax[2].set_ylabel(\"Accuracy\")\n",
    "ax[2].set_title(\"L: {0}, KV-pairs: {1}\".format(\"256\",16))\n",
    "ax[2].grid(True)\n",
    "plt.tight_layout()\n",
    "plt.savefig('qlstm.pdf', format='pdf', bbox_inches='tight', pad_inches=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
