{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "465c397d-f676-4c2c-994e-b81f278a1454",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Computer algebra system for appendix of \"When can Transformers reason with abstract symbols?\"\n",
    "\n",
    "This iPython notebook contains code helpful for analyzing the random features kernel of an attention layer.\n",
    "\n",
    "We fuzz-test the computer algebra system that we implement for correctness (see below)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6eae43b1-4f7d-43d7-ab2c-25d8d8a3947e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'disjoint_set'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[1], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mIPython\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdisplay\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m display, Math, Latex\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mcopy\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mdisjoint_set\u001b[39;00m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mitertools\u001b[39;00m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtqdm\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m tqdm\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'disjoint_set'"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from IPython.display import display, Math, Latex\n",
    "import copy\n",
    "import disjoint_set\n",
    "import itertools\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import os\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e12fc696-106d-43e6-b0cc-13222a7d9718",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Definition of the terms $T_{r,n,\\mathbf{i},\\mathbf{j},\\mathbf{a},\\mathbf{b},\\mathbf{c},\\mathbf{d}}$ that we focus on\n",
    "\n",
    "We will consider terms given by $r \\in \\mathbb{Z}$, $\\mathbf{i} = [i_1,\\ldots,i_k] \\in [r]^k, \\mathbf{j} = [j_1,\\ldots,j_l] \\in [r]^l, \\mathbf{a} = [a_1,\\ldots,a_m] \\in [r]^m, \\mathbf{b} = [b_1,\\ldots,b_o] = [r]^o, \\mathbf{c} \\in [c_1,\\ldots,c_{\\mu}] \\in [r]^{\\mu}, \\mathbf{d} = [(d_{1,1}, d_{1,2}),\\ldots, (d_{z,1},d_{z,2})]$, which are given by\n",
    "\n",
    "$$T_{r,\\mathbf{i},\\mathbf{j},\\mathbf{a},\\mathbf{b},\\mathbf{c},\\mathbf{d}} = \\sum_{w_1,\\ldots,w_r \\in [k]} \\left(\\prod_{q \\in [k]} s_{w_{i_q}}\\right) \\cdot \\left(\\prod_{q \\in [l]} u_{w_{j_q}}\\right) \\cdot \\left(\\prod_{q \\in [m]} t_{w_{a_q}}\\right) \\cdot \\left(\\prod_{q \\in [o]} v_{w_{b_q}}\\right) \\cdot \\left(\\prod_{q \\in [\\mu]} p_{w_{c_{\\mu}}} \\right) \\cdot \\left(\\prod_{q \\in [z]} 1(x_{w_{d_{q,1}}} = y_{w_{d_{q,2}}})\\right)$$\n",
    "\n",
    "Here\n",
    "* $p_1,\\ldots,p_k \\in \\mathbb{R}$\n",
    "* $x_1,\\ldots,x_k \\in \\mathbb{R}$\n",
    "* $y_1,\\ldots,y_k \\in \\mathbb{R}$.\n",
    "* $\\zeta_1,\\ldots,\\zeta_k \\in \\mathbb{R}$.\n",
    "* $\\xi_1,\\ldots,\\xi_k \\in \\mathbb{R}$.\n",
    "* $u_1,\\ldots,u_k,v_1,\\ldots,v_k \\in \\mathbb{R}$ are defined as $u_i = \\zeta_i + \\gamma p_i$ for all $i \\in [k]$ and $v_i = \\xi_i + \\gamma p_i$ for all $i \\in [k]$\n",
    "* $\\mathbf{s} = \\mathrm{softmax}([\\beta u_1,\\ldots,\\beta u_k]) \\in \\mathbb{R}^k$ and $\\mathbf{t} = \\mathrm{softmax}([\\beta v_1,\\ldots,\\beta v_k]) \\in \\mathbb{R}^k$ for some $\\beta \\in \\mathbb{R}$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6644c78b-9717-4bf2-9ba4-48fdcff5ea2e",
   "metadata": {},
   "source": [
    "## Code to display a term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "affcb41f-1cee-4c90-b6f0-ab1d51345f62",
   "metadata": {},
   "outputs": [],
   "source": [
    "INDEX_TO_VAR_NAME = ['ERROR','a','b','c','d','e','f','g','h','i','j','k','l', r'\\alpha', r'\\delta', r'\\epsilon', r'\\tau', r'\\sigma']\n",
    "\n",
    "def display_terms(currterms):\n",
    "    tot_str = ''\n",
    "    for t in currterms:\n",
    "        tot_str += term_string(t)\n",
    "    display(Math(tot_str))\n",
    "    \n",
    "def coeff_term_str(coeff):\n",
    "    term_str = ''\n",
    "    if coeff > 0:\n",
    "        term_str += '+'\n",
    "    if coeff == 1:\n",
    "        pass\n",
    "    elif coeff == -1:\n",
    "        term_str += '-'\n",
    "    else:\n",
    "        term_str += str(coeff)\n",
    "    return term_str\n",
    "    \n",
    "def term_string(t):\n",
    "    i_terms, js, a_terms, bs, ps, diracs, coeff = t\n",
    "\n",
    "    term_str = coeff_term_str(coeff)\n",
    "    \n",
    "    terms_set = set(i_terms + a_terms)\n",
    "    term_str += r'\\sum_{'\n",
    "    for i_idx, i in enumerate(terms_set):\n",
    "        term_str += INDEX_TO_VAR_NAME[i]\n",
    "        if i_idx < len(terms_set) - 1:\n",
    "            term_str += ','\n",
    "    term_str += r'}'\n",
    "    \n",
    "    for i in i_terms:\n",
    "        term_str += r's_{' + INDEX_TO_VAR_NAME[i] + '}'\n",
    "    for a in a_terms:\n",
    "        term_str += r't_{' + INDEX_TO_VAR_NAME[a] + '}'\n",
    "    for j in js:\n",
    "        term_str += r'u_{' + INDEX_TO_VAR_NAME[j] + '}'\n",
    "    for b in bs:\n",
    "        term_str += r'v_{' + INDEX_TO_VAR_NAME[b] + '}'\n",
    "    for p in ps:\n",
    "        term_str += r'p_{' + INDEX_TO_VAR_NAME[p] + '}'\n",
    "    for v1, v2 in diracs:\n",
    "        term_str += r'1(x_{' + INDEX_TO_VAR_NAME[v1] + '} = y_{' + INDEX_TO_VAR_NAME[v2] + '})'\n",
    "    \n",
    "    return term_str"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bede1a4e-d739-46c5-91ba-90bd586ae6ef",
   "metadata": {},
   "source": [
    "## Examples of terms\n",
    "\n",
    "We give some examples of terms, for illustrative purposes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a4088eb9-f0d1-47ae-b829-a6f1d9244451",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [], [2], [], [], [(1, 2)], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}1(x_{a} = y_{b})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [], [1], [], [], [], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [2], [1], [2], [3], [(3, 1)], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}u_{b}v_{b}p_{c}1(x_{c} = y_{a})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# i_terms, js, a_terms, bs, ps, diracs, coeff\n",
    "term = [[1],[],[2],[], [], [(1,2)], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "\n",
    "term = [[1],[],[1],[], [], [], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "\n",
    "term = [[1],[2],[1],[2], [3], [(3,1)], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41ddf4cc-4b1a-43c9-bbf1-25e3ec7874b1",
   "metadata": {},
   "source": [
    "## Derivatives in $\\beta$\n",
    "\n",
    "We care about computing $\\frac{\\partial}{\\partial \\beta} T_{r,\\mathbf{i},\\mathbf{j},\\mathbf{a},\\mathbf{b},\\mathbf{c},\\mathbf{d}}$. The observation is that we can express this as a sum of terms of the same form. Since only $\\mathbf{s}$ and $\\mathbf{t}$ depend on $\\beta$, the following code successfully computes derivatives in $\\beta$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d15d34bf-dde2-44aa-911b-79d6bed7d0ac",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def take_beta_deriv(term):\n",
    "    i_terms, js, a_terms, bs, ps, diracs, coeff = term\n",
    "    assert(len(ps) == 0)\n",
    "\n",
    "    newterms = []\n",
    "\n",
    "    # Construct the k index by picking the first that does not appear in i_terms or j_terms\n",
    "    k = max(np.max(i_terms),np.max(a_terms))+1\n",
    "    \n",
    "    for i in i_terms:\n",
    "        # Notice that \\pd{s_i}{beta} = \\pd{}{\\beta} e^{\\beta x_i} / (\\sum_j e^{\\beta x_j})\n",
    "        # = x_i e^{\\beta x_i} / (\\sum_j e^{\\beta x_j}) - e^{\\beta x_i} (\\pd{}{\\beta} \\sum_j e^{\\beta x_j}) / (\\sum_j e^{\\beta x_j})^2\n",
    "        # = x_i s_i - s_i \\sum_k x_k s_k\n",
    "        # = s_i x_i - \\sum_{k} s_i s_k x_k\n",
    "\n",
    "        ## Add the x_i s_i term\n",
    "        newterm = [i_terms, js + [i], a_terms, bs, ps, diracs, coeff]\n",
    "        newterms.append(newterm)\n",
    "\n",
    "        ## Add the -\\sum_k s_i s_k x_k term\n",
    "        newterm = [i_terms + [k], js + [k], a_terms, bs, ps, diracs, -coeff]\n",
    "        newterms.append(newterm)\n",
    "    \n",
    "    for a in a_terms:\n",
    "        ## Add the x_i s_i term\n",
    "        newterm = [i_terms, js, a_terms, bs + [a], ps, diracs, coeff]\n",
    "        newterms.append(newterm)\n",
    "\n",
    "        ## Add the -\\sum_k s_i s_k x_k term\n",
    "        newterm = [i_terms, js, a_terms + [k], bs + [k], ps, diracs, -coeff]\n",
    "        newterms.append(newterm)\n",
    "\n",
    "    for a in newterms:\n",
    "        assert(len(a) == 7)\n",
    "    return newterms\n",
    "\n",
    "\n",
    "def take_beta_deriv_terms(currterms):\n",
    "    newterms = []\n",
    "    for t in currterms:\n",
    "        newterms.extend(take_beta_deriv(t))\n",
    "    return newterms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f3499b8-e57f-42ce-81dd-7ba14e28debe",
   "metadata": {},
   "source": [
    "## Examples of beta derivatives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f2d17ba2-8c64-4bf8-b024-d1dd4d3f2020",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [], [2], [], [], [(1, 2)], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}1(x_{a} = y_{b})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Derivative [[[1], [1], [2], [], [], [(1, 2)], 1], [[1, 3], [3], [2], [], [], [(1, 2)], -1], [[1], [], [2], [2], [], [(1, 2)], 1], [[1], [], [2, 3], [3], [], [(1, 2)], -1]]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}u_{a}1(x_{a} = y_{b})-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}1(x_{a} = y_{b})+\\sum_{a,b}s_{a}t_{b}v_{b}1(x_{a} = y_{b})-\\sum_{a,b,c}s_{a}t_{b}t_{c}v_{c}1(x_{a} = y_{b})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Term representation [[1], [], [1], [], [], [], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Derivative [[[1], [1], [1], [], [], [], 1], [[1, 2], [2], [1], [], [], [], -1], [[1], [], [1], [1], [], [], 1], [[1], [], [1, 2], [2], [], [], -1]]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}u_{a}-\\sum_{a,b}s_{a}s_{b}t_{a}u_{b}+\\sum_{a}s_{a}t_{a}v_{a}-\\sum_{a,b}s_{a}t_{a}t_{b}v_{b}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Term representation [[1], [1], [1], [2], [], [], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}u_{a}v_{b}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Derivative [[[1], [1, 1], [1], [2], [], [], 1], [[1, 2], [1, 2], [1], [2], [], [], -1], [[1], [1], [1], [2, 1], [], [], 1], [[1], [1], [1, 2], [2, 2], [], [], -1]]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}u_{a}u_{a}v_{b}-\\sum_{a,b}s_{a}s_{b}t_{a}u_{a}u_{b}v_{b}+\\sum_{a}s_{a}t_{a}u_{a}v_{b}v_{a}-\\sum_{a,b}s_{a}t_{a}t_{b}u_{a}v_{b}v_{b}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# i_terms, js, a_terms, bs, ps, diracs, coeff\n",
    "term = [[1],[],[2],[], [], [(1,2)], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "deriv = take_beta_deriv(term)\n",
    "print('Derivative',deriv)\n",
    "display_terms(deriv)\n",
    "print()\n",
    "\n",
    "term = [[1],[],[1],[], [], [], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "deriv = take_beta_deriv(term)\n",
    "print('Derivative',deriv)\n",
    "display_terms(deriv)\n",
    "print()\n",
    "\n",
    "\n",
    "term = [[1],[1],[1],[2], [], [], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "deriv = take_beta_deriv(term)\n",
    "print('Derivative',deriv)\n",
    "display_terms(deriv)\n",
    "print()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "827173f8-f32a-4f14-90ed-2f15d69781e2",
   "metadata": {},
   "source": [
    "## Simplifying sums of terms\n",
    "If we iteratively take the $\\beta$ derivatives, we may end up with sums multiple terms. In order to avoid blow-up in the length of the expression, we have code that groups together like terms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8e575100-3979-4283-9d5f-8516da7c2481",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def check_terms_equiv(term1,term2):\n",
    "    \n",
    "    i1, j1, a1, b1, ps1, diracs1, coeff1 = term1\n",
    "    i2, j2, a2, b2, ps2, diracs2, coeff2 = term2\n",
    "\n",
    "    if len(i1) != len(i2):\n",
    "        return False\n",
    "    if len(j1) != len(j2):\n",
    "        return False\n",
    "    if len(a1) != len(a2):\n",
    "        return False\n",
    "    if len(b1) != len(b2):\n",
    "        return False\n",
    "    if len(ps1) != len(ps2):\n",
    "        return False\n",
    "    if len(diracs1) != len(diracs2):\n",
    "        return False\n",
    "    \n",
    "    indices_1 = set(i1 + a1)\n",
    "    indices_2 = set(i2 + a2)\n",
    "    if len(indices_1) != len(indices_2):\n",
    "        return False\n",
    "    for j in j1:\n",
    "        assert(j in indices_1)\n",
    "    for b in b1:\n",
    "        assert(b in indices_1)\n",
    "    for j in j2:\n",
    "        assert(j in indices_2)\n",
    "    for b in b2:\n",
    "        assert(b in indices_2)\n",
    "    for p in ps1:\n",
    "        assert(p in indices_1)\n",
    "    for p in ps2:\n",
    "        assert(p in indices_2)\n",
    "    for v1, v2 in diracs1:\n",
    "        assert(v1 in indices_1)\n",
    "        assert(v2 in indices_1)\n",
    "    for v1, v2 in diracs2:\n",
    "        assert(v1 in indices_2)\n",
    "        assert(v2 in indices_2)\n",
    "        \n",
    "    if len(diracs1) == 0: \n",
    "        rel_dict = {}\n",
    "        indices_used = set()\n",
    "        idx_counts1 = {}\n",
    "        idx_counts2 = {}\n",
    "        for i in indices_1:\n",
    "            idx_counts1[i] = (i1.count(i), j1.count(i), a1.count(i), b1.count(i), ps1.count(i))\n",
    "        for i in indices_2:\n",
    "            idx_counts2[i] = (i2.count(i), j2.count(i), a2.count(i), b2.count(i), ps2.count(i))\n",
    "\n",
    "        for i in indices_1:\n",
    "            found_idx = False\n",
    "            for j in indices_2:\n",
    "                if j in indices_used:\n",
    "                    continue\n",
    "                if idx_counts1[i] == idx_counts2[j]:\n",
    "                    rel_dict[i] = j\n",
    "                    indices_used.add(j)\n",
    "                    found_idx = True\n",
    "                    break\n",
    "            if not found_idx:\n",
    "                return False\n",
    "        return True\n",
    "    elif len(diracs1) == 1:\n",
    "        rel_dict = {}\n",
    "        indices_used = set()\n",
    "        idx_counts1 = {}\n",
    "        idx_counts2 = {}\n",
    "        for i in indices_1:\n",
    "            idx_counts1[i] = (i1.count(i), j1.count(i), a1.count(i), b1.count(i), ps1.count(i))\n",
    "        for i in indices_2:\n",
    "            idx_counts2[i] = (i2.count(i), j2.count(i), a2.count(i), b2.count(i), ps2.count(i))\n",
    "\n",
    "        # Now that we added the dirac 1(x_i = y_j), the index counts alone are not sufficient.\n",
    "        # The dirac terms have to be matched.\n",
    "        v11, v12 = diracs1[0]\n",
    "        v21, v22 = diracs2[0]\n",
    "        if idx_counts1[v11] != idx_counts2[v21]:\n",
    "            return False\n",
    "        if idx_counts1[v12] != idx_counts2[v22]:\n",
    "            return False\n",
    "        rel_dict[v11] = v21\n",
    "        rel_dict[v12] = v22\n",
    "        indices_used.add(v21)\n",
    "        indices_used.add(v22)\n",
    "\n",
    "        for i in indices_1:\n",
    "            if i in rel_dict.keys():\n",
    "                assert(i in [v11, v12])\n",
    "                continue\n",
    "            found_idx = False\n",
    "            for j in indices_2:\n",
    "                if j in indices_used:\n",
    "                    continue\n",
    "                if idx_counts1[i] == idx_counts2[j]:\n",
    "                    rel_dict[i] = j\n",
    "                    indices_used.add(j)\n",
    "                    found_idx = True\n",
    "                    break\n",
    "            if not found_idx:\n",
    "                return False\n",
    "        return True\n",
    "    else:\n",
    "        assert(False) # Case not implemented\n",
    "\n",
    "def simplify_terms(currterms):\n",
    "    covered = set()\n",
    "    equiv_groups = []\n",
    "    for i in tqdm(range(len(currterms))):\n",
    "        if i in covered:\n",
    "            continue\n",
    "        curr_group = [i]\n",
    "        for j in range(i+1,len(currterms)):\n",
    "            if j in covered:\n",
    "                continue\n",
    "            if check_terms_equiv(currterms[i], currterms[j]):\n",
    "                covered.add(j)\n",
    "                curr_group.append(j)\n",
    "        covered.add(i)\n",
    "        equiv_groups.append(curr_group)\n",
    "        \n",
    "    for curr_group in equiv_groups:\n",
    "        for i in curr_group:\n",
    "            for j in curr_group:\n",
    "                assert(check_terms_equiv(currterms[i], currterms[j]))\n",
    "    \n",
    "    simplified_terms = []\n",
    "    for curr_group in equiv_groups:\n",
    "        tot_coeff = 0\n",
    "        for i in curr_group:\n",
    "            tot_coeff += currterms[i][-1]\n",
    "        base_term = copy.deepcopy(currterms[curr_group[0]])\n",
    "        base_term[-1] = tot_coeff\n",
    "        # if tot_coeff == 0:\n",
    "        #     continue\n",
    "        simplified_terms.append(base_term)\n",
    "    return simplified_terms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c3c0d9c-078b-4b5a-a8cf-d3a4aa46a258",
   "metadata": {},
   "source": [
    "## Example of simplifying sums of terms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "550c55fa-9fbb-4c1c-b798-6ea1992a2fc8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [], [2], [2], [], [], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}v_{b}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Derivative [[[1], [1], [2], [2], [], [], 1], [[1, 3], [3], [2], [2], [], [], -1], [[1], [], [2], [2, 2], [], [], 1], [[1], [], [2, 3], [2, 3], [], [], -1]]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}u_{a}v_{b}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}v_{b}+\\sum_{a,b}s_{a}t_{b}v_{b}v_{b}-\\sum_{a,b,c}s_{a}t_{b}t_{c}v_{b}v_{c}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Second derivative [[[1], [1, 1], [2], [2], [], [], 1], [[1, 3], [1, 3], [2], [2], [], [], -1], [[1], [1], [2], [2, 2], [], [], 1], [[1], [1], [2, 3], [2, 3], [], [], -1], [[1, 3], [3, 1], [2], [2], [], [], -1], [[1, 3, 4], [3, 4], [2], [2], [], [], 1], [[1, 3], [3, 3], [2], [2], [], [], -1], [[1, 3, 4], [3, 4], [2], [2], [], [], 1], [[1, 3], [3], [2], [2, 2], [], [], -1], [[1, 3], [3], [2, 4], [2, 4], [], [], 1], [[1], [1], [2], [2, 2], [], [], 1], [[1, 3], [3], [2], [2, 2], [], [], -1], [[1], [], [2], [2, 2, 2], [], [], 1], [[1], [], [2, 3], [2, 2, 3], [], [], -1], [[1], [1], [2, 3], [2, 3], [], [], -1], [[1, 4], [4], [2, 3], [2, 3], [], [], 1], [[1], [], [2, 3], [2, 3, 2], [], [], -1], [[1], [], [2, 3, 4], [2, 3, 4], [], [], 1], [[1], [], [2, 3], [2, 3, 3], [], [], -1], [[1], [], [2, 3, 4], [2, 3, 4], [], [], 1]]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}u_{a}u_{a}v_{b}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{a}u_{c}v_{b}+\\sum_{a,b}s_{a}t_{b}u_{a}v_{b}v_{b}-\\sum_{a,b,c}s_{a}t_{b}t_{c}u_{a}v_{b}v_{c}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}u_{a}v_{b}+\\sum_{a,b,c,d}s_{a}s_{c}s_{d}t_{b}u_{c}u_{d}v_{b}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}u_{c}v_{b}+\\sum_{a,b,c,d}s_{a}s_{c}s_{d}t_{b}u_{c}u_{d}v_{b}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}v_{b}v_{b}+\\sum_{a,b,c,d}s_{a}s_{c}t_{b}t_{d}u_{c}v_{b}v_{d}+\\sum_{a,b}s_{a}t_{b}u_{a}v_{b}v_{b}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}v_{b}v_{b}+\\sum_{a,b}s_{a}t_{b}v_{b}v_{b}v_{b}-\\sum_{a,b,c}s_{a}t_{b}t_{c}v_{b}v_{b}v_{c}-\\sum_{a,b,c}s_{a}t_{b}t_{c}u_{a}v_{b}v_{c}+\\sum_{a,b,c,d}s_{a}s_{d}t_{b}t_{c}u_{d}v_{b}v_{c}-\\sum_{a,b,c}s_{a}t_{b}t_{c}v_{b}v_{c}v_{b}+\\sum_{a,b,c,d}s_{a}t_{b}t_{c}t_{d}v_{b}v_{c}v_{d}-\\sum_{a,b,c}s_{a}t_{b}t_{c}v_{b}v_{c}v_{c}+\\sum_{a,b,c,d}s_{a}t_{b}t_{c}t_{d}v_{b}v_{c}v_{d}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 36615.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified second derivative [[[1], [1, 1], [2], [2], [], [], 1], [[1, 3], [1, 3], [2], [2], [], [], -2], [[1], [1], [2], [2, 2], [], [], 2], [[1], [1], [2, 3], [2, 3], [], [], -2], [[1, 3, 4], [3, 4], [2], [2], [], [], 2], [[1, 3], [3, 3], [2], [2], [], [], -1], [[1, 3], [3], [2], [2, 2], [], [], -2], [[1, 3], [3], [2, 4], [2, 4], [], [], 2], [[1], [], [2], [2, 2, 2], [], [], 1], [[1], [], [2, 3], [2, 2, 3], [], [], -3], [[1], [], [2, 3, 4], [2, 3, 4], [], [], 2]]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}u_{a}u_{a}v_{b}-2\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{a}u_{c}v_{b}+2\\sum_{a,b}s_{a}t_{b}u_{a}v_{b}v_{b}-2\\sum_{a,b,c}s_{a}t_{b}t_{c}u_{a}v_{b}v_{c}+2\\sum_{a,b,c,d}s_{a}s_{c}s_{d}t_{b}u_{c}u_{d}v_{b}-\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}u_{c}v_{b}-2\\sum_{a,b,c}s_{a}s_{c}t_{b}u_{c}v_{b}v_{b}+2\\sum_{a,b,c,d}s_{a}s_{c}t_{b}t_{d}u_{c}v_{b}v_{d}+\\sum_{a,b}s_{a}t_{b}v_{b}v_{b}v_{b}-3\\sum_{a,b,c}s_{a}t_{b}t_{c}v_{b}v_{b}v_{c}+2\\sum_{a,b,c,d}s_{a}t_{b}t_{c}t_{d}v_{b}v_{c}v_{d}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "term = [[1],[],[2],[2], [], [], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "deriv = take_beta_deriv_terms([term])\n",
    "print('Derivative',deriv)\n",
    "display_terms(deriv)\n",
    "deriv2 = take_beta_deriv_terms(deriv)\n",
    "print('Second derivative', deriv2)\n",
    "display_terms(deriv2)\n",
    "simplified_deriv2 = simplify_terms(deriv2)\n",
    "print('Simplified second derivative', simplified_deriv2)\n",
    "display_terms(simplified_deriv2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "463c7c3a-0cc4-4598-aa06-62126d2db90b",
   "metadata": {},
   "source": [
    "## Derivatives in $\\gamma$, when $\\beta = 0$\n",
    "\n",
    "Now, consider derivatives of a term in $\\gamma$, in the case that we are evaluating $\\beta = 0$. I.e., consider\n",
    "\n",
    "$$\\frac{\\partial}{\\partial \\gamma} T_{r,\\mathbf{i},\\mathbf{j},\\mathbf{a},\\mathbf{b},\\mathbf{c},\\mathbf{d}} \\mid_{\\beta = 0}$$\n",
    "\n",
    "Because $\\beta = 0$, the only dependence on $\\gamma$ is through the terms $u_{j_{k}}$ or $v_{b_k}$.\n",
    "\n",
    "These can again be written in terms of sums of terms of the same form.\n",
    "\n",
    "WARNING: We use a formula that requires $\\beta = 0$. If we wish to evaluate an expression of the form $\\frac{\\partial^{s_1}}{\\partial^{s_1} \\beta}\\frac{\\partial^{s_2}}{\\partial \\gamma^{s_2}} T_{r,\\mathbf{i},\\mathbf{j},\\mathbf{a},\\mathbf{b},\\mathbf{c},\\mathbf{d}} \\mid_{\\beta = 0}$, it is important to take all $\\beta$ derivatives first, and then all $\\gamma$ derivatives."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7fe80a9a-dec9-468a-b43c-a9691eeaf1d8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: only use take_gamma_deriv only after you have taken all beta derivs first\n"
     ]
    }
   ],
   "source": [
    "print('WARNING: only use take_gamma_deriv only after you have taken all beta derivs first')\n",
    "def take_gamma_deriv(term):\n",
    "    i_terms, js, a_terms, bs, ps, diracs, coeff = term\n",
    "\n",
    "    newterms = []\n",
    "    \n",
    "    for j_idx, j in enumerate(js):\n",
    "        ## Convert j term to p term\n",
    "        newjs = copy.deepcopy(js)\n",
    "        del newjs[j_idx]\n",
    "        newterm = [i_terms, newjs, a_terms, bs, ps + [j], diracs, coeff]\n",
    "        newterms.append(newterm)\n",
    "    \n",
    "    for b_idx, b in enumerate(bs):\n",
    "        ## Convert b term to p term\n",
    "        newbs = copy.deepcopy(bs)\n",
    "        del newbs[b_idx]\n",
    "        newterm = [i_terms, js, a_terms, newbs, ps + [b], diracs, coeff]\n",
    "        newterms.append(newterm)\n",
    "\n",
    "    return newterms\n",
    "\n",
    "\n",
    "def take_gamma_deriv_terms(currterms):\n",
    "    newterms = []\n",
    "    for t in currterms:\n",
    "        newterms.extend(take_gamma_deriv(t))\n",
    "    return newterms\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5fea6b4-07f2-43c5-9e76-40219f149bc5",
   "metadata": {},
   "source": [
    "## Examples of mixed $\\beta$ derivatives and $\\gamma$ derivatives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9951acfa-db3b-4bc4-a76e-614ea730af73",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [], [2], [], [], [], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a,b}s_{a}t_{b}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Second derivative in beta and gamma, at beta = 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 45814.35it/s]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +2\\sum_{a,b}s_{a}t_{b}p_{a}p_{a}-4\\sum_{a,b,c}s_{a}s_{c}t_{b}p_{a}p_{c}+4\\sum_{a,b}s_{a}t_{b}p_{a}p_{b}-4\\sum_{a,b,c}s_{a}t_{b}t_{c}p_{a}p_{c}+4\\sum_{a,b,c,d}s_{a}s_{c}s_{d}t_{b}p_{c}p_{d}-2\\sum_{a,b,c}s_{a}s_{c}t_{b}p_{c}p_{c}-4\\sum_{a,b,c}s_{a}s_{c}t_{b}p_{c}p_{b}+4\\sum_{a,b,c,d}s_{a}s_{c}t_{b}t_{d}p_{c}p_{d}+2\\sum_{a,b}s_{a}t_{b}p_{b}p_{b}-4\\sum_{a,b,c}s_{a}t_{b}t_{c}p_{b}p_{c}+4\\sum_{a,b,c,d}s_{a}t_{b}t_{c}t_{d}p_{c}p_{d}-2\\sum_{a,b,c}s_{a}t_{b}t_{c}p_{c}p_{c}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "term = [[1],[],[2],[], [], [], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "deriv = take_beta_deriv_terms([term])\n",
    "deriv2 = take_beta_deriv_terms(deriv)\n",
    "deriv3 = take_gamma_deriv_terms(deriv2)\n",
    "deriv4 = take_gamma_deriv_terms(deriv3)\n",
    "print('Second derivative in beta and gamma, at beta = 0')\n",
    "display_terms(simplify_terms(deriv4))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb258e0c-c811-4647-8856-f8ea94e43ac0",
   "metadata": {},
   "source": [
    "## Computing expectation over random $\\zeta,\\xi,p$ at $\\beta = 0, \\gamma = 0$\n",
    "\n",
    "Now consider setting $\\beta = 0, \\gamma = 0$, and taking the expectation over Gaussian $\\mathbf{\\zeta} = [\\zeta_1,\\ldots,\\zeta_k]$ and $\\mathbf{\\xi} = [\\xi_1,\\ldots,\\xi_k]$ and $\\mathbf{p} = [p_1,\\ldots,p_k]$ which have the following covariance structure:\n",
    "* $E[\\zeta_i \\zeta_j] = 1(x_i = x_j)$\n",
    "* $E[\\zeta_i \\xi_j] = 1(x_i = y_j)$\n",
    "* $E[\\zeta_i \\zeta_j] = 1(y_i = y_j)$\n",
    "* $E[p_i p_j] = \\delta_{ij}$\n",
    "* $E[p_i \\zeta_j] = 0$\n",
    "* $E[p_i \\xi_j] = 0$\n",
    "\n",
    "These are the random variables that appear in the expression for the attention kernel.\n",
    "\n",
    "Since $\\beta = 0$, we know that $\\mathbf{s} = [1/k,\\ldots,1/k]$ and $\\mathbf{t} = [1/k,\\ldots,1/k]$. Therefore, the expetation of $T_{r,\\mathbf{i},\\mathbf{j},\\mathbf{a},\\mathbf{b},\\mathbf{c},\\mathbf{d}} \\mid_{\\beta = 0,\\gamma = 0}$ can be computed using Wick's formula, as a sum over matchings. This is done below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "93302e0f-a52c-42b3-b11d-0fcc5557222c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def list_of_matchings(a):\n",
    "    \"\"\"\n",
    "    Utility function that outputs a list of perfect matchings between elements of a; needed for Wick's formula\n",
    "    \"\"\"\n",
    "    if len(a) % 2 != 0:\n",
    "        return []\n",
    "    if len(a) == 0:\n",
    "        return [[]]\n",
    "    if len(a) == 2:\n",
    "        return [[(a[0], a[1])]]\n",
    "    new_matchings = []\n",
    "    for i in range(len(a)-1):\n",
    "        a_copy = copy.deepcopy(a)\n",
    "        a_copy = a_copy[:i] + a_copy[i+1:]\n",
    "        a_copy = a_copy[:-1]\n",
    "        sub_matchings = list_of_matchings(a_copy)\n",
    "        for x in sub_matchings:\n",
    "            new_matchings.append(x + [(a[i], a[-1])])\n",
    "    return new_matchings\n",
    "\n",
    "def compute_expectation_terms_from_term(term):\n",
    "    i_terms, js, a_terms, bs, ps, old_diracs, coeff = term\n",
    "    exp_terms = []\n",
    "    \n",
    "    \n",
    "    vocab_terms = [('x',j) for j in js] + [('y',b) for b in bs]\n",
    "    sum_indices = set(i_terms + a_terms)\n",
    "    \n",
    "    # All the s_i and t_a terms have entries equal to 1/k now\n",
    "    # So they contribute a 1/k^{front_coeff} scaling to the overall output\n",
    "    # We sum over tuples of front_coeff-1 indices, i.e., over k^{front_coeff-1} terms\n",
    "    front_coeff = len(i_terms) + len(a_terms)\n",
    "    \n",
    "    ## SANITY CHECK, HOLDS ONLY BECAUSE WE ONLY CARE ABOUT DERIVATIVES OF ONE OF TWO TERMS IN SOFTMAX SELF-ATTENTION KERNEL:\n",
    "    if len(old_diracs) == 0:\n",
    "        assert(front_coeff - len(sum_indices) == 1)\n",
    "    elif len(old_diracs) == 1:\n",
    "        assert(front_coeff - len(sum_indices) == 0)\n",
    "    else:\n",
    "        assert(False)\n",
    "    \n",
    "    # calc terms using Wick's theorem\n",
    "    # Match the ps terms and the js + bs terms\n",
    "    for mp in list_of_matchings(ps):\n",
    "        # Add a dirac delta for each pair of indices in mp\n",
    "        # We can keep track of equal indices via a union-find data structure\n",
    "        ds = disjoint_set.DisjointSet()\n",
    "        for i in sum_indices:\n",
    "            ds.find(i)\n",
    "        for p1, p2 in mp:\n",
    "            ds.union(p1,p2)\n",
    "        ds_list = list(ds)\n",
    "        rel_dict = {}\n",
    "        for i, j in ds_list:\n",
    "            rel_dict[i] = j\n",
    "        \n",
    "        # For each connected component, we keep one element\n",
    "        used_indices = set([i for i in rel_dict.keys() if rel_dict[i] == i])\n",
    "        \n",
    "        # Note that the sum still remains on the lower-order index\n",
    "        for mv in list_of_matchings(vocab_terms):\n",
    "            # print(mp,mv)\n",
    "            \n",
    "            dirac_terms = []\n",
    "            for v1,v2 in mv:\n",
    "                assert(v1[0] in ['x','y'])\n",
    "                assert(v2[0] in ['x','y'])\n",
    "                if v1[0] == v2[0] and rel_dict[v1[1]] == rel_dict[v2[1]]: # Terms of form 1(x_a = x_a) don't need to be added, since they are 1\n",
    "                    continue\n",
    "                sorted_terms = [(v1[0], rel_dict[v1[1]]),(v2[0], rel_dict[v2[1]])]\n",
    "                sorted_terms.sort()\n",
    "                dirac_terms.append(tuple(sorted_terms))\n",
    "            for v1, v2 in old_diracs:\n",
    "                dirac_terms.append((('x', rel_dict[v1]), ('y', rel_dict[v2])))\n",
    "                \n",
    "            # If there are multiple equivalent terms of the form 1(x_a = x_b), say, we can remove them and keep only the first one, since their product is 1\n",
    "            dirac_terms = list(set(dirac_terms))\n",
    "        \n",
    "            # If there are no terms with a certain variable, we can simplify by summing out that variable, which multiplies the sum by k\n",
    "            actually_used_terms = set([v1[1] for v1,_ in dirac_terms] + [v2[1] for _,v2 in dirac_terms])\n",
    "            for i in actually_used_terms:\n",
    "                assert(i in used_indices)\n",
    "            gap_ind = len(used_indices) - len(actually_used_terms)\n",
    "            \n",
    "            exp_term = front_coeff - gap_ind, actually_used_terms, dirac_terms, coeff\n",
    "            exp_terms.append(exp_term)\n",
    "            \n",
    "    return exp_terms\n",
    "\n",
    "\n",
    "def compute_expectation_terms(terms):\n",
    "    exp_terms = []\n",
    "    for t in terms:\n",
    "        exp_terms += compute_expectation_terms_from_term(t)\n",
    "    return exp_terms\n",
    "\n",
    "def display_expectation_terms(exp_terms):\n",
    "    tot_str = ''\n",
    "    for t in exp_terms:\n",
    "        tot_str += get_expectation_term_str(t)\n",
    "    display(Math(tot_str))\n",
    "        \n",
    "def get_expectation_term_str(exp_term):\n",
    "    \n",
    "    k_exp, actually_used_indices, dirac_terms, coeff = exp_term\n",
    "    \n",
    "    math_expr = coeff_term_str(coeff)\n",
    "    math_expr += r'\\frac{1}{k^{' + str(k_exp) + '}}'\n",
    "    if len(actually_used_indices) > 0:\n",
    "        math_expr += r'\\sum_{' + ','.join([INDEX_TO_VAR_NAME[i] for i in actually_used_indices]) + '}'\n",
    "\n",
    "    for v1, v2 in dirac_terms:\n",
    "        math_expr += r'1(' + v1[0] + '_{' + INDEX_TO_VAR_NAME[v1[1]] + '}=' + v2[0] + '_{' + INDEX_TO_VAR_NAME[v2[1]] + '})'\n",
    "    return math_expr"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44c42439-026c-4c1d-92a0-a06ebf57d734",
   "metadata": {},
   "source": [
    "## Example expectation over random $\\zeta,\\xi,p$ at $\\beta = 0, \\gamma = 0$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4c061305-bbdb-4a2a-bf35-b317f5ac963b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term representation [[1], [], [1], [], [], [], 1]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Second derivative in beta\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 76052.66it/s]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\sum_{a}s_{a}t_{a}u_{a}u_{a}-2\\sum_{a,b}s_{a}s_{b}t_{a}u_{a}u_{b}+2\\sum_{a}s_{a}t_{a}u_{a}v_{a}-2\\sum_{a,b}s_{a}t_{a}t_{b}u_{a}v_{b}+2\\sum_{a,b,c}s_{a}s_{b}s_{c}t_{a}u_{b}u_{c}-\\sum_{a,b}s_{a}s_{b}t_{a}u_{b}u_{b}-2\\sum_{a,b}s_{a}s_{b}t_{a}u_{b}v_{a}+2\\sum_{a,b,c}s_{a}s_{b}t_{a}t_{c}u_{b}v_{c}+\\sum_{a}s_{a}t_{a}v_{a}v_{a}-2\\sum_{a,b}s_{a}t_{a}t_{b}v_{a}v_{b}+2\\sum_{a,b,c}s_{a}t_{a}t_{b}t_{c}v_{b}v_{c}-\\sum_{a,b}s_{a}t_{a}t_{b}v_{b}v_{b}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expectation of the above expression over random zeta, xi, p, at beta = 0, gamma = 0\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +\\frac{1}{k^{1}}-\\frac{1}{k^{3}}\\sum_{a,b}1(x_{a}=x_{b})+\\frac{1}{k^{2}}\\sum_{a}1(x_{a}=y_{a})-\\frac{1}{k^{3}}\\sum_{a,b}1(x_{a}=y_{b})-\\frac{1}{k^{3}}\\sum_{a,b}1(x_{a}=x_{b})+\\frac{1}{k^{3}}\\sum_{b,c}1(x_{b}=x_{c})-\\frac{1}{k^{1}}+\\frac{1}{k^{3}}\\sum_{b,c}1(x_{b}=x_{c})-\\frac{1}{k^{3}}\\sum_{a,b}1(x_{b}=y_{a})+\\frac{1}{k^{3}}\\sum_{b,c}1(x_{b}=y_{c})+\\frac{1}{k^{2}}\\sum_{a}1(x_{a}=y_{a})-\\frac{1}{k^{3}}\\sum_{a,b}1(x_{b}=y_{a})+\\frac{1}{k^{1}}-\\frac{1}{k^{3}}\\sum_{a,b}1(y_{a}=y_{b})-\\frac{1}{k^{3}}\\sum_{a,b}1(x_{a}=y_{b})+\\frac{1}{k^{3}}\\sum_{b,c}1(x_{c}=y_{b})-\\frac{1}{k^{3}}\\sum_{a,b}1(y_{a}=y_{b})+\\frac{1}{k^{3}}\\sum_{b,c}1(y_{b}=y_{c})-\\frac{1}{k^{1}}+\\frac{1}{k^{3}}\\sum_{b,c}1(y_{b}=y_{c})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "term = [[1],[],[1],[], [], [], 1]\n",
    "print('Term representation', term)\n",
    "display_terms([term])\n",
    "deriv = take_beta_deriv_terms([term])\n",
    "deriv2 = take_beta_deriv_terms(deriv)\n",
    "print('Second derivative in beta')\n",
    "display_terms(simplify_terms(deriv2))\n",
    "exp_terms = compute_expectation_terms(deriv2)\n",
    "print('Expectation of the above expression over random zeta, xi, p, at beta = 0, gamma = 0')\n",
    "display_expectation_terms(exp_terms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ccb5591-5f10-4e26-a580-f55ee7612c1f",
   "metadata": {},
   "source": [
    "## Simplifying sums of expectations\n",
    "Similarly to before, we can simplify sums of expectation terms by grouping together like terms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d9fcd89b-c6c5-4557-9c75-74976e963c2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def simplify_expectation_terms(exp_terms, full_simplify=True):\n",
    "\n",
    "    exp_terms = [exp_term_to_first_indices(t) for t in exp_terms]\n",
    "    \n",
    "    ds = disjoint_set.DisjointSet()\n",
    "    for i in range(len(exp_terms)):\n",
    "        ds.find(i)\n",
    "    indices_left = set(list(range(len(exp_terms))))\n",
    "    for i in tqdm(range(len(exp_terms))):\n",
    "        if i not in indices_left:\n",
    "            continue\n",
    "        # print(len(indices_left))\n",
    "        for j in range(i+1,len(exp_terms)):\n",
    "            if j not in indices_left:\n",
    "                continue\n",
    "            if check_equiv_expectation_terms(exp_terms[i], exp_terms[j], full_simplify=full_simplify):\n",
    "                ds.union(i,j)\n",
    "                indices_left.remove(j)\n",
    "        indices_left.remove(i)\n",
    "    # print(list(ds.itersets()))\n",
    "    # print('Regular itersets', list(ds.itersets()))\n",
    "        \n",
    "    \n",
    "    tot_terms = []\n",
    "    for iterset in ds.itersets():\n",
    "        curr_coeff = 0\n",
    "        for idx in iterset:\n",
    "            curr_coeff += exp_terms[idx][-1]\n",
    "        curr_exp_term = tuple(list(exp_terms[list(iterset)[0]][:-1]) + [curr_coeff])\n",
    "        if curr_coeff != 0:\n",
    "            tot_terms.append(curr_exp_term)\n",
    "    return tot_terms\n",
    "\n",
    "def exp_term_to_first_indices(exp_term):\n",
    "    k_exp, sum_indices, dirac_terms, coeff = exp_term\n",
    "\n",
    "    for v1, v2 in dirac_terms:\n",
    "        assert(v1[1] in sum_indices)\n",
    "        assert(v2[1] in sum_indices)\n",
    "    \n",
    "    n = len(sum_indices)\n",
    "    new_sum_indices = set(range(1,n+1))\n",
    "    rel_dict = {}\n",
    "    for i, idx in enumerate(sum_indices):\n",
    "        rel_dict[idx] = i+1\n",
    "    \n",
    "    new_dirac_terms = []\n",
    "    for x in dirac_terms:\n",
    "        new_term = [(x[0][0], rel_dict[x[0][1]]), (x[1][0], rel_dict[x[1][1]])]\n",
    "        new_term.sort()\n",
    "        new_dirac_terms.append(tuple(new_term))\n",
    "    \n",
    "    return (k_exp, new_sum_indices, new_dirac_terms, coeff)\n",
    "\n",
    "\n",
    "def check_equiv_expectation_terms(exp_term1, exp_term2, full_simplify=True,ignore_kexp=False):\n",
    "    k_exp1, sum_indices1, dirac_terms1, coeff1 = exp_term1\n",
    "    k_exp2, sum_indices2, dirac_terms2, coeff2 = exp_term2\n",
    "    \n",
    "    for v1, v2 in dirac_terms1:\n",
    "        assert(v1[1] in sum_indices1)\n",
    "        assert(v2[1] in sum_indices1)\n",
    "    for v1, v2 in dirac_terms2:\n",
    "        assert(v1[1] in sum_indices2)\n",
    "        assert(v2[1] in sum_indices2)\n",
    "    \n",
    "    if not ignore_kexp:\n",
    "        if k_exp1 != k_exp2:\n",
    "            return False\n",
    "    if len(sum_indices1) != len(sum_indices2):\n",
    "        return False\n",
    "    if len(sum_indices1) == 0:\n",
    "        assert(len(dirac_terms1) == 0)\n",
    "        assert(len(dirac_terms2) == 0)\n",
    "        return True\n",
    "    # return False\n",
    "    \n",
    "    sig1 = get_signature_from_expectation_term(exp_term1)\n",
    "    sig2 = get_signature_from_expectation_term(exp_term2)\n",
    "    if sig1 != sig2:\n",
    "        return False\n",
    "    \n",
    "    # Map each of the sum_indices1 to one of the sum_indices2, if possible\n",
    "    # Try each of the permutations\n",
    "    to_idx1 = list(sum_indices1)\n",
    "    to_idx2 = list(sum_indices2)\n",
    "    \n",
    "    n = len(to_idx1)\n",
    "    to_idx1.sort()\n",
    "    to_idx2.sort()\n",
    "    \n",
    "    ## NO LONGER REQUIRE THE SUM INDICES TO BE 1...n+1\n",
    "    assert(to_idx1 == list(range(1,n+1)))\n",
    "    assert(to_idx2 == list(range(1,n+1)))\n",
    "    \n",
    "    mat1 = get_transitive_closure_of_incidence_mat(exp_term1)\n",
    "    mat2 = get_transitive_closure_of_incidence_mat(exp_term2)\n",
    "    \n",
    "    if np.all(mat1 == mat2):\n",
    "        return True\n",
    "    \n",
    "    if not full_simplify:\n",
    "        return False\n",
    "    \n",
    "    for perm in itertools.permutations(range(n)):\n",
    "        # print(perm)\n",
    "        doubleperm = np.zeros(2*n, dtype=np.int32)\n",
    "        doubleperm[0:n] = np.asarray(perm)\n",
    "        doubleperm[n:2*n] = np.asarray(perm)+n\n",
    "        permmat2 = np.array(mat2)\n",
    "        for i in range(2*n):\n",
    "            for j in range(2*n):\n",
    "                permmat2[i,j] = mat2[doubleperm[i], doubleperm[j]]\n",
    "\n",
    "        if np.all(mat1 == permmat2):\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "\n",
    "def get_signature_from_expectation_term(exp_term):\n",
    "    k_exp, sum_indices, dirac_terms, coeff = exp_term\n",
    "    \n",
    "    # Break diracs into connected components\n",
    "    ds = disjoint_set.DisjointSet()\n",
    "    for i in sum_indices:\n",
    "        ds.find(i)\n",
    "    for v1, v2 in dirac_terms:\n",
    "        ds.union(v1[1],v2[1])\n",
    "    # print(list(ds.itersets()))\n",
    "    \n",
    "    # For each connected component, keep track of the number of x, y, and xy variables\n",
    "    x_vars = set()\n",
    "    y_vars = set()\n",
    "    for v1, v2 in dirac_terms:\n",
    "        if v1[0] == 'x':\n",
    "            x_vars.add(v1[1])\n",
    "        elif v1[0] == 'y':\n",
    "            y_vars.add(v1[1])\n",
    "        else:\n",
    "            assert(False)\n",
    "        \n",
    "        if v2[0] == 'x':\n",
    "            x_vars.add(v2[1])\n",
    "        elif v2[0] == 'y':\n",
    "            y_vars.add(v2[1])\n",
    "        else:\n",
    "            assert(False)\n",
    "            \n",
    "    xy_vars = x_vars.intersection(y_vars)\n",
    "    x_only_vars = x_vars.difference(xy_vars)\n",
    "    y_only_vars = y_vars.difference(xy_vars)\n",
    "        \n",
    "    signature = []\n",
    "    for iterset in ds.itersets():\n",
    "        xy_ct = len(iterset.intersection(xy_vars))\n",
    "        y_ct = len(iterset.intersection(y_only_vars))\n",
    "        x_ct = len(iterset.intersection(x_only_vars))\n",
    "        signature.append((x_ct,y_ct,xy_ct))\n",
    "    signature.sort()\n",
    "    return signature\n",
    "\n",
    "\n",
    "def get_incidence_mat(exp_term):\n",
    "    ## Assumes that sum_indices are 1...n+1\n",
    "    k_exp, sum_indices, dirac_terms, coeff = exp_term\n",
    "    n = len(sum_indices)\n",
    "    assert(list(sum_indices) == list(range(1,n+1)))\n",
    "    n = len(sum_indices)\n",
    "    mat = np.zeros((2*n,2*n))\n",
    "    for term in dirac_terms:\n",
    "        v1, v2 = term\n",
    "        i1 = v1[1]-1\n",
    "        i2 = v2[1]-1\n",
    "        if v1[0] == 'y':\n",
    "            i1 += n\n",
    "        if v2[0] == 'y':\n",
    "            i2 += n\n",
    "        mat[i1,i2] = 1\n",
    "        mat[i2,i1] = 1\n",
    "    for i in range(2*n):\n",
    "        mat[i,i] = 1\n",
    "    return mat\n",
    "\n",
    "def transitive_closure_of_mat(mat):\n",
    "    # Floyd warshall algorithm\n",
    "    n = mat.shape[0]\n",
    "    for i in range(n):\n",
    "        mat = mat @ mat\n",
    "        mat = 1 * (mat > 0)\n",
    "    return mat\n",
    "\n",
    "def get_transitive_closure_of_incidence_mat(exp_term):\n",
    "    return transitive_closure_of_mat(get_incidence_mat(exp_term))\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc41fc94-2370-4eaa-9d4c-77ad69f0fb62",
   "metadata": {},
   "source": [
    "## Example of simplifying sum of expectations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "438fb600-689a-4316-a2b6-64570bad9d06",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 8071.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified version of above terms\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +2\\frac{1}{k^{2}}\\sum_{a}1(x_{a}=y_{a})-2\\frac{1}{k^{3}}\\sum_{a,b}1(x_{a}=y_{b})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "exp_terms = simplify_expectation_terms(exp_terms)\n",
    "print('Simplified version of above terms')\n",
    "display_expectation_terms(exp_terms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01cc3aaf-ebeb-427c-9788-33e7b14ec6e4",
   "metadata": {},
   "source": [
    "## More compact output format\n",
    "The equations in the above format can be complicated to read, especially if there are many indices. Here, we provide code that writes them in more compact matrix notation. This works by creating a list of terms `EXP_TERMS_REF_LIST` for which we manually determine an equivalent linear-algebraic expression. While outputting, if a term corresponds to a term from `EXP_TERMS_REF_LIST`, we replace it with the corresponding expression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "63b2085c-17ce-477c-a760-d10a99d1cb1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "EXP_TERMS_REF_LIST = [((2, set(), [], -1152), ''),\n",
    "                      \n",
    "             ((0, {1, 2}, [(('x', 2), ('y', 1))], 1), '{\\color{green} {1^TXY^T 1}}'),\n",
    "             ((3, {1, 2, 3}, [(('x', 1), ('y', 2)), (('x', 1), ('x', 3))], -4), '{\\color{green} {1^TXX^TXY^T1}}'),     \n",
    "             ((3, {1, 2, 3}, [(('x', 1), ('y', 2)), (('y', 2), ('y', 3))], -4),   '{\\color{green} {1^TXY^TYY^T1}}'),\n",
    "             ((4, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('x', 2), ('x', 3)), (('x', 3), ('y', 4))], 192), '{\\color{green} {1^TXX^TXX^TXY^T1}}'),\n",
    "             ((4, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('y', 2), ('x', 3)), (('x', 3), ('y', 4))], 192), '{\\color{green} {1^TXY^TYX^TXY^T1}}'),\n",
    "             ((4, {1, 2, 3, 4}, [(('y', 1), ('x', 2)), (('x', 2), ('y', 3)), (('y', 3), ('y', 4))], 192), '{\\color{green} {1^TYX^TXY^TYY^T1}}'),\n",
    "             ((0, {1, 2, 3, 4, 5}, [(('x', 1), ('y', 2)), (('x', 4), ('x', 5)), (('x', 3), ('y', 2)), (('x', 3), ('x', 4))], 1), '{\\color{green} {1^TXY^TYX^TXX^TXX^T1}}'),\n",
    "             ((0, {1, 2, 3, 4, 5}, [(('x', 1), ('y', 2)), (('y', 2), ('y', 3)), (('y', 3), ('x', 4)), (('x', 4), ('x', 5))], 1), '{\\color{green} {1^TXY^TYY^TYX^TXX^T1}}'),\n",
    "             ((0, {1, 2, 3, 4, 5}, [(('y', 1), ('x', 2)), (('x', 2), ('x', 3)), (('x', 3), ('y', 4)), (('y', 4), ('y', 5))], 1), '{\\color{green} {1^TYX^TXX^TXY^TYY^T1}}'),\n",
    "             ((0, {1, 2, 3, 4, 5}, [(('y', 1), ('y', 2)), (('y', 2), ('x', 3)), (('x', 3), ('y', 4)), (('y', 4), ('y', 5))], 1), '{\\color{green} {1^TYY^TYX^TXY^TYY^T1}}'),        \n",
    "             ((3, {1, 2}, [(('x', 1), ('x', 2)), (('x', 2), ('y', 1))], -336), '{\\color{green} {tr(XX^TXY^T)}}'),\n",
    "             ((3, {1, 2}, [(('x', 1), ('y', 2)), (('y', 2), ('y', 1))], -336), '{\\color{green} {tr(XY^TYY^T)}}'),\n",
    "                      \n",
    "                      \n",
    "             ((4, {1, 2}, [(('x', 1), ('x', 2))], 576), '{\\color{orange} {1^TXX^T1}}'),\n",
    "             ((4, {1, 2}, [(('y', 1), ('y', 2))], 576), '{\\color{orange} {1^TYY^T1}}'),\n",
    "                      \n",
    "             ((2, {1}, [(('x', 1), ('y', 1))], 864), '{\\color{red} {tr(XY^T)}}'),\n",
    "             ((4, {1, 2, 3}, [(('x', 2), ('y', 1)), (('x', 1), ('y', 3))], 144),  '{\\color{red} {1^TXY^TXY^T1}}'),\n",
    "             ((0, {1, 2, 3}, [(('x', 1), ('y', 2)), (('x', 2), ('x', 3))], 1),    '{\\color{red} {1^TXX^TYX^T1}}'),\n",
    "             ((4, {1, 2, 3}, [(('y', 1), ('y', 2)), (('x', 1), ('y', 3))], 192),  '{\\color{red} {1^TYX^TYY^T1}}'),\n",
    "             ((0, {1, 2}, [(('x', 2), ('y', 1)), (('x', 1), ('y', 2))], 1), '{\\color{red} {tr(XY^T XY^T)}}'),\n",
    "                      \n",
    "             ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('y', 1), ('y', 3))], 1),    '{\\color{purple} {1^TXX^TYY^T1}}'),\n",
    "             ((0, {1, 2}, [(('y', 1), ('y', 2)), (('x', 1), ('x', 2))], 1), '{\\color{blue}{ tr(XX^T YY^T)}}'),\n",
    "                      \n",
    "             ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('x', 1), ('x', 3))], 1),    '1^TXX^TXX^T1'),\n",
    "             ((0, {1, 2, 3}, [(('y', 1), ('y', 3)), (('y', 1), ('y', 2))], 1),    '1^TYY^TYY^T1'),\n",
    "    \n",
    "#              ((0, {1, 2, 3}, [(('y', 1), ('x', 2)), (('x', 2), ('y', 3)), (('y', 3), ('x', 1))], 1), 'tr(YX^TXY^TYX^T)'),\n",
    "#              ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('x', 2), ('x', 3)), (('x', 3), ('y', 3))], 1), 'tr(XX^TXX^TXY^T)'),\n",
    "#              ((0, {1, 2, 3}, [(('y', 1), ('y', 2)), (('y', 2), ('y', 3)), (('y', 3), ('x', 1))], 1), 'tr(YY^TYY^TYX^T)'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('x', 2), ('x', 3)), (('x', 3), ('x', 4))], 1), '1^TXX^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('y', 2)), (('y', 2), ('y', 3)), (('y', 3), ('y', 4))], 1), '1^TYY^TYY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('x', 2), ('x', 3)), (('x', 3), ('y', 4))], 1), '1^TXY^TXX^TXY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('y', 2), ('x', 3)), (('y', 3), ('y', 4))], 1), '1^TXY^TYX^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('y', 2)), (('y', 2), ('x', 3)), (('y', 3), ('y', 4))], 1), '1^TYY^TYX^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('y', 2), ('x', 3)), (('x', 3), ('y', 4))], 1), '1^TXX^TYX^TXY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('y', 2), ('y', 3)), (('x', 3), ('y', 4))], 1), '1^TXY^TYY^TXY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('x', 2)), (('y', 2), ('y', 3)), (('y', 3), ('y', 4))], 1), '1^TYX^TYY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('x', 2), ('x', 3)), (('x', 3), ('x', 4))], 1), '1^TXY^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('y', 2)), (('x', 2), ('x', 3)), (('x', 3), ('x', 4))], 1), '1^TYY^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('y', 2), ('y', 3)), (('y', 3), ('y', 4))], 1), '1^TXX^TYY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('x', 2), ('y', 3)), (('x', 3), ('x', 4))], 1), '1^TXX^TXY^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 2), ('y', 1)), (('x', 1), ('y', 4)), (('y', 4), ('y', 3))], 1), '1^TXY^TXY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 3)), (('x', 3), ('y', 2)), (('x', 2), ('y', 4))], 1), '1^TXX^TXY^TXY^T1'),           \n",
    "            ]\n",
    "\n",
    "def get_expectation_term_str_for_equiv_terms(exp_terms,add_coeff=True, simple_mode=True):\n",
    "    if add_coeff:\n",
    "        math_expr = '+('\n",
    "        for i in range(len(exp_terms)):\n",
    "            k_exp, actually_used_indices, dirac_terms, coeff = exp_terms[i]\n",
    "            math_expr += coeff_term_str(coeff)\n",
    "            math_expr += r'\\frac{1}{k^{' + str(k_exp) + '}}'\n",
    "\n",
    "        math_expr += ')'\n",
    "    else:\n",
    "        math_expr=''\n",
    "    \n",
    "    k_exp, actually_used_indices, dirac_terms, coeff = exp_terms[0]\n",
    "    \n",
    "    if simple_mode:\n",
    "        if len(actually_used_indices) > 0:\n",
    "            math_expr += r'\\sum_{' + ','.join([INDEX_TO_VAR_NAME[i] for i in actually_used_indices]) + '}'\n",
    "\n",
    "        for v1, v2 in dirac_terms:\n",
    "            math_expr += r'1(' + v1[0] + '_{' + INDEX_TO_VAR_NAME[v1[1]] + '}=' + v2[0] + '_{' + INDEX_TO_VAR_NAME[v2[1]] + '})'\n",
    "        return math_expr\n",
    "    else:\n",
    "        # Search for equivalent expression in EXP_TERMS_REF_LIST, and use the representation from there\n",
    "        for k, v in EXP_TERMS_REF_LIST:\n",
    "            if check_equiv_expectation_terms(exp_terms[0], k, full_simplify=True,ignore_kexp=True):\n",
    "                return math_expr + v\n",
    "        \n",
    "        display_expectation_terms([exp_terms[0]])\n",
    "        print(exp_terms[0])\n",
    "        assert(False)\n",
    "\n",
    "def get_expectation_terms_compact_str(exp_terms,simple_mode=True, add_coeff=False):\n",
    "    # Group together expectation terms by front coeff\n",
    "    exp_terms = [exp_term_to_first_indices(t) for t in exp_terms]\n",
    "    ds = disjoint_set.DisjointSet()\n",
    "    for i in range(len(exp_terms)):\n",
    "        ds.find(i)\n",
    "    for i in range(len(exp_terms)):\n",
    "        for j in range(i+1,len(exp_terms)):\n",
    "            if check_equiv_expectation_terms(exp_terms[i], exp_terms[j], full_simplify=True,ignore_kexp=True):\n",
    "                ds.union(i,j)\n",
    "                break\n",
    "\n",
    "                \n",
    "    tot_str_list = []\n",
    "    for i, iterset in enumerate(ds.itersets()):\n",
    "        tot_str = ''\n",
    "        if add_coeff:\n",
    "            tot_str += get_expectation_term_str_for_equiv_terms([exp_terms[i] for i in iterset], add_coeff=True, simple_mode=simple_mode)\n",
    "        else:\n",
    "            tot_str += '+c_{' + str(i+1) + '}' + get_expectation_term_str_for_equiv_terms([exp_terms[i] for i in iterset], add_coeff=False, simple_mode=simple_mode)\n",
    "        tot_str_list.append(tot_str)\n",
    "    return tot_str_list\n",
    "\n",
    "def display_expectation_terms_compact(exp_terms,simple_mode=True, add_coeff=False):\n",
    "    tot_str_list = get_expectation_terms_compact_str(exp_terms, simple_mode=simple_mode, add_coeff=add_coeff)\n",
    "    display(Math(''.join(tot_str_list)))\n",
    "\n",
    "def get_expectation_term_str_for_equiv_terms(exp_terms,add_coeff=True, simple_mode=True):\n",
    "    if add_coeff:\n",
    "        math_expr = '+('\n",
    "        for i in range(len(exp_terms)):\n",
    "            k_exp, actually_used_indices, dirac_terms, coeff = exp_terms[i]\n",
    "            math_expr += coeff_term_str(coeff)\n",
    "            math_expr += r'\\frac{1}{k^{' + str(k_exp) + '}}'\n",
    "\n",
    "        math_expr += ')'\n",
    "    else:\n",
    "        math_expr=''\n",
    "    \n",
    "    k_exp, actually_used_indices, dirac_terms, coeff = exp_terms[0]\n",
    "    \n",
    "    if simple_mode:\n",
    "        if len(actually_used_indices) > 0:\n",
    "            math_expr += r'\\sum_{' + ','.join([INDEX_TO_VAR_NAME[i] for i in actually_used_indices]) + '}'\n",
    "\n",
    "        for v1, v2 in dirac_terms:\n",
    "            math_expr += r'1(' + v1[0] + '_{' + INDEX_TO_VAR_NAME[v1[1]] + '}=' + v2[0] + '_{' + INDEX_TO_VAR_NAME[v2[1]] + '})'\n",
    "        return math_expr\n",
    "    else:\n",
    "        # Break the products of diracs into connected components\n",
    "        # The term is the product of these connected components\n",
    "        \n",
    "        assert(actually_used_indices == set([dterm[0][1] for dterm in dirac_terms] + [dterm[1][1] for dterm in dirac_terms]))\n",
    "        \n",
    "        ds = disjoint_set.DisjointSet()\n",
    "        for i in actually_used_indices:\n",
    "            ds.find(i)\n",
    "        for v1, v2 in dirac_terms:\n",
    "            ds.union(v1[1],v2[1])\n",
    "        # print('DS itersets', list(ds.itersets()))\n",
    "\n",
    "        found_term_strs = []\n",
    "        for iterset in ds.itersets():\n",
    "            # print(iterset)\n",
    "            dirac_subset = [dterm for dterm in dirac_terms if dterm[0][1] in iterset or dterm[1][1] in iterset]\n",
    "            # print(dirac_subset)\n",
    "            currstr = get_dirac_term_str(iterset, dirac_subset)\n",
    "            \n",
    "            found_term_strs.append(currstr)\n",
    "        # print(found_term_strs)\n",
    "        if len(found_term_strs) == 1:\n",
    "            math_expr += found_term_strs[0]\n",
    "            return math_expr\n",
    "        elif len(found_term_strs) > 0:\n",
    "            math_expr += '('\n",
    "            math_expr += ')('.join(found_term_strs)\n",
    "            math_expr += ')'\n",
    "            return math_expr\n",
    "        else:\n",
    "            math_expr += ''\n",
    "            assert(len(actually_used_indices) == 0)\n",
    "            return math_expr\n",
    "            # display_expectation_terms([exp_terms[0]])\n",
    "            # print(exp_terms[0])\n",
    "            # assert(False)\n",
    "\n",
    "def get_dirac_term_str(actually_used_indices, dirac_terms):\n",
    "    \n",
    "    math_expr = ''\n",
    "    \n",
    "    curr_exp_term = (0, actually_used_indices, dirac_terms, 1)\n",
    "    curr_exp_term_first_ind = exp_term_to_first_indices(curr_exp_term)\n",
    "    \n",
    "    # Search for equivalent expression in EXP_TERMS_REF_LIST, and use the representation from there\n",
    "    for k, v in EXP_TERMS_REF_LIST:\n",
    "        if check_equiv_expectation_terms(curr_exp_term_first_ind, k, full_simplify=True,ignore_kexp=True):\n",
    "            if v == 'default':\n",
    "                print('USING DEFAULT')\n",
    "                math_expr += '{\\color{red}{'\n",
    "                if len(actually_used_indices) > 0:\n",
    "                    math_expr += r'\\sum_{' + ','.join([INDEX_TO_VAR_NAME[i] for i in actually_used_indices]) + '}'\n",
    "\n",
    "                for v1, v2 in dirac_terms:\n",
    "                    math_expr += r'1(' + v1[0] + '_{' + INDEX_TO_VAR_NAME[v1[1]] + '}=' + v2[0] + '_{' + INDEX_TO_VAR_NAME[v2[1]] + '})'\n",
    "                math_expr += '}}'\n",
    "                return math_expr\n",
    "            else:\n",
    "\n",
    "                return math_expr + v\n",
    "    display_expectation_terms([curr_exp_term])\n",
    "    print(curr_exp_term)\n",
    "    print(curr_exp_term_first_ind)\n",
    "    return '{\\color{red}{ERROR}}'\n",
    "    # assert(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bfe2d4b6-2d8a-4cbc-93b6-3041f589c7ab",
   "metadata": {},
   "source": [
    "## Example of more compact output format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f413128e-dffe-45e7-8e5b-08deb553806c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The expectation terms from before in simpler linear-algebraic format\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +(+2\\frac{1}{k^{2}}){\\color{red} {tr(XY^T)}}+(-2\\frac{1}{k^{3}}){\\color{green} {1^TXY^T 1}}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The expectation terms from before in simpler linear-algebraic format, hiding the coefficients that depend on k\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}{\\color{red} {tr(XY^T)}}+c_{2}{\\color{green} {1^TXY^T 1}}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "print('The expectation terms from before in simpler linear-algebraic format')\n",
    "display_expectation_terms_compact(exp_terms,simple_mode=False,add_coeff=True)\n",
    "print('The expectation terms from before in simpler linear-algebraic format, hiding the coefficients that depend on k')\n",
    "display_expectation_terms_compact(exp_terms,simple_mode=False,add_coeff=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37058ed3-b98a-44c1-8e0e-8b09f17a68f5",
   "metadata": {},
   "source": [
    "## Fuzz-testing to ensure correctness\n",
    "To check correctness of the computer algebra system -- and specifically the simplification step, which is the most complex, we fuzz-test the computed expectations. For this, we evaluate the derivatives for random inputs $x_1,\\ldots,x_k$ and $y_1,\\ldots,y_k$. Fuzz-testing code is below, and fuzz tests for validity are conducted when computing the derivatives that we require for our paper."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "430df13f-0550-44be-9407-d20982b3257b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rand_sequence(k,m):\n",
    "    Xinds = np.random.randint(0,m,k)\n",
    "    return Xinds\n",
    "\n",
    "def sequence_to_one_hot_matrix(Xinds,k,m):\n",
    "    assert(len(Xinds) == k)\n",
    "    X = np.zeros((k,m))\n",
    "    z = [[i,Xinds[i]] for i in range(k)]\n",
    "    z = np.asarray(z,dtype=np.int32)\n",
    "    X[z[:,0],z[:,1]] = 1\n",
    "    return X\n",
    "\n",
    "def eval_exp_terms(X,Y,exp_terms,ignore_coeff=True, ignore_k_exp=False):\n",
    "    tot = 0\n",
    "    for t in tqdm(exp_terms):\n",
    "        # print(t)\n",
    "        tot += eval_exp_term(X,Y,t,ignore_coeff=ignore_coeff, ignore_k_exp=ignore_k_exp)\n",
    "    return tot\n",
    "\n",
    "def eval_exp_term(X,Y,exp_term,ignore_coeff=True, ignore_k_exp=False):\n",
    "    k = X.shape[0]\n",
    "    # m = X.shape[1]\n",
    "    assert(k == Y.shape[0])\n",
    "    assert(len(X.shape) == 1)\n",
    "    assert(len(Y.shape) == 1)\n",
    "    # assert(m == Y.shape[1])\n",
    "    \n",
    "    k_exp, actually_used_indices, dirac_terms, coeff = exp_term\n",
    "    \n",
    "    scaling = 1\n",
    "    if not ignore_k_exp:\n",
    "        scaling = scaling * math.pow(k,-k_exp)\n",
    "    if not ignore_coeff:\n",
    "        scaling = scaling * coeff\n",
    "    \n",
    "    actually_used_indices = list(actually_used_indices)\n",
    "    idx_dict = {}\n",
    "    for i in range(len(actually_used_indices)):\n",
    "        idx_dict[actually_used_indices[i]] = i\n",
    "    \n",
    "    tot_sum = 0\n",
    "    \n",
    "    for idx_vec in itertools.product(range(k),repeat=len(actually_used_indices)):\n",
    "        # print(idx_vec)\n",
    "        curr_dirac_prod = 1\n",
    "        for dterm in dirac_terms:\n",
    "            if dterm[0][0] == 'x':\n",
    "                v1 = X[idx_vec[idx_dict[dterm[0][1]]]]\n",
    "            elif dterm[0][0] == 'y':\n",
    "                v1 = Y[idx_vec[idx_dict[dterm[0][1]]]]\n",
    "            else:\n",
    "                assert(False)\n",
    "            if dterm[1][0] == 'x':\n",
    "                v2 = X[idx_vec[idx_dict[dterm[1][1]]]]\n",
    "            elif dterm[1][0] == 'y':\n",
    "                v2 = Y[idx_vec[idx_dict[dterm[1][1]]]]\n",
    "            else:\n",
    "                assert(False)\n",
    "            if v1 != v2:\n",
    "                curr_dirac_prod = 0\n",
    "                break\n",
    "        tot_sum += curr_dirac_prod\n",
    "    \n",
    "    return scaling * tot_sum    \n",
    "    \n",
    "\n",
    "def fuzz_for_equiv(exp_terms,k,m,num_fuzzers,ignore_coeff=True, ignore_k_exp=False):\n",
    "    fuzzers = []\n",
    "    for i in range(num_fuzzers):\n",
    "        X = rand_sequence(k=k,m=m)\n",
    "        Y = rand_sequence(k=k,m=m)\n",
    "        fuzzers.append((X,Y))\n",
    "\n",
    "    fuzzer_outs = []\n",
    "    for i in tqdm(range(len(exp_terms))):\n",
    "        # print(exp_terms[i])\n",
    "        fuzzer_outs.append([])\n",
    "        for f in fuzzers:\n",
    "            ev = eval_exp_term(f[0],f[1],exp_terms[i], ignore_coeff=ignore_coeff, ignore_k_exp=ignore_k_exp)\n",
    "            fuzzer_outs[i].append(ev)\n",
    "        fuzzer_outs[i] = tuple(fuzzer_outs[i])\n",
    "    \n",
    "    ds = disjoint_set.DisjointSet()\n",
    "    fuzz_dict = {}\n",
    "    for i in range(len(exp_terms)):\n",
    "        ds.find(i)\n",
    "        if fuzzer_outs[i] in fuzz_dict:\n",
    "            ds.union(i, fuzz_dict[fuzzer_outs[i]])\n",
    "        else:\n",
    "            fuzz_dict[fuzzer_outs[i]] = i\n",
    "    return list(ds.itersets())\n",
    "\n",
    "\n",
    "def eval_on_fuzzers(exp_terms,fuzzers):\n",
    "    exp_evals = []\n",
    "    for i in range(len(fuzzers)):\n",
    "        ev = eval_exp_terms(fuzzers[i][0], fuzzers[i][1],exp_terms,ignore_coeff=False, ignore_k_exp=False)\n",
    "        exp_evals.append(ev)\n",
    "    return exp_evals\n",
    "\n",
    "def fuzz_compare(exp_terms1,exp_terms2):\n",
    "    k_fuzz = 5\n",
    "    m_fuzz = 8\n",
    "    num_fuzzers = 10\n",
    "\n",
    "    fuzzers = []\n",
    "    for i in range(num_fuzzers):\n",
    "        X = rand_sequence(k=k_fuzz,m=m_fuzz)\n",
    "        Y = rand_sequence(k=k_fuzz,m=m_fuzz)\n",
    "        fuzzers.append((X,Y))\n",
    "\n",
    "    print('Fuzzing')\n",
    "    evals1 = eval_on_fuzzers(exp_terms1, fuzzers)\n",
    "    evals2 = eval_on_fuzzers(exp_terms2, fuzzers)\n",
    "    print(evals1)\n",
    "    print(evals2)\n",
    "    for i in range(num_fuzzers):\n",
    "        assert(abs(evals1[i] - evals2[i]) < 1e-6)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c123a99-39bf-4800-9e9b-b149619b0ad4",
   "metadata": {},
   "source": [
    "## Computations on the transformer random features kernel\n",
    "\n",
    "The transformer random features kernel is\n",
    "\n",
    "$$\\kappa_{X,Y}(\\beta,\\gamma) = E_{\\zeta,\\xi,p}[\\mathrm{softmax}(\\beta u)^T XY^T \\mathrm{softmax}(\\beta v) + \\gamma^2 \\mathrm{softmax}(\\beta u)^T \\mathrm{softmax}(\\beta v)]$$\n",
    "\n",
    "Changing notation slightly, this is a sum of two terms that can be written in our computer algebra system:\n",
    "$$\\kappa_{X,Y}(\\beta,\\gamma) = E_{\\zeta,\\xi,p}[\\sum_{a,b} s_at_b 1(x_a=y_b) + \\gamma^2 \\sum_{a} s_at_a]\\,.$$\n",
    "\n",
    "So we for any $s_1,s_2$ we can compute\n",
    "$$\\frac{\\partial^{s_1}}{\\partial \\beta^{s_1}} \\frac{\\partial^{s_2}}{\\partial \\gamma^{s_2}} \\kappa_{X,Y}(\\beta,\\gamma)\\,,$$\n",
    "which we do below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ee02d57e-efdc-4b76-9ea8-ffec3d570633",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_transformer_rf_deriv(num_beta_derivs, num_gamma_derivs, fuzz_test=False):\n",
    "\n",
    "    ## Term 1:\n",
    "    # # s, m, t, n, p, diracs, coeff\n",
    "    # ## smax(beta u)^T X Y^T smax(beta v)\n",
    "    startterm = [[1],[],[2],[], [], [(1,2)], 1]\n",
    "\n",
    "    currterms = [startterm]\n",
    "    for i in range(num_beta_derivs):\n",
    "        currterms = simplify_terms(take_beta_deriv_terms(currterms))\n",
    "        print(f'beta {i+1}, Simplified len', len(currterms))\n",
    "        # display_terms(currterms)\n",
    "\n",
    "    for i in range(num_gamma_derivs):\n",
    "        currterms = simplify_terms(take_gamma_deriv_terms(currterms))\n",
    "        print(f'gamma {i+1}, Simplified len', len(currterms))\n",
    "        # display_terms(currterms)\n",
    "    # display_terms(currterms)\n",
    "\n",
    "    exp_terms = compute_expectation_terms(currterms)\n",
    "    simplified_exp_terms = exp_terms\n",
    "    print('Exp terms, len', len(exp_terms))\n",
    "    simplified_exp_terms = simplify_expectation_terms(simplified_exp_terms, full_simplify=False)\n",
    "    print('Exp terms, partially simplified len', len(simplified_exp_terms))\n",
    "    if fuzz_test:\n",
    "        fuzz_compare(exp_terms,simplified_exp_terms)\n",
    "    # display_expectation_terms_compact(simplified_exp_terms, simple_mode=True)\n",
    "    simplified_exp_terms2 = simplify_expectation_terms(simplified_exp_terms, full_simplify=True)\n",
    "    print('Exp terms, fully simplified len', len(simplified_exp_terms2))\n",
    "    if fuzz_test:\n",
    "        fuzz_compare(simplified_exp_terms,simplified_exp_terms2)\n",
    "    display_expectation_terms_compact(simplified_exp_terms2, simple_mode=True)\n",
    "    print('Term 1 Expectation',len(simplified_exp_terms2))\n",
    "    # display_expectation_terms(simplified_exp_terms)\n",
    "\n",
    "    exp_terms1 = simplified_exp_terms2\n",
    "\n",
    "    ## Term 2:\n",
    "    # # s, m, t, n, p, diracs, coeff\n",
    "    # # ## gamma^2 smax(beta u)^T smax(beta v)\n",
    "    startterm = [[1],[],[1],[], [], [], num_gamma_derivs * (num_gamma_derivs-1)]\n",
    "\n",
    "    currterms = [startterm]\n",
    "    for i in range(num_beta_derivs):\n",
    "        currterms = simplify_terms(take_beta_deriv_terms(currterms))\n",
    "        print(f'beta {i+1}, Simplified len', len(currterms))\n",
    "        # display_terms(currterms)\n",
    "\n",
    "    for i in range(num_gamma_derivs-2):\n",
    "        currterms = simplify_terms(take_gamma_deriv_terms(currterms))\n",
    "        print(f'gamma {i+1}, Simplified len', len(currterms))\n",
    "        # display_terms(currterms)\n",
    "    # display_terms(currterms)\n",
    "\n",
    "    exp_terms = compute_expectation_terms(currterms)\n",
    "    simplified_exp_terms = exp_terms\n",
    "    print('Exp terms, len', len(exp_terms))\n",
    "    simplified_exp_terms = simplify_expectation_terms(simplified_exp_terms, full_simplify=False)\n",
    "    print('Exp terms, partially simplified len', len(simplified_exp_terms))\n",
    "    # display_expectation_terms_compact(simplified_exp_terms, simple_mode=True)\n",
    "    simplified_exp_terms = simplify_expectation_terms(simplified_exp_terms, full_simplify=True)\n",
    "    print('Exp terms, fully simplified len', len(simplified_exp_terms))\n",
    "    display_expectation_terms_compact(simplified_exp_terms, simple_mode=True)\n",
    "    print('Term 2 Expectation',len(simplified_exp_terms))\n",
    "    # display_expectation_terms(simplified_exp_terms)\n",
    "    exp_terms2 = simplified_exp_terms\n",
    "    \n",
    "    exp_terms = exp_terms1 + exp_terms2\n",
    "    simplified_exp_terms = exp_terms\n",
    "    simplified_exp_terms = simplify_expectation_terms(simplified_exp_terms, full_simplify=False)\n",
    "    simplified_exp_terms = simplify_expectation_terms(simplified_exp_terms, full_simplify=True)\n",
    "    print('Simplified len', len(simplified_exp_terms))\n",
    "    return simplified_exp_terms"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcf0b258-677c-4a41-acfe-6a75b44523b4",
   "metadata": {},
   "source": [
    "## Computing and saving the derivatives of the transformer random features kernel that we care about"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f891e3a1-7a00-44b6-bf54-87f632606b9c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, len 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5133.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 1\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8701.88it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9341.43it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4116.10it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4080.06it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9020.01it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8322.03it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8848.74it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9000.65it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9709.04it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4096.00it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7244.05it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7810.62it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9619.96it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9238.56it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7869.24it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7810.62it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8050.49it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3953.16it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8542.37it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8405.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.16, 0.2, 0.08, 0.2, 0.2, 0.16, 0.12, 0.08, 0.2, 0.2]\n",
      "[0.16, 0.2, 0.08, 0.2, 0.2, 0.16, 0.12, 0.08, 0.2, 0.2]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12372.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 1\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4634.59it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8473.34it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7884.03it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7869.24it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8289.14it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4044.65it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7530.17it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8256.50it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8756.38it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8905.10it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9258.95it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7626.01it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8322.03it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7943.76it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7681.88it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8192.00it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8774.69it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 7943.76it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8081.51it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 8065.97it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.16, 0.12, 0.04, 0.12, 0.12, 0.16, 0.08, 0.2, 0.16, 0.08]\n",
      "[0.16, 0.12, 0.04, 0.12, 0.12, 0.16, 0.08, 0.2, 0.16, 0.08]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a,b}1(x_{a}=y_{b})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 1 Expectation 1\n",
      "Exp terms, len 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13148.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 2 Expectation 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 10538.45it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11366.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified len 1\n",
      "Compact beta deriv=0, gamma deriv=0, len=1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 22104.37it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 45964.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 49200.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 1, Simplified len 18\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 57675.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 2, Simplified len 12\n",
      "Exp terms, len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 5385.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 3\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 20945.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 22260.79it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 23172.95it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 22723.09it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 22399.49it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 23442.78it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 19768.91it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 22671.91it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 22641.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 21854.82it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 17747.41it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 18864.94it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 10066.33it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 18052.96it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 18613.78it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 19093.95it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 19478.19it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 21112.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 17119.61it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 17772.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.128, -0.064, 0.03199999999999997, -0.096, -0.192, 0.128, -0.096, 0.064, -0.096, -0.096]\n",
      "[-0.128, -0.064, 0.032, -0.096, -0.192, 0.128, -0.096, 0.064, -0.096, -0.096]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 5769.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 2\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11859.48it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 21254.92it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 10280.16it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11814.94it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 10951.19it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11234.74it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 19722.43it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 18157.16it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 17236.87it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 16958.10it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15391.94it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15621.24it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15505.74it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 9565.12it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 13486.51it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 13888.42it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15363.75it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15169.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 15738.48it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 16946.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.16, -0.032, 0.032, -0.16, 0.0, 0.0, 0.0, -0.032, -0.064, 0.096]\n",
      "[-0.16, -0.032, 0.032, -0.16, 0.0, 0.0, 0.0, -0.032, -0.064, 0.096]\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a,b}1(x_{b}=y_{a})+c_{2}\\sum_{a}1(x_{a}=y_{a})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 1 Expectation 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 17313.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 53464.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n",
      "Exp terms, len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 12777.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 8322.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 2\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a}1(x_{a}=y_{a})+c_{2}\\sum_{a,b}1(x_{b}=y_{a})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 2 Expectation 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 10286.46it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 22982.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified len 2\n",
      "Compact beta deriv=2, gamma deriv=2, len=2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 16070.13it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 45990.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 46987.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 3, Simplified len 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:00<00:00, 25625.36it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4, Simplified len 68\n",
      "Exp terms, len 204\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 3295.91it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 40\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 424.56it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 393.46it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 427.21it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 408.15it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 425.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 398.89it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 420.87it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 425.48it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 446.99it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 426.69it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 118.22it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 114.50it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 119.15it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 114.50it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 118.96it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 115.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 119.37it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 120.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 128.31it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 121.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.6435840000000022, -0.7403520000000039, -0.5514240000000007, -1.1089920000000062, -0.5314560000000016, -0.5160959999999997, -0.5514240000000007, -0.3240960000000006, 0.0, -0.5314560000000004]\n",
      "[-0.6435840000000002, -0.7403520000000025, -0.5514240000000008, -1.1089920000000006, -0.5314560000000002, -0.5160960000000006, -0.5514240000000012, -0.32409600000000016, 0.0, -0.5314560000000011]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 1465.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 23\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 116.56it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 126.88it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 102.94it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 120.68it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 112.09it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 123.57it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 109.96it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 122.07it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 113.37it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:00<00:00, 122.84it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 160.11it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 169.27it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 137.80it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 161.23it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 164.23it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 161.13it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 145.00it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 163.11it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 150.41it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 161.61it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.3317760000000012, -0.14592000000000008, -0.3210239999999964, -0.5084160000000003, -0.26265599999999983, -0.40704000000000023, -1.1980800000000005, -0.25344000000000044, -0.7403520000000006, -0.40704000000000023]\n",
      "[-0.33177600000000085, -0.14592, -0.32102399999999887, -0.5084159999999998, -0.2626559999999998, -0.40703999999999974, -1.1980800000000014, -0.2534399999999999, -0.7403519999999987, -0.4070399999999994]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a,b}1(x_{a}=y_{b})+c_{2}\\sum_{a,b,c}1(x_{a}=y_{b})1(x_{c}=y_{b})+c_{3}\\sum_{a,b,c}1(x_{a}=y_{c})1(x_{a}=y_{b})+c_{4}\\sum_{a,b,c,d}1(x_{a}=y_{b})1(x_{a}=x_{c})1(x_{c}=x_{d})+c_{5}\\sum_{a,b,c,d}1(x_{a}=y_{b})1(x_{c}=x_{d})+c_{6}\\sum_{a,b,c,d}1(x_{c}=y_{d})1(x_{a}=y_{b})1(x_{a}=y_{d})+c_{7}\\sum_{a,b,c,d}1(x_{c}=y_{d})1(x_{a}=y_{b})+c_{8}\\sum_{a,b,c,d}1(y_{b}=y_{c})1(x_{a}=y_{b})1(x_{a}=y_{d})+c_{9}\\sum_{a,b,c,d}1(x_{a}=y_{b})1(y_{c}=y_{d})+c_{10}\\sum_{a,b,c,d,e}1(x_{a}=y_{b})1(x_{e}=y_{b})1(x_{c}=x_{d})+c_{11}\\sum_{a,b,c,d,e}1(x_{a}=y_{b})1(x_{a}=y_{e})1(x_{c}=x_{d})+c_{12}\\sum_{a,b,c,d,e}1(x_{c}=y_{e})1(x_{a}=y_{b})1(x_{c}=x_{d})+c_{13}\\sum_{a,b,c,d,e}1(x_{c}=y_{e})1(x_{a}=y_{b})1(x_{a}=y_{d})+c_{14}\\sum_{a,b,c,d,e}1(x_{a}=y_{b})1(y_{d}=y_{e})1(x_{a}=x_{c})+c_{15}\\sum_{a,b,c,d,e}1(x_{a}=y_{b})1(x_{a}=y_{e})1(y_{c}=y_{d})+c_{16}\\sum_{a,b,c,d,e,f}1(x_{c}=x_{f})1(x_{d}=x_{e})1(x_{a}=y_{b})+c_{17}\\sum_{a,b,c,d,e}1(x_{a}=y_{b})1(x_{c}=x_{e})1(x_{c}=x_{d})+c_{18}\\sum_{a,b,c,d,e,f}1(x_{a}=y_{b})1(x_{d}=x_{e})1(x_{c}=y_{f})+c_{19}\\sum_{a,b,c,d,e,f}1(x_{c}=y_{f})1(x_{a}=y_{b})1(x_{d}=y_{e})+c_{20}\\sum_{a,b,c,d,e,f}1(x_{a}=y_{b})1(y_{e}=y_{f})1(x_{c}=x_{d})+c_{21}\\sum_{a,b,c,d,e,f}1(x_{c}=y_{f})1(y_{d}=y_{e})1(x_{a}=y_{b})+c_{22}\\sum_{a,b,c,d,e,f}1(y_{c}=y_{f})1(x_{a}=y_{b})1(y_{d}=y_{e})+c_{23}\\sum_{a,b,c,d,e}1(x_{a}=y_{b})1(y_{c}=y_{e})1(y_{c}=y_{d})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 1 Expectation 23\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 49344.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 85510.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 51843.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 3, Simplified len 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:00<00:00, 30147.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4, Simplified len 68\n",
      "Exp terms, len 204\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 3733.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle $"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 2 Expectation 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 9754.20it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 9799.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified len 23\n",
      "Compact beta deriv=4, gamma deriv=0, len=23\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 18957.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 44034.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 42623.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 3, Simplified len 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:00<00:00, 26476.14it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4, Simplified len 68\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272/272 [00:00<00:00, 12063.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 1, Simplified len 154\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 462/462 [00:00<00:00, 8285.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 2, Simplified len 197\n",
      "Exp terms, len 197\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 3658.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 29\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4095.39it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4296.44it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4363.14it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4571.61it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4426.58it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4181.36it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4605.27it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 3785.70it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4496.58it/s]\n",
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 197/197 [00:00<00:00, 4515.77it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3141.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3272.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3373.31it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3500.38it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3385.52it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3343.36it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3394.21it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 2877.57it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3357.30it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3453.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.090559999999999, -2.488320000000005, -0.1843200000000036, -1.628160000000003, -1.489920000000001, -2.8415999999999983, -0.6758400000000004, -1.029120000000006, -1.6128000000000018, -0.8448000000000024]\n",
      "[1.0905600000000009, -2.488320000000002, -0.18431999999999965, -1.6281600000000005, -1.4899200000000006, -2.8416000000000006, -0.67584, -1.029119999999995, -1.6127999999999996, -0.8448]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3071.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 15\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3095.27it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3479.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3021.16it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 2707.51it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 1932.31it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3470.52it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3139.53it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3372.09it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 3362.68it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 2937.97it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 3664.00it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 4227.84it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 4187.60it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 3579.16it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 3542.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 3861.92it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 3803.32it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 4090.14it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 4019.33it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 3722.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.7065599999999999, -1.35168, -0.8448, -2.18112, -2.181120000000001, -0.90624, -2.04288, -2.08896, -1.3055999999999999, -0.5836800000000002]\n",
      "[-0.7065600000000014, -1.35168, -0.8448000000000002, -2.18112, -2.18112, -0.9062400000000002, -2.0428800000000003, -2.088960000000001, -1.3055999999999996, -0.5836799999999984]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a,b}1(x_{b}=y_{a})+c_{2}\\sum_{a}1(x_{a}=y_{a})+c_{3}\\sum_{a,b,c}1(x_{b}=y_{c})1(x_{a}=y_{c})+c_{4}\\sum_{a,b}1(x_{a}=y_{a})1(x_{b}=y_{a})+c_{5}\\sum_{a,b,c}1(y_{b}=y_{c})1(x_{a}=y_{b})+c_{6}\\sum_{a,b}1(x_{a}=y_{a})1(x_{a}=y_{b})+c_{7}\\sum_{a,b,c,d}1(x_{c}=x_{d})1(x_{b}=y_{a})+c_{8}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=x_{c})+c_{9}\\sum_{a,b,c,d}1(x_{b}=y_{d})1(x_{c}=y_{a})+c_{10}\\sum_{a,b,c}1(x_{c}=y_{c})1(x_{a}=y_{b})+c_{11}\\sum_{a,b,c,d}1(y_{c}=y_{d})1(x_{b}=y_{a})+c_{12}\\sum_{a,b,c}1(x_{a}=y_{b})1(x_{b}=y_{c})+c_{13}\\sum_{a,b,c}1(x_{a}=y_{a})1(y_{b}=y_{c})+c_{14}\\sum_{a,b,c}1(x_{b}=y_{a})1(y_{b}=y_{c})+c_{15}\\sum_{a,b,c}1(x_{b}=x_{c})1(x_{a}=y_{b})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 1 Expectation 15\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 44858.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 49257.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 50848.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 3, Simplified len 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:00<00:00, 29705.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4, Simplified len 68\n",
      "Exp terms, len 204\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:00<00:00, 3647.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 24\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 3966.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 17\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a}1(x_{a}=y_{a})+c_{2}\\sum_{a,b}1(x_{a}=y_{a})1(x_{a}=x_{b})+c_{3}\\sum_{a,b}1(x_{b}=y_{a})+c_{4}\\sum_{a,b}1(x_{a}=y_{a})1(y_{a}=y_{b})+c_{5}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=x_{c})+c_{6}\\sum_{a,b,c}1(x_{b}=y_{a})1(x_{a}=x_{c})+c_{7}\\sum_{a,b,c}1(x_{a}=y_{c})1(x_{b}=y_{a})+c_{8}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=y_{c})+c_{9}\\sum_{a,b,c}1(x_{a}=x_{b})1(y_{a}=y_{c})+c_{10}\\sum_{a,b,c}1(x_{a}=y_{c})1(y_{a}=y_{b})+c_{11}\\sum_{a,b,c}1(x_{a}=y_{a})1(y_{b}=y_{c})+c_{12}\\sum_{a,b,c,d}1(x_{b}=y_{a})1(x_{c}=x_{d})+c_{13}\\sum_{a,b,c}1(x_{c}=y_{a})1(x_{b}=y_{a})+c_{14}\\sum_{a,b,c,d}1(x_{c}=y_{a})1(x_{b}=y_{d})+c_{15}\\sum_{a,b,c,d}1(y_{a}=y_{d})1(x_{b}=x_{c})+c_{16}\\sum_{a,b,c}1(x_{b}=y_{a})1(x_{b}=y_{c})+c_{17}\\sum_{a,b,c,d}1(x_{b}=y_{d})1(y_{a}=y_{c})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 2 Expectation 17\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4929.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [00:00<00:00, 3798.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified len 17\n",
      "Compact beta deriv=4, gamma deriv=2, len=17\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 44501.90it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 71392.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 44228.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 3, Simplified len 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:00<00:00, 26560.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4, Simplified len 68\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 532/532 [00:00<00:00, 14305.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 5, Simplified len 142\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1232/1232 [00:00<00:00, 7849.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 6, Simplified len 281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1686/1686 [00:00<00:00, 2749.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 1, Simplified len 798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3990/3990 [00:02<00:00, 1575.73it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 2, Simplified len 1345\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5380/5380 [00:04<00:00, 1267.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 3, Simplified len 1570\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4710/4710 [00:03<00:00, 1362.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 4, Simplified len 1345\n",
      "Exp terms, len 4035\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:01<00:00, 2338.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 91\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4474.53it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4640.74it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4534.28it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4385.35it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4353.21it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4529.86it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4654.47it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4072.39it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4394.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4035/4035 [00:00<00:00, 4690.01it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2756.58it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2869.85it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2791.17it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2632.34it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2665.78it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2833.21it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2910.80it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2590.55it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2783.23it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2908.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[74.64960000001554, 94.0031999999953, 655.2576000000022, -414.72000000001015, -52.53119999997903, 171.4176000000046, -364.9536000000099, -74.64959999998287, 367.71840000001237, -348.3648000000089]\n",
      "[74.64959999999917, 94.00319999999968, 655.2575999999999, -414.72000000000094, -52.53120000000081, 171.4176000000004, -364.95360000000005, -74.64960000000406, 367.71839999999884, -348.3648000000002]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2102.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 32\n",
      "Fuzzing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2737.15it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2771.83it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2803.45it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2806.63it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2906.50it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2822.84it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2740.23it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2681.29it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2948.90it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 91/91 [00:00<00:00, 2733.29it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4473.78it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4556.40it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4357.86it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4406.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4776.77it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4403.04it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4168.90it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4164.50it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4594.92it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4376.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[483.8400000000004, 658.0223999999994, 436.8383999999992, 55.295999999999935, 376.0127999999995, 326.24639999999937, 276.4799999999984, 539.1360000000013, 293.0687999999996, 470.01599999999866]\n",
      "[483.8399999999988, 658.0224, 436.83839999999987, 55.295999999999395, 376.0127999999999, 326.2464000000001, 276.4799999999989, 539.1359999999983, 293.06879999999984, 470.01599999999956]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}\\sum_{a,b}1(x_{b}=y_{a})+c_{2}\\sum_{a}1(x_{a}=y_{a})+c_{3}\\sum_{a,b,c}1(x_{a}=x_{c})1(x_{a}=y_{b})+c_{4}\\sum_{a,b}1(x_{b}=y_{b})1(x_{a}=y_{b})+c_{5}\\sum_{a,b,c}1(x_{b}=y_{c})1(x_{b}=y_{a})+c_{6}\\sum_{a,b}1(y_{a}=y_{b})1(x_{b}=y_{a})+c_{7}\\sum_{a,b,c,d}1(x_{a}=y_{d})1(x_{b}=x_{c})+c_{8}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=x_{c})+c_{9}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=y_{c})+c_{10}\\sum_{a,b,c,d}1(x_{a}=y_{b})1(y_{c}=y_{d})+c_{11}\\sum_{a,b,c,d}1(x_{b}=y_{c})1(x_{a}=y_{d})+c_{12}\\sum_{a,b,c}1(x_{b}=y_{c})1(x_{c}=y_{a})+c_{13}\\sum_{a,b,c}1(x_{a}=y_{a})1(y_{b}=y_{c})+c_{14}\\sum_{a,b,c}1(x_{b}=x_{c})1(x_{a}=y_{b})+c_{15}\\sum_{a,b,c}1(x_{b}=y_{a})1(y_{b}=y_{c})+c_{16}\\sum_{a,b}1(x_{b}=y_{a})1(x_{a}=y_{b})+c_{17}\\sum_{a,b}1(x_{b}=y_{b})1(x_{a}=y_{a})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 1 Expectation 32\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 17403.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 1, Simplified len 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 83635.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2, Simplified len 12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 52401.51it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 3, Simplified len 30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [00:00<00:00, 30145.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4, Simplified len 68\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 532/532 [00:00<00:00, 16220.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 5, Simplified len 142\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1232/1232 [00:00<00:00, 8930.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 6, Simplified len 281\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1686/1686 [00:00<00:00, 3547.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 1, Simplified len 656\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3280/3280 [00:01<00:00, 2580.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gamma 2, Simplified len 979\n",
      "Exp terms, len 2937\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2937/2937 [00:00<00:00, 2949.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, partially simplified len 90\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:00<00:00, 2477.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp terms, fully simplified len 44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +c_{1}+c_{2}\\sum_{a}1(x_{a}=y_{a})+c_{3}\\sum_{a,b}1(x_{a}=x_{b})+c_{4}\\sum_{a,b}1(x_{a}=y_{a})1(x_{b}=y_{a})+c_{5}\\sum_{a,b}1(x_{b}=y_{a})+c_{6}\\sum_{a,b}1(y_{a}=y_{b})1(x_{b}=y_{a})+c_{7}\\sum_{a,b,c}1(x_{a}=x_{c})1(x_{a}=x_{b})+c_{8}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=x_{c})+c_{9}\\sum_{a,b,c}1(x_{a}=x_{c})1(x_{b}=y_{a})+c_{10}\\sum_{a,b,c}1(x_{c}=y_{a})1(x_{b}=y_{a})+c_{11}\\sum_{a,b,c}1(x_{b}=y_{a})1(x_{a}=y_{c})+c_{12}\\sum_{a,b,c}1(x_{a}=y_{a})1(x_{b}=y_{c})+c_{13}\\sum_{a,b,c}1(y_{b}=y_{c})1(x_{a}=x_{b})+c_{14}\\sum_{a,b}1(y_{a}=y_{b})+c_{15}\\sum_{a,b,c}1(x_{c}=y_{b})1(y_{a}=y_{b})+c_{16}\\sum_{a,b,c}1(x_{a}=y_{c})1(y_{a}=y_{b})+c_{17}\\sum_{a,b,c}1(x_{a}=y_{a})1(y_{b}=y_{c})+c_{18}\\sum_{a,b,c,d}1(x_{a}=x_{d})1(x_{b}=x_{c})+c_{19}\\sum_{a,b,c,d}1(x_{a}=y_{d})1(x_{b}=x_{c})+c_{20}\\sum_{a,b,c,d}1(x_{b}=y_{d})1(x_{a}=y_{c})+c_{21}\\sum_{a,b,c,d}1(y_{a}=y_{d})1(x_{b}=x_{c})+c_{22}\\sum_{a,b,c,d}1(x_{b}=y_{d})1(y_{a}=y_{c})+c_{23}\\sum_{a,b}1(x_{a}=y_{b})1(x_{b}=y_{a})+c_{24}\\sum_{a,b}1(x_{a}=y_{a})1(x_{b}=y_{b})+c_{25}\\sum_{a,b}1(y_{a}=y_{b})1(x_{a}=x_{b})+c_{26}\\sum_{a,b,c}1(y_{a}=y_{b})1(y_{b}=y_{c})+c_{27}\\sum_{a,b,c,d}1(y_{b}=y_{c})1(y_{a}=y_{d})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Term 2 Expectation 44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 76/76 [00:00<00:00, 5058.27it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:00<00:00, 3191.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Simplified len 44\n",
      "Compact beta deriv=6, gamma deriv=4, len=44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# for num_beta_derivs in [0,2,4,6]:\n",
    "#     for num_gamma_derivs in [0,2,4,6,8,10]:\n",
    "for num_beta_derivs, num_gamma_derivs in [(0,0), (2,2), (4, 0), (4,2), (6, 4)]:\n",
    "    curr_file = f'exp_terms/beta{num_beta_derivs}_gamma{num_gamma_derivs}.pkl'\n",
    "    if os.path.exists(curr_file):\n",
    "        continue\n",
    "    exp_terms = get_transformer_rf_deriv(num_beta_derivs, num_gamma_derivs,fuzz_test=True)\n",
    "    pickle.dump(exp_terms, open(curr_file, 'wb'))\n",
    "    # display_expectation_terms(exp_terms)\n",
    "    print(f'Compact beta deriv={num_beta_derivs}, gamma deriv={num_gamma_derivs}, len={len(exp_terms)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8d1137e5-d1da-440f-bdfe-db1906388084",
   "metadata": {},
   "outputs": [],
   "source": [
    "# EXP_TERMS_REF_LIST = [((2, set(), [], -1152), ''),\n",
    "#              ((2, {1}, [(('x', 1), ('y', 1))], 864), 'tr(XY^T)'),\n",
    "#              ((0, {1, 2}, [(('x', 2), ('y', 1))], 1), '1^T X Y^T 1'),\n",
    "#              ((3, {1, 2}, [(('x', 1), ('y', 1)), (('x', 1), ('x', 2))], -336), 'tr(XY^T diag(XX^T1))'),\n",
    "#              ((3, {1, 2}, [(('y', 1), ('y', 2)), (('x', 1), ('y', 1))], -336), 'tr(XY^T diag(YY^T1))'),\n",
    "#              ((4, {1, 2}, [(('x', 1), ('x', 2))], 576), '1^TXX^T1'),\n",
    "#              ((4, {1, 2}, [(('y', 1), ('y', 2))], 576), '1^TYY^T 1'),\n",
    "#              ((0, {1, 2}, [(('x', 2), ('y', 1)), (('x', 1), ('y', 2))], 1), 'tr(XY^T XY^T)'),\n",
    "#              ((0, {1, 2}, [(('y', 1), ('y', 2)), (('x', 1), ('x', 2))], 1), '\\color{blue}{ tr(XX^T YY^T)}'),\n",
    "                      \n",
    "#              ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('x', 1), ('x', 3))], 1),    '1^TXX^TXX^T1'),\n",
    "#              ((3, {1, 2, 3}, [(('x', 1), ('y', 2)), (('x', 1), ('x', 3))], -4),   '1^TXX^TXY^T1'),\n",
    "#              ((0, {1, 2, 3}, [(('x', 1), ('y', 2)), (('x', 2), ('x', 3))], 1),    '1^TXX^TYX^T1'),\n",
    "#              ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('y', 1), ('y', 3))], 1),    '1^TXX^TYY^T1'),\n",
    "#              ((4, {1, 2, 3}, [(('x', 2), ('y', 1)), (('x', 1), ('y', 3))], 144),  '1^TXY^TXY^T1'),\n",
    "#              ((3, {1, 2, 3}, [(('x', 1), ('y', 2)), (('y', 2), ('y', 3))], -4),   '1^TXY^TYY^T1'),\n",
    "#              ((4, {1, 2, 3}, [(('y', 1), ('y', 2)), (('x', 1), ('y', 3))], 192),  '1^TYX^TYY^T1'),\n",
    "#              ((0, {1, 2, 3}, [(('y', 1), ('y', 3)), (('y', 1), ('y', 2))], 1),    '1^TYY^TYY^T1'),\n",
    "            \n",
    "#              # ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('x', 1), ('x', 3)), (('x', 1), ('y', 1))], 1), '\\color{green}{tr(XY^T diag(XX^T1)diag(XX^T1))}'),\n",
    "#              # ((0, {1, 2, 3}, [(('y', 1), ('y', 3)), (('x', 1), ('y', 1)), (('x', 2), ('y', 3))], 1), '\\color{green}{tr(XY^T diag(YY^T YX^T 1)}'),\n",
    "#              # ((0, {1, 2, 3}, [(('x', 1), ('y', 2)), (('x', 1), ('y', 1)), (('x', 1), ('y', 3))], 1), '\\color{green}{tr(XY^T diag(XY^T1)diag(XY^T1))}'),\n",
    "                      \n",
    "#              ((0, {1, 2, 3}, [(('x', 1), ('x', 2)), (('y', 1), ('y', 3)), (('x', 1), ('y', 1))], 1), 'tr(diag(XX^T1)XY^Tdiag(YY^T1))'),\n",
    "#              ((0, {1, 2, 3}, [(('x', 1), ('x', 3)), (('x', 1), ('y', 1)), (('x', 1), ('x', 2))], 1), 'tr(diag(XX^T1)XY^Tdiag(XX^T1))'),\n",
    "#              ((0, {1, 2, 3}, [(('y', 1), ('y', 2)), (('x', 1), ('y', 1)), (('y', 1), ('y', 3))], 1), 'tr(diag(YY^T1)XY^Tdiag(YY^T1))'),\n",
    "                      \n",
    "#              ((4, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('x', 2), ('x', 3)), (('x', 3), ('y', 4))], 192), '1^TXX^TXX^TXY^T1'),\n",
    "#              ((4, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('y', 2), ('x', 3)), (('x', 3), ('y', 4))], 192), '1^TXY^TYX^TXY^T1'),\n",
    "#              ((4, {1, 2, 3, 4}, [(('y', 1), ('x', 2)), (('x', 2), ('y', 3)), (('y', 3), ('y', 4))], 192), '1^TYX^TXY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('x', 2), ('x', 3)), (('x', 3), ('x', 4))], 1), '1^TXX^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('y', 2)), (('y', 2), ('y', 3)), (('y', 3), ('y', 4))], 1), '1^TYY^TYY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('x', 2), ('x', 3)), (('x', 3), ('y', 4))], 1), '1^TXY^TXX^TXY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('y', 2), ('x', 3)), (('y', 3), ('y', 4))], 1), '1^TXY^TYX^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('y', 2)), (('y', 2), ('x', 3)), (('y', 3), ('y', 4))], 1), '1^TYY^TYX^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('y', 2), ('x', 3)), (('x', 3), ('y', 4))], 1), '1^TXX^TYX^TXY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('y', 2), ('y', 3)), (('x', 3), ('y', 4))], 1), '1^TXY^TYY^TXY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('x', 2)), (('y', 2), ('y', 3)), (('y', 3), ('y', 4))], 1), '1^TYX^TYY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('y', 2)), (('x', 2), ('x', 3)), (('x', 3), ('x', 4))], 1), '1^TXY^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('y', 1), ('y', 2)), (('x', 2), ('x', 3)), (('x', 3), ('x', 4))], 1), '1^TYY^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('y', 2), ('y', 3)), (('y', 3), ('y', 4))], 1), '1^TXX^TYY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 2)), (('x', 2), ('y', 3)), (('x', 3), ('x', 4))], 1), '1^TXX^TXY^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 2), ('y', 1)), (('x', 1), ('y', 4)), (('y', 4), ('y', 3))], 1), '1^TXY^TXY^TYY^T1'),\n",
    "#              ((0, {1, 2, 3, 4}, [(('x', 1), ('x', 3)), (('x', 3), ('y', 2)), (('x', 2), ('y', 4))], 1), '1^TXX^TXY^TXY^T1'),\n",
    "                      \n",
    "#              ((0, {1, 2, 3, 4, 5}, [(('x', 1), ('y', 2)), (('x', 4), ('x', 5)), (('x', 3), ('y', 2)), (('x', 3), ('x', 4))], 1), '1^TXY^TYX^TXX^TXX^T1'),\n",
    "#              ((0, {1, 2, 3, 4, 5}, [(('x', 1), ('y', 2)), (('x', 1), ('x', 3)), (('x', 3), ('x', 4)), (('x', 1), ('y', 5))], 1), 'tr(diag(XY^T1)diag(XY^T1)XX^Tdiag(XX^T1))'),\n",
    "#              ((0, {1, 2, 3, 4, 5}, [(('x', 1), ('y', 2)), (('x', 1), ('x', 3)), (('x', 3), ('y', 5)), (('x', 3), ('y', 4))], 1), '1^Tdiag(XY^T1)XX^Tdiag(XY^T1)diag(XY^T1)1'),\n",
    "#              ((0, {1, 2, 3, 4, 5}, [(('x', 1), ('y', 2)), (('x', 1), ('y', 4)), (('y', 2), ('y', 3)), (('x', 1), ('y', 5))], 1), 'tr(diag(XY^T1)diag(XY^T1)diag(YY^T1)diag(XY^T1))'),]\n",
    "\n",
    "# for exp_term, v in EXP_TERMS_REF_LIST:\n",
    "#     if v == 'default':\n",
    "#         print(exp_term)\n",
    "#         display_expectation_terms([exp_term])\n",
    "#         display_expectation_terms_compact([exp_term], simple_mode=False)\n",
    "#         print(v)\n",
    "#         assert(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8da9435-fc8d-49b1-bd5b-88dbf7eeb2de",
   "metadata": {},
   "source": [
    "# LaTeX output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3f07bcbd-30f9-4a09-a278-3d18121e746d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 0 gamma 0\n",
      "Len 1\n",
      "beta 4 gamma 0\n",
      "Len 23\n",
      "beta 2 gamma 2\n",
      "Len 2\n",
      "beta 4 gamma 2\n",
      "Len 17\n",
      "beta 6 gamma 4\n",
      "Len 44\n",
      "\\begin{longtable}{p{0.2\\linewidth} >{\\raggedright\\arraybackslash}p{0.8\\linewidth}<{}}\n",
      "\\toprule \\textbf{Derivative} & \\textbf{Expansion} \\\\* \\midrule\n",
      "$ \\kappa_{X,Y}(0,0) = $ & $ +c_{1}{\\color{cblind1} {1^TXY^T 1}} $  \\\\* \n",
      "\\midrule\n",
      "$ \\frac{\\partial^{2}}{\\partial \\beta^{2}}\\frac{\\partial^{2}}{\\partial \\gamma^{2}}\\kappa_{X,Y}(0,0) = $ & $ +c_{1}{\\color{cblind1} {1^TXY^T 1}} $ $ +c_{2}{\\color{cblind3} {tr(XY^T)}} $  \\\\* \n",
      "\\midrule\n",
      "$ \\frac{\\partial^{4}}{\\partial \\beta^{4}}\\kappa_{X,Y}(0,0) = $ & $ +c_{1}{\\color{cblind1} {1^TXY^T 1}} $ $ +c_{2}{\\color{cblind1} {1^TXX^TXY^T1}} $ $ +c_{3}{\\color{cblind1} {1^TXY^TYY^T1}} $ $ +c_{4}{\\color{cblind1} {1^TXX^TXX^TXY^T1}} $ $ +c_{5}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{6}{\\color{cblind1} {1^TXY^TYX^TXY^T1}} $ $ +c_{7}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}}) $ $ +c_{8}{\\color{cblind1} {1^TYX^TXY^TYY^T1}} $ $ +c_{9}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{10}({\\color{cblind1} {1^TXX^TXY^T1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{11}({\\color{cblind1} {1^TXY^TYY^T1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{12}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXX^TXY^T1}}) $ $ +c_{13}({\\color{cblind1} {1^TXY^TYY^T1}})({\\color{cblind1} {1^TXY^T 1}}) $ $ +c_{14}({\\color{cblind1} {1^TXX^TXY^T1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{15}({\\color{cblind1} {1^TXY^TYY^T1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{16}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TXX^T1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{17}({\\color{cblind1} {1^TXY^T 1}})(1^TXX^TXX^T1) $ $ +c_{18}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{19}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}}) $ $ +c_{20}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TXX^T1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{21}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{22}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TYY^T1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{23}({\\color{cblind1} {1^TXY^T 1}})(1^TYY^TYY^T1) $  \\\\* \n",
      "\\midrule\n",
      "$ \\frac{\\partial^{4}}{\\partial \\beta^{4}}\\frac{\\partial^{2}}{\\partial \\gamma^{2}}\\kappa_{X,Y}(0,0) = $ & $ +c_{1}{\\color{cblind1} {1^TXY^T 1}} $ $ +c_{2}{\\color{cblind3} {tr(XY^T)}} $ $ +c_{3}{\\color{cblind1} {1^TXX^TXY^T1}} $ $ +c_{4}{\\color{cblind1} {tr(XX^TXY^T)}} $ $ +c_{5}{\\color{cblind1} {1^TXY^TYY^T1}} $ $ +c_{6}{\\color{cblind1} {tr(XY^TYY^T)}} $ $ +c_{7}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{8}({\\color{cblind3} {tr(XY^T)}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{9}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}}) $ $ +c_{10}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind3} {tr(XY^T)}}) $ $ +c_{11}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{12}{\\color{cblind3} {1^TXY^TXY^T1}} $ $ +c_{13}({\\color{cblind3} {tr(XY^T)}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{14}{\\color{cblind3} {1^TYX^TYY^T1}} $ $ +c_{15}{\\color{cblind3} {1^TXX^TYX^T1}} $ $ +c_{16}{\\color{cblind4} {1^TXX^TYY^T1}} $ $ +c_{17}({\\color{cblind2} {1^TYY^T1}})({\\color{cblind2} {1^TXX^T1}}) $  \\\\* \n",
      "\\midrule\n",
      "$ \\frac{\\partial^{6}}{\\partial \\beta^{6}}\\frac{\\partial^{4}}{\\partial \\gamma^{4}}\\kappa_{X,Y}(0,0) = $ & $ +c_{1}{\\color{cblind1} {1^TXY^T 1}} $ $ +c_{2}{\\color{cblind3} {tr(XY^T)}} $ $ +c_{3}{\\color{cblind1} {1^TXX^TXY^T1}} $ $ +c_{4}{\\color{cblind1} {tr(XX^TXY^T)}} $ $ +c_{5}{\\color{cblind1} {1^TXY^TYY^T1}} $ $ +c_{6}{\\color{cblind1} {tr(XY^TYY^T)}} $ $ +c_{7}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{8}({\\color{cblind3} {tr(XY^T)}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{9}({\\color{cblind3} {tr(XY^T)}})({\\color{cblind1} {1^TXY^T 1}}) $ $ +c_{10}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{11}({\\color{cblind1} {1^TXY^T 1}})({\\color{cblind1} {1^TXY^T 1}}) $ $ +c_{12}{\\color{cblind3} {1^TXY^TXY^T1}} $ $ +c_{13}({\\color{cblind3} {tr(XY^T)}})({\\color{cblind2} {1^TYY^T1}}) $ $ +c_{14}{\\color{cblind3} {1^TXX^TYX^T1}} $ $ +c_{15}{\\color{cblind3} {1^TYX^TYY^T1}} $ $ +c_{16}{\\color{cblind3} {tr(XY^T XY^T)}} $ $ +c_{17}({\\color{cblind3} {tr(XY^T)}})({\\color{cblind3} {tr(XY^T)}}) $ $ +c_{18} $ $ +c_{19}{\\color{cblind2} {1^TXX^T1}} $ $ +c_{20}1^TXX^TXX^T1 $ $ +c_{21}{\\color{cblind4} {1^TXX^TYY^T1}} $ $ +c_{22}{\\color{cblind2} {1^TYY^T1}} $ $ +c_{23}({\\color{cblind2} {1^TXX^T1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{24}({\\color{cblind2} {1^TYY^T1}})({\\color{cblind2} {1^TXX^T1}}) $ $ +c_{25}{\\color{cblind5}{ tr(XX^T YY^T)}} $ $ +c_{26}1^TYY^TYY^T1 $ $ +c_{27}({\\color{cblind2} {1^TYY^T1}})({\\color{cblind2} {1^TYY^T1}}) $  \\\\* \n",
      "\\midrule\n",
      "\\end{longtable}\n"
     ]
    }
   ],
   "source": [
    "latex_strs = {}\n",
    "colors = ['CC79A7', 'D55E00', '0072B2', 'F0E442', '009E73', '56B4E9', 'E69F00', '000000']\n",
    "color_dict = {r'\\color{green}' : r'\\color{cblind1}',\n",
    "              r'\\color{orange}' : r'\\color{cblind2}',\n",
    "              r'\\color{red}' : r'\\color{cblind3}',\n",
    "              r'\\color{purple}' : r'\\color{cblind4}',\n",
    "              r'\\color{blue}' : r'\\color{cblind5}',\n",
    "             }\n",
    "# \\definecolor{cblind1}{HTML}{CC79A7}\n",
    "# \\definecolor{cblind2}{HTML}{0072B2}\n",
    "# \\definecolor{cblind3}{HTML}{D55E00}\n",
    "# \\definecolor{cblind4}{HTML}{009E73}\n",
    "# \\definecolor{cblind5}{HTML}{F0E442}\n",
    "\n",
    "for num_gamma_derivs in [0,2,4,6,8,10]:\n",
    "    for num_beta_derivs in [0,2,4,6]:\n",
    "        curr_file = f'exp_terms/beta{num_beta_derivs}_gamma{num_gamma_derivs}.pkl'\n",
    "        if not os.path.exists(curr_file):\n",
    "            continue\n",
    "       \n",
    "            \n",
    "        print('beta', num_beta_derivs, 'gamma', num_gamma_derivs)\n",
    "        exp_terms = pickle.load(open(curr_file, 'rb'))\n",
    "        print('Len', len(exp_terms))\n",
    "        curr_str = get_expectation_terms_compact_str(exp_terms,simple_mode=False, add_coeff=False)\n",
    "        curr_str = '$ ' + f' $ $ '.join(curr_str) + ' $ '\n",
    "        for k, v in color_dict.items():\n",
    "            curr_str = curr_str.replace(k,v)\n",
    "        latex_strs[(num_beta_derivs, num_gamma_derivs)] = curr_str\n",
    "\n",
    "        \n",
    "latex_output = r'\\begin{longtable}{p{0.2\\linewidth} >{\\raggedright\\arraybackslash}p{0.8\\linewidth}<{}}' + '\\n' + r'\\toprule \\textbf{Derivative} & \\textbf{Expansion} \\\\* \\midrule' + '\\n'\n",
    "for b in [0,2,4,6]:\n",
    "    for l in [0,2,4]:\n",
    "        if (b,l) not in [(0,0), (4,0), (2,2), (4,2), (6,4)]:\n",
    "            continue\n",
    "        if len(latex_strs[(b,l)]) == 0:\n",
    "            latex_strs[(b,l)] = '$0$'\n",
    "        # latex_output += r'\\rule{0pt}{4ex} \\vspace{2ex} $ '\n",
    "        latex_output += r'$ '\n",
    "        if b > 0:\n",
    "            latex_output += r'\\frac{\\partial^{' + f'{b}' + r'}}{\\partial \\beta^{' + f'{b}' + '}}' \n",
    "        if l > 0:\n",
    "            latex_output += r'\\frac{\\partial^{' + f'{l}' + r'}}{\\partial \\gamma^{' + f'{l}' + '}}'\n",
    "        latex_output += r'\\kappa_{X,Y}(0,0) = $ & ' + latex_strs[(b,l)] + r' \\\\* ' + '\\n' + r'\\midrule' + '\\n'\n",
    "latex_output += r'\\end{longtable}'\n",
    "print(latex_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "92359bb5-fcbf-4720-9181-a904cb14e849",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 0 gamma 0\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +(+\\frac{1}{k^{2}}){\\color{green} {1^TXY^T 1}}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 2 gamma 2\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +(-8\\frac{1}{k^{3}}){\\color{green} {1^TXY^T 1}}+(+8\\frac{1}{k^{2}}){\\color{red} {tr(XY^T)}}$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4 gamma 0\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +(+12\\frac{1}{k^{2}}){\\color{green} {1^TXY^T 1}}+(-96\\frac{1}{k^{3}}){\\color{green} {1^TXX^TXY^T1}}+(-96\\frac{1}{k^{3}}){\\color{green} {1^TXY^TYY^T1}}+(+192\\frac{1}{k^{4}}){\\color{green} {1^TXX^TXX^TXY^T1}}+(+36\\frac{1}{k^{4}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TXX^T1}})+(+192\\frac{1}{k^{4}}){\\color{green} {1^TXY^TYX^TXY^T1}}+(+36\\frac{1}{k^{4}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})+(+192\\frac{1}{k^{4}}){\\color{green} {1^TYX^TXY^TYY^T1}}+(+36\\frac{1}{k^{4}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TYY^T1}})+(-144\\frac{1}{k^{5}})({\\color{green} {1^TXX^TXY^T1}})({\\color{orange} {1^TXX^T1}})+(-48\\frac{1}{k^{5}})({\\color{green} {1^TXY^TYY^T1}})({\\color{orange} {1^TXX^T1}})+(-168\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXX^TXY^T1}})+(-168\\frac{1}{k^{5}})({\\color{green} {1^TXY^TYY^T1}})({\\color{green} {1^TXY^T 1}})+(-48\\frac{1}{k^{5}})({\\color{green} {1^TXX^TXY^T1}})({\\color{orange} {1^TYY^T1}})+(-144\\frac{1}{k^{5}})({\\color{green} {1^TXY^TYY^T1}})({\\color{orange} {1^TYY^T1}})+(+72\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TXX^T1}})({\\color{orange} {1^TXX^T1}})+(-72\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})(1^TXX^TXX^T1)+(+72\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TXX^T1}})+(+48\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})+(+24\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TXX^T1}})({\\color{orange} {1^TYY^T1}})+(+72\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TYY^T1}})+(+72\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TYY^T1}})({\\color{orange} {1^TYY^T1}})+(-72\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})(1^TYY^TYY^T1)$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 4 gamma 2\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +(-168\\frac{1}{k^{3}}){\\color{green} {1^TXY^T 1}}+(+72\\frac{1}{k^{2}}){\\color{red} {tr(XY^T)}}+(+576\\frac{1}{k^{4}}){\\color{green} {1^TXX^TXY^T1}}+(-336\\frac{1}{k^{3}}){\\color{green} {tr(XX^TXY^T)}}+(+576\\frac{1}{k^{4}}){\\color{green} {1^TXY^TYY^T1}}+(-336\\frac{1}{k^{3}}){\\color{green} {tr(XY^TYY^T)}}+(-432\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TXX^T1}})+(+96\\frac{1}{k^{4}})({\\color{red} {tr(XY^T)}})({\\color{orange} {1^TXX^T1}})+(-384\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})+(+144\\frac{1}{k^{4}})({\\color{green} {1^TXY^T 1}})({\\color{red} {tr(XY^T)}})+(-432\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TYY^T1}})+(+144\\frac{1}{k^{4}}){\\color{red} {1^TXY^TXY^T1}}+(+96\\frac{1}{k^{4}})({\\color{red} {tr(XY^T)}})({\\color{orange} {1^TYY^T1}})+(+192\\frac{1}{k^{4}}){\\color{red} {1^TYX^TYY^T1}}+(+192\\frac{1}{k^{4}}){\\color{red} {1^TXX^TYX^T1}}+(+48\\frac{1}{k^{4}}){\\color{purple} {1^TXX^TYY^T1}}+(-48\\frac{1}{k^{5}})({\\color{orange} {1^TYY^T1}})({\\color{orange} {1^TXX^T1}})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 6 gamma 4\n"
     ]
    },
    {
     "data": {
      "text/latex": [
       "$\\displaystyle +(-34560\\frac{1}{k^{3}}+552960\\frac{1}{k^{4}}){\\color{green} {1^TXY^T 1}}+(+17280\\frac{1}{k^{2}}-432000\\frac{1}{k^{3}}){\\color{red} {tr(XY^T)}}+(-1693440\\frac{1}{k^{5}}+129600\\frac{1}{k^{4}}){\\color{green} {1^TXX^TXY^T1}}+(+1313280\\frac{1}{k^{4}}-86400\\frac{1}{k^{3}}){\\color{green} {tr(XX^TXY^T)}}+(-1693440\\frac{1}{k^{5}}+129600\\frac{1}{k^{4}}){\\color{green} {1^TXY^TYY^T1}}+(+1313280\\frac{1}{k^{4}}-86400\\frac{1}{k^{3}}){\\color{green} {tr(XY^TYY^T)}}+(-103680\\frac{1}{k^{5}}+1451520\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TXX^T1}})+(+25920\\frac{1}{k^{4}}-397440\\frac{1}{k^{5}})({\\color{red} {tr(XY^T)}})({\\color{orange} {1^TXX^T1}})+(+34560\\frac{1}{k^{4}}-587520\\frac{1}{k^{5}})({\\color{red} {tr(XY^T)}})({\\color{green} {1^TXY^T 1}})+(+1451520\\frac{1}{k^{6}}-103680\\frac{1}{k^{5}})({\\color{green} {1^TXY^T 1}})({\\color{orange} {1^TYY^T1}})+(-86400\\frac{1}{k^{5}}+1226880\\frac{1}{k^{6}})({\\color{green} {1^TXY^T 1}})({\\color{green} {1^TXY^T 1}})+(+34560\\frac{1}{k^{4}}-587520\\frac{1}{k^{5}}){\\color{red} {1^TXY^TXY^T1}}+(+25920\\frac{1}{k^{4}}-397440\\frac{1}{k^{5}})({\\color{red} {tr(XY^T)}})({\\color{orange} {1^TYY^T1}})+(-794880\\frac{1}{k^{5}}+51840\\frac{1}{k^{4}}){\\color{red} {1^TXX^TYX^T1}}+(-794880\\frac{1}{k^{5}}+51840\\frac{1}{k^{4}}){\\color{red} {1^TYX^TYY^T1}}+(+34560\\frac{1}{k^{4}}){\\color{red} {tr(XY^T XY^T)}}+(+34560\\frac{1}{k^{4}})({\\color{red} {tr(XY^T)}})({\\color{red} {tr(XY^T)}})+(-17280\\frac{1}{k^{2}})+(+60480\\frac{1}{k^{4}}){\\color{orange} {1^TXX^T1}}+(-207360\\frac{1}{k^{5}})1^TXX^TXX^T1+(+17280\\frac{1}{k^{4}}-241920\\frac{1}{k^{5}}){\\color{purple} {1^TXX^TYY^T1}}+(+60480\\frac{1}{k^{4}}){\\color{orange} {1^TYY^T1}}+(+155520\\frac{1}{k^{6}})({\\color{orange} {1^TXX^T1}})({\\color{orange} {1^TXX^T1}})+(+224640\\frac{1}{k^{6}}-17280\\frac{1}{k^{5}})({\\color{orange} {1^TYY^T1}})({\\color{orange} {1^TXX^T1}})+(+17280\\frac{1}{k^{4}}){\\color{blue}{ tr(XX^T YY^T)}}+(-207360\\frac{1}{k^{5}})1^TYY^TYY^T1+(+155520\\frac{1}{k^{6}})({\\color{orange} {1^TYY^T1}})({\\color{orange} {1^TYY^T1}})$"
      ],
      "text/plain": [
       "<IPython.core.display.Math object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# For checking that the relevant coefficients are strictly positive\n",
    "for num_beta_derivs, num_gamma_derivs in [(0,0), (2,2), (4, 0), (4,2), (6, 4)]:\n",
    "    curr_file = f'exp_terms/beta{num_beta_derivs}_gamma{num_gamma_derivs}.pkl'\n",
    "    print('beta', num_beta_derivs, 'gamma', num_gamma_derivs)\n",
    "    exp_terms = pickle.load(open(curr_file, 'rb'))\n",
    "    display_expectation_terms_compact(exp_terms,simple_mode=False,add_coeff=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51637e03-edf0-4a8b-9673-b5f568af8612",
   "metadata": {},
   "source": [
    "# Bonus: extra fuzzing code that was used to sanity-check that the computed functions were symmetric in x,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0a697b44-dd8a-4500-bde3-a5fd66cb7a7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beta 0 gamma 0\n",
      "Len 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 733.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "\n",
      "1\n",
      "1\n",
      "beta 2 gamma 2\n",
      "Len 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 1366.89it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "\n",
      "1 1\n",
      "\n",
      "2\n",
      "2\n",
      "beta 4 gamma 0\n",
      "Len 23\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 23/23 [00:24<00:00,  1.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "\n",
      "1 2\n",
      "\n",
      "3 7\n",
      "\n",
      "4 8\n",
      "\n",
      "5 5\n",
      "\n",
      "6 6\n",
      "\n",
      "9 14\n",
      "\n",
      "10 13\n",
      "\n",
      "11 12\n",
      "\n",
      "15 21\n",
      "\n",
      "16 22\n",
      "\n",
      "17 20\n",
      "\n",
      "18 18\n",
      "\n",
      "19 19\n",
      "\n",
      "23\n",
      "23\n",
      "beta 4 gamma 2\n",
      "Len 17\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 48.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "\n",
      "1 1\n",
      "\n",
      "2 4\n",
      "\n",
      "3 5\n",
      "\n",
      "6 10\n",
      "\n",
      "7 12\n",
      "\n",
      "8 8\n",
      "\n",
      "9 9\n",
      "\n",
      "11 11\n",
      "\n",
      "13 14\n",
      "\n",
      "15 15\n",
      "\n",
      "16 16\n",
      "\n",
      "17\n",
      "17\n",
      "beta 6 gamma 4\n",
      "Len 44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44/44 [00:00<00:00, 52.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0\n",
      "\n",
      "1 1\n",
      "\n",
      "2 6\n",
      "\n",
      "3 3\n",
      "\n",
      "4 7\n",
      "\n",
      "5 5\n",
      "\n",
      "8 15\n",
      "\n",
      "9 12\n",
      "\n",
      "10 18\n",
      "\n",
      "11 14\n",
      "\n",
      "13 13\n",
      "\n",
      "16 16\n",
      "\n",
      "17 17\n",
      "\n",
      "19 24\n",
      "\n",
      "20 26\n",
      "\n",
      "21 21\n",
      "\n",
      "22 31\n",
      "\n",
      "23 23\n",
      "\n",
      "25 25\n",
      "\n",
      "27 28\n",
      "\n",
      "29 29\n",
      "\n",
      "30 30\n",
      "\n",
      "32 32\n",
      "\n",
      "33 36\n",
      "\n",
      "34 42\n",
      "\n",
      "35 35\n",
      "\n",
      "37 43\n",
      "\n",
      "38 38\n",
      "\n",
      "39 39\n",
      "\n",
      "40 40\n",
      "\n",
      "41 41\n",
      "\n",
      "44\n",
      "44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def fuzz_compare_symm(exp_terms):\n",
    "    k_fuzz = 5\n",
    "    m_fuzz = 8\n",
    "    num_fuzzers = 10\n",
    "\n",
    "    fuzzers = []\n",
    "    fuzzers2 = []\n",
    "    for i in range(num_fuzzers):\n",
    "        X = rand_sequence(k=k_fuzz,m=m_fuzz)\n",
    "        Y = rand_sequence(k=k_fuzz,m=m_fuzz)\n",
    "        fuzzers.append((X,Y))\n",
    "        fuzzers2.append((Y,X))\n",
    "\n",
    "    print('Fuzzing for symm')\n",
    "    evals1 = eval_on_fuzzers(exp_terms, fuzzers)\n",
    "    evals2 = eval_on_fuzzers(exp_terms, fuzzers2)\n",
    "    print(evals1)\n",
    "    print(evals2)\n",
    "    for i in range(num_fuzzers):\n",
    "        assert(abs(evals1[i] - evals2[i]) < 1e-6)\n",
    "        \n",
    "def fuzz_match_terms_by_symm(exp_terms,k,m,num_fuzzers,ignore_coeff=True, ignore_k_exp=False):\n",
    "    fuzzers = []\n",
    "    for i in range(num_fuzzers):\n",
    "        X = rand_sequence(k=k,m=m)\n",
    "        Y = rand_sequence(k=k,m=m)\n",
    "        fuzzers.append((X,Y))\n",
    "\n",
    "    fuzzer_outs = {}\n",
    "    for i in tqdm(range(len(exp_terms))):\n",
    "        # print(exp_terms[i])\n",
    "        fuzzer_outs[(i,0)] = []\n",
    "        fuzzer_outs[(i,1)] = []\n",
    "        for f in fuzzers:\n",
    "            ev = eval_exp_term(f[0],f[1],exp_terms[i], ignore_coeff=ignore_coeff, ignore_k_exp=ignore_k_exp)\n",
    "            ev2 = eval_exp_term(f[1],f[0],exp_terms[i], ignore_coeff=ignore_coeff, ignore_k_exp=ignore_k_exp)\n",
    "            fuzzer_outs[(i,0)].append(ev)\n",
    "            fuzzer_outs[(i,1)].append(ev2)\n",
    "        fuzzer_outs[(i,0)] = tuple(fuzzer_outs[(i,0)])\n",
    "        fuzzer_outs[(i,1)] = tuple(fuzzer_outs[(i,1)])\n",
    "    return fuzzer_outs\n",
    "\n",
    "\n",
    "for num_beta_derivs in [0,2,4,6]:\n",
    "    for num_gamma_derivs in [0,2,4,6,8,10]:\n",
    "        curr_file = f'exp_terms/beta{num_beta_derivs}_gamma{num_gamma_derivs}.pkl'\n",
    "        if not os.path.exists(curr_file):\n",
    "            continue\n",
    "       \n",
    "            \n",
    "        print('beta', num_beta_derivs, 'gamma', num_gamma_derivs)\n",
    "        exp_terms = pickle.load(open(curr_file, 'rb'))\n",
    "        print('Len', len(exp_terms))\n",
    "        # print(get_expectation_terms_compact_str(exp_terms,simple_mode=False, add_coeff=False))\n",
    "\n",
    "        fuzzer_outs = fuzz_match_terms_by_symm(exp_terms,k=7,m=3,num_fuzzers=10, ignore_k_exp=False, ignore_coeff=False)\n",
    "        matched = set()\n",
    "        nf = len(fuzzer_outs[(0,0)])\n",
    "        for i in range(len(exp_terms)):\n",
    "            for s in range(i,len(exp_terms)):\n",
    "                if np.all([abs(fuzzer_outs[(i,0)][j] - fuzzer_outs[(s,1)][j]) < 1e-6 for j in range(nf)]):\n",
    "                    # print(i,s)\n",
    "                    print(i,s)\n",
    "                    if i in matched:\n",
    "                        print(i)\n",
    "                        assert(False)\n",
    "                    if s in matched:\n",
    "                        print(s)\n",
    "                        assert(False)\n",
    "                    matched.add(i)\n",
    "                    matched.add(s)\n",
    "                    # print(exp_terms[i])\n",
    "                    # print(exp_terms[s])\n",
    "                    # display_expectation_terms_compact([exp_terms[i]],simple_mode=False)\n",
    "                    # display_expectation_terms_compact([exp_terms[s]],simple_mode=False)\n",
    "                    print()\n",
    "\n",
    "\n",
    "        # All terms are matched\n",
    "        print(len(matched))\n",
    "        print(len(exp_terms))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b2af25d-051c-4f81-8313-ad8476a83393",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
