{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "import itertools\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from mindreadingautobots.entropy_and_bayesian import boolean\n",
    "from mindreadingautobots.sequence_generators import make_datasets, data_io\n",
    "import os\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Find a trick (Joker) function \n",
    "\n",
    "Recall that we start with uniformly random bitstrings $X$ (each length $n$ bitstrings), then we generate $Y=f(X)$ and create a pair $(X, Y)$. This pair $X,Y$ has a joint distribution. By applying bitflips to turn $X \\rightarrow Z$, we end up with a new joinst distr $p_{ZY}$ where $Z$ is a bitflipped version of $X$ and $Y=f(X)$.\n",
    "\n",
    "By trick function, we mean that for data with  we are looking for $g^*$ (optimal prediction for _noisy_ data) is more accuracte and less sensitive than $f$ (function used to generate noiseless data).\n",
    "\n",
    "We want an example of data and noise such that MLD (for the noisy data) evaluated on noiseless data is _worse_ than MLD (for noiseless data) evaluated on noiseless data.\n",
    "\n",
    "Suppose $f^*(z^{n-1}) = \\argmax_{x'} p_{X|Z^{n-1}}(x|z^{n-1})$ is our MLD for noisy data and $g*(x^{n-1}) = \\argmax_{x'} p_{X|X^{n-1}}(x|x^{n-1})$ is our MLD for noiseless data. We build $f^*$ analytically, then compare it to $g^*$. \n",
    "\n",
    "Our example will be $k=3$ majority function (in which case $g^*$ is just majority)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Searching for weight-based counterexamples**\n",
    "\n",
    "Instead of defining boolean functions as $f:\\{0,1\\}^n \\rightarrow \\{0,1\\}$, we define $f$ such that \n",
    "$$\n",
    "f(x) = g(\\text{wt}(x))\n",
    "$$\n",
    "what this does is, there are only $2^{n+1}$ possible functions $g$, instead of $2^{2^n}$ possible functions $f$. For example, if $n=3$, then a possible $g$ is \n",
    "\\begin{equation}\n",
    "     \\text{wt}(x) \\rightarrow \\begin{cases}\n",
    "        0 \\rightarrow 0 \\\\\n",
    "        1 \\rightarrow 1 \\\\\n",
    "        2 \\rightarrow 0 \\\\\n",
    "        3 \\rightarrow 1\n",
    "    \\end{cases}\n",
    "\\end{equation}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Searching for joker functions - weight based functions\n",
    "\n",
    "**constraints**:\n",
    "- we want both senstivities to be $\\gg 0$\n",
    "- we want function accuracies to be $\\gg 1/2$ (this is equivalent to $p$ not big)\n",
    "- we want functions that are more balanced (`imbal` closer to 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We go for a larger search space on the python script\n",
    "for n in [4]:\n",
    "    for p in [0.2, 0.22]:\n",
    "    \n",
    "        X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\n",
    "\n",
    "        # p_x = 1 / (2 ** n) # uniform distribution over x # chaos distribution and the thing with chaos distribution is its fair\n",
    "        # WEIGHT-BASED FUNCTIONS\n",
    "\n",
    "        signatures = itertools.product([0, 1], repeat=n+1)\n",
    "        f_accs = []\n",
    "        fn_accs = []\n",
    "        fn_noiseless_accs = []\n",
    "        imbal_list = []\n",
    "        sentitivity_f_list = []\n",
    "        sensitivity_fnstar_list = []\n",
    "        sensitivity_diff_list = []\n",
    "\n",
    "        for signature in signatures:\n",
    "\n",
    "            hash = dict(zip(range(n+1), signature))\n",
    "            func = lambda b: hash[sum(b)]\n",
    "\n",
    "\n",
    "            \n",
    "            noisy_lookup = np.zeros((2, 2**n)) # noisy_lookup[row,col] is the JOINT probability Pr(f(z)=row| x=col)\n",
    "            true_lookup = np.zeros((2, 2**n)) # true lookup is an array with 2 rows; there is a p_x at [row, column] if  \n",
    "                                            # f[column] = row]. so, true_lookup[i, j] = pr(f(x) = i| x=j)\n",
    "\n",
    "            for i, x in enumerate(X_arr):\n",
    "\n",
    "                func_value = func(x) \n",
    "                true_lookup[func(x), i] = 1\n",
    "\n",
    "                # Iterate over all possible noisy strings\n",
    "                for e in product([0, 1], repeat=n):\n",
    "\n",
    "                    z = np.array(x) ^ np.array(e)\n",
    "                    p_x_given_z = p ** sum(e) * (1-p)**(n - sum(e)) \n",
    "\n",
    "                    noisy_lookup[func_value, int(''.join(map(str, z)), 2)] += p_x_given_z \n",
    "\n",
    "            imbal = abs(true_lookup[0,:].sum() - true_lookup[1,:].sum())  / 2 ** n\n",
    "            imbal_list.append(imbal)\n",
    "\n",
    "            noisy_mle = np.round(noisy_lookup)  \n",
    "            out = np.multiply(noisy_mle, true_lookup) / 2 ** n # \"inner product\" of the functions\n",
    "            diff = out.sum()\n",
    "\n",
    "\n",
    "            fnstar_dct = {}\n",
    "\n",
    "            for i, x in enumerate(X_arr):\n",
    "                fnstar_dct[tuple(x)] = np.argmax(noisy_lookup[:, i])\n",
    "\n",
    "            def fnstar(x):\n",
    "                return fnstar_dct[tuple(x)]\n",
    "\n",
    "            sensitivity_f = boolean.average_sensitivity(func, X_arr)\n",
    "            sensitivity_fnstar = boolean.average_sensitivity(fnstar, X_arr)\n",
    "            sensitivity_diff = sensitivity_f - sensitivity_fnstar\n",
    "\n",
    "            # accuracies on dataset\n",
    "            p_zy = boolean.generate_noisy_distr(n, p, func)\n",
    "            noisy_f_acc = boolean.compute_acc_noisytest(p_zy, func, n) # accuracy of f on noisy data\n",
    "            noiseless_fnstar_acc = boolean.compute_acc_test(fnstar, func, n) # accuracy of fN* on noiseless data\n",
    "            noisy_fnstar_acc = boolean.compute_acc_noisytest(p_zy, fnstar, n) # accuracy of fN* MLE on noisy data\n",
    "\n",
    "            f_accs.append(noisy_f_acc)\n",
    "            fn_accs.append(noisy_fnstar_acc)\n",
    "            fn_noiseless_accs.append(noiseless_fnstar_acc)\n",
    "            sentitivity_f_list.append(sensitivity_f)\n",
    "            sensitivity_fnstar_list.append(sensitivity_fnstar)\n",
    "            sensitivity_diff_list.append(sensitivity_diff)\n",
    "\n",
    "\n",
    "        signatures_list = list(itertools.product([0, 1], repeat=n+1))\n",
    "        df = pd.DataFrame({'signature': signatures_list, 'f_acc': f_accs, 'fn_acc': fn_accs, 'fn_noiseless_acc': fn_noiseless_accs,\n",
    "                        'imbalance': imbal_list, 'sensitivity_f': sentitivity_f_list, 'sensitivity_fnstar': sensitivity_fnstar_list, 'sensitivity_diff': sensitivity_diff_list,\n",
    "                        'bitflip': len(signatures_list)*[p]})\n",
    "        df['acc_diff'] = df['f_acc'] - df['fn_acc']\n",
    "\n",
    "        df_filtered = df[\n",
    "            (df['imbalance'] < 1) &\n",
    "            (df['sensitivity_f'] > 0) &\n",
    "            (df['sensitivity_fnstar'] > 0) &\n",
    "            (df['sensitivity_diff'] > 0) &\n",
    "            (df['fn_acc'] > 0.6) &\n",
    "            (df['f_acc'] > 0.6) &\n",
    "            (df['acc_diff'] != 0)\n",
    "        ]\n",
    "\n",
    "        if len(df_filtered) > 0:\n",
    "\n",
    "            df_filtered.to_csv(f'dentsets/weight_functions_n={n}_p={p}.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Making datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Surviving samples: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_922262/3958740254.py:21: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_big_filtered['distance'] = np.sqrt(df_big_filtered['acc_diff']**2 + df_big_filtered['sensitivity_diff']**2)\n",
      "/tmp/ipykernel_922262/3958740254.py:22: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_big_filtered['n'] = df_big_filtered['signature'].apply(lambda x: len(ast.literal_eval(x)) - 1)\n",
      "/tmp/ipykernel_922262/3958740254.py:23: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df_big_filtered.sort_values(by='distance', ascending=False, inplace=True)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>signature</th>\n",
       "      <th>f_acc</th>\n",
       "      <th>fn_acc</th>\n",
       "      <th>fn_noiseless_acc</th>\n",
       "      <th>imbalance</th>\n",
       "      <th>sensitivity_f</th>\n",
       "      <th>sensitivity_fnstar</th>\n",
       "      <th>sensitivity_diff</th>\n",
       "      <th>bitflip</th>\n",
       "      <th>acc_diff</th>\n",
       "      <th>distance</th>\n",
       "      <th>n</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>957</th>\n",
       "      <td>(1, 1, 1, 1, 0, 0, 1, 1, 1)</td>\n",
       "      <td>0.620041</td>\n",
       "      <td>0.623912</td>\n",
       "      <td>0.890625</td>\n",
       "      <td>0.015625</td>\n",
       "      <td>3.500000</td>\n",
       "      <td>2.62500</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.003870</td>\n",
       "      <td>0.875009</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>949</th>\n",
       "      <td>(1, 1, 1, 0, 0, 1, 1, 1, 1)</td>\n",
       "      <td>0.620041</td>\n",
       "      <td>0.623912</td>\n",
       "      <td>0.890625</td>\n",
       "      <td>0.015625</td>\n",
       "      <td>3.500000</td>\n",
       "      <td>2.62500</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.003870</td>\n",
       "      <td>0.875009</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>776</th>\n",
       "      <td>(0, 0, 0, 1, 1, 0, 0, 0, 0)</td>\n",
       "      <td>0.620041</td>\n",
       "      <td>0.623912</td>\n",
       "      <td>0.890625</td>\n",
       "      <td>0.015625</td>\n",
       "      <td>3.500000</td>\n",
       "      <td>2.62500</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.003870</td>\n",
       "      <td>0.875009</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>768</th>\n",
       "      <td>(0, 0, 0, 0, 1, 1, 0, 0, 0)</td>\n",
       "      <td>0.620041</td>\n",
       "      <td>0.623912</td>\n",
       "      <td>0.890625</td>\n",
       "      <td>0.015625</td>\n",
       "      <td>3.500000</td>\n",
       "      <td>2.62500</td>\n",
       "      <td>0.875000</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.003870</td>\n",
       "      <td>0.875009</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>459</th>\n",
       "      <td>(1, 1, 1, 0, 0, 1, 1, 0)</td>\n",
       "      <td>0.612732</td>\n",
       "      <td>0.615028</td>\n",
       "      <td>0.992188</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>3.390625</td>\n",
       "      <td>3.28125</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.002296</td>\n",
       "      <td>0.109399</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>418</th>\n",
       "      <td>(1, 0, 0, 1, 1, 0, 0, 0)</td>\n",
       "      <td>0.612732</td>\n",
       "      <td>0.615028</td>\n",
       "      <td>0.992188</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>3.390625</td>\n",
       "      <td>3.28125</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.002296</td>\n",
       "      <td>0.109399</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>399</th>\n",
       "      <td>(0, 1, 1, 0, 0, 1, 1, 1)</td>\n",
       "      <td>0.612732</td>\n",
       "      <td>0.615028</td>\n",
       "      <td>0.992188</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>3.390625</td>\n",
       "      <td>3.28125</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.002296</td>\n",
       "      <td>0.109399</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>358</th>\n",
       "      <td>(0, 0, 0, 1, 1, 0, 0, 1)</td>\n",
       "      <td>0.612732</td>\n",
       "      <td>0.615028</td>\n",
       "      <td>0.992188</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>3.390625</td>\n",
       "      <td>3.28125</td>\n",
       "      <td>0.109375</td>\n",
       "      <td>0.2</td>\n",
       "      <td>-0.002296</td>\n",
       "      <td>0.109399</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                       signature     f_acc    fn_acc  fn_noiseless_acc  \\\n",
       "957  (1, 1, 1, 1, 0, 0, 1, 1, 1)  0.620041  0.623912          0.890625   \n",
       "949  (1, 1, 1, 0, 0, 1, 1, 1, 1)  0.620041  0.623912          0.890625   \n",
       "776  (0, 0, 0, 1, 1, 0, 0, 0, 0)  0.620041  0.623912          0.890625   \n",
       "768  (0, 0, 0, 0, 1, 1, 0, 0, 0)  0.620041  0.623912          0.890625   \n",
       "459     (1, 1, 1, 0, 0, 1, 1, 0)  0.612732  0.615028          0.992188   \n",
       "418     (1, 0, 0, 1, 1, 0, 0, 0)  0.612732  0.615028          0.992188   \n",
       "399     (0, 1, 1, 0, 0, 1, 1, 1)  0.612732  0.615028          0.992188   \n",
       "358     (0, 0, 0, 1, 1, 0, 0, 1)  0.612732  0.615028          0.992188   \n",
       "\n",
       "     imbalance  sensitivity_f  sensitivity_fnstar  sensitivity_diff  bitflip  \\\n",
       "957   0.015625       3.500000             2.62500          0.875000      0.2   \n",
       "949   0.015625       3.500000             2.62500          0.875000      0.2   \n",
       "776   0.015625       3.500000             2.62500          0.875000      0.2   \n",
       "768   0.015625       3.500000             2.62500          0.875000      0.2   \n",
       "459   0.109375       3.390625             3.28125          0.109375      0.2   \n",
       "418   0.109375       3.390625             3.28125          0.109375      0.2   \n",
       "399   0.109375       3.390625             3.28125          0.109375      0.2   \n",
       "358   0.109375       3.390625             3.28125          0.109375      0.2   \n",
       "\n",
       "     acc_diff  distance  n  \n",
       "957 -0.003870  0.875009  8  \n",
       "949 -0.003870  0.875009  8  \n",
       "776 -0.003870  0.875009  8  \n",
       "768 -0.003870  0.875009  8  \n",
       "459 -0.002296  0.109399  7  \n",
       "418 -0.002296  0.109399  7  \n",
       "399 -0.002296  0.109399  7  \n",
       "358 -0.002296  0.109399  7  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import glob\n",
    "import ast\n",
    "\n",
    "# Read all CSV files matching the pattern\n",
    "file_pattern = 'dentsets/weight_functions_n=*_p=*.csv'\n",
    "all_files = glob.glob(file_pattern)\n",
    "\n",
    "minimum_acc = 0.6\n",
    "maximum_imbalance = 0.25\n",
    "minimum_sensitivity = 1.0\n",
    "minimum_acc_diff = 0.00000001\n",
    "maximum_acc_dff = 0.004\n",
    "\n",
    "# Concatenate all dataframes\n",
    "df_big = pd.concat((pd.read_csv(file) for file in all_files), ignore_index=True)\n",
    "\n",
    "df_big_filtered = df_big[(df_big['imbalance'] < maximum_imbalance) & (df_big['sensitivity_f'] > minimum_sensitivity) & \n",
    "(df_big['sensitivity_fnstar'] > minimum_sensitivity) & (df_big['fn_acc'] > minimum_acc) & (df_big['f_acc'] > minimum_acc) & (np.abs(df_big['acc_diff']) > minimum_acc_diff) & \n",
    "(np.abs(df_big['acc_diff']) < maximum_acc_dff)]\n",
    "\n",
    "df_big_filtered['distance'] = np.sqrt(df_big_filtered['acc_diff']**2 + df_big_filtered['sensitivity_diff']**2)\n",
    "df_big_filtered['n'] = df_big_filtered['signature'].apply(lambda x: len(ast.literal_eval(x)) - 1)\n",
    "df_big_filtered.sort_values(by='distance', ascending=False, inplace=True)\n",
    "print(f'Surviving samples: {len(df_big_filtered)}')\n",
    "df_big_filtered.head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# golden_signatures = df_big_filtered['signature'].values[:5]\n",
    "# golden_signatures = [ast.literal_eval(sig) for sig in golden_signatures]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating counterexample000110000_nbits10_n10000_bf20_seed1234 with p_bitflip=0.2\n",
      "idx for sparse function: save these: [0, 1, 2, 3, 5, 6, 7, 8]\n",
      "Saved /u/mhzambia/ResearchDocuments/MindReadingAutobot/mindreadingautobots/data/counterexample000110000_nbits10_n10000_bf20_seed1234/train.pkl, /u/mhzambia/ResearchDocuments/MindReadingAutobot/mindreadingautobots/data/counterexample000110000_nbits10_n10000_bf20_seed1234/val.pkl, /u/mhzambia/ResearchDocuments/MindReadingAutobot/mindreadingautobots/data/counterexample000110000_nbits10_n10000_bf20_seed1234/noiseless_train.pkl, /u/mhzambia/ResearchDocuments/MindReadingAutobot/mindreadingautobots/data/counterexample000110000_nbits10_n10000_bf20_seed1234/noiseless_val.pkl\n"
     ]
    }
   ],
   "source": [
    "# idx = [[3, 6, 7, 11, 13, 14, 18], \n",
    "#        [1, 2, 5, 7, 11, 12, 13], \n",
    "#        [0, 3, 5, 6, 7, 12, 13], \n",
    "#        [2, 5, 6, 9, 10, 15, 18],\n",
    "#        [4, 7, 8, 11, 14, 16, 17]]\n",
    "\n",
    "idx = [[0, 1, 2, 3, 5, 6, 7, 8]]\n",
    "\n",
    "p_bitflips = [0.2]\n",
    "n_bits = 10\n",
    "seed = 1234\n",
    "n_val = 15000 # number of validation examples\n",
    "n_train = 10000\n",
    "\n",
    "selected_signatures = [(0, 0, 0, 1, 1, 0, 0, 0, 0)]\n",
    "\n",
    "for signature, subseq_idx in zip(selected_signatures, idx):\n",
    "\n",
    "    k = len(subseq_idx)\n",
    "    gen_name = \"counterexample\" + \"\".join([str(i) for i in signature])\n",
    "    signature = dict(zip(range(len(signature)), signature))\n",
    "    \n",
    "    for p_bitflip in p_bitflips:\n",
    "\n",
    "        p100 = int(p_bitflip*100)\n",
    "        suffix = f\"_nbits{n_bits}_n{n_train}_bf{p100}_seed{seed}\"\n",
    "        dirname = gen_name + suffix\n",
    "        print(f\"Generating {dirname} with p_bitflip={p_bitflip}\")\n",
    "        # If your dataset has a hidden subset, update this list:\n",
    "\n",
    "        X, Z, subseq_idx = make_datasets.sparse_boolean_weightbased_k_n(n_bits, k, n_train + n_val, signature, p_bitflip=p_bitflip, seed=seed, subseq_idx=subseq_idx)\n",
    "        print(\"idx for sparse function: save these:\", subseq_idx)\n",
    "\n",
    "        if p_bitflip == 0:\n",
    "            Z = X\n",
    "            \n",
    "        Z_train = Z[:n_train]\n",
    "        Z_val = Z[n_train:]\n",
    "\n",
    "        # Check if the data directory exists, if not create it\n",
    "        if not os.path.exists(dirname):\n",
    "            os.makedirs(dirname)\n",
    "            base_dir = os.path.abspath(os.path.join(os.getcwd(), \"../../../../data\"))\n",
    "            target_dir = os.path.join(base_dir, dirname)\n",
    "\n",
    "            # Check if the data directory exists, if not create it\n",
    "            if not os.path.exists(target_dir):\n",
    "                os.makedirs(target_dir)\n",
    "\n",
    "            train_path = os.path.join(target_dir, \"train.pkl\")\n",
    "            val_path = os.path.join(target_dir, \"val.pkl\")\n",
    "            data_io.save_numpy_as_dict(Z_train, train_path)\n",
    "            data_io.save_numpy_as_dict(Z_val, val_path)\n",
    "\n",
    "            X_train = X[:n_train]\n",
    "            X_val = X[n_train:]\n",
    "            noiseless_train_path = os.path.join(target_dir, \"noiseless_train.pkl\")\n",
    "            noiseless_val_path = os.path.join(target_dir, \"noiseless_val.pkl\")\n",
    "            data_io.save_numpy_as_dict(X_train, noiseless_train_path)\n",
    "            data_io.save_numpy_as_dict(X_val, noiseless_val_path)\n",
    "            print(f\"Saved {train_path}, {val_path}, {noiseless_train_path}, {noiseless_val_path}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Different search: not weight based."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this code looks at all possible boolean functions that are perfectly balanced\n",
    "\n",
    "def boolean_function_from_signature(f):\n",
    "    \"\"\"Given a length 2^n binary array, return the function with that boolean signature.\"\"\"\n",
    "    X_arr = list(itertools.product([0, 1], repeat=n))\n",
    "    X_arr = [tuple(x) for x in X_arr]\n",
    "    lookup = dict(zip(X_arr, f))\n",
    "    def func(x):\n",
    "        return lookup[tuple(x)]\n",
    "    return lookup, func\n",
    "\n",
    "n = 4\n",
    "# WARNING: n=4 might take 5 minutes\n",
    "assert n <= 4\n",
    "k = n\n",
    "# pvals = [0.01, 0.25, 0.49]\n",
    "pvals = [0.49]\n",
    "\n",
    "# a \"signature\" of a boolean function is a length 2**n bitstring S where f(x) = S[bin(x)]\n",
    "# we will check signatures for perfectly balanced functions\n",
    "sig_arr = np.array(list(itertools.product([0, 1], repeat=2**n)))\n",
    "sig_arr = sig_arr[sig_arr.sum(axis=1) == 2**(n-1)] # only balanced functions\n",
    "\n",
    "X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\n",
    "p_x = 1 / (2 ** k) # uniform distribution over x\n",
    "for i, signature in enumerate(sig_arr):\n",
    "    # iterate over signatures\n",
    "    print(signature)\n",
    "    dct, func = boolean_function_from_signature(signature)\n",
    "    for p in pvals:        \n",
    "        # noisy_lookup[row,col] is the JOINT probability Pr(f(z)=row| x=col)\n",
    "        noisy_lookup = np.zeros((2, 2**n))\n",
    "        true_lookup = np.zeros((2, 2**n))\n",
    "        # simulate a noisy dataset essentially\n",
    "        for i, x in enumerate(product([0,1], repeat=k)):\n",
    "            func_value = func(x)\n",
    "            # true lookup is an array with 2 rows; there is a p_x at [row, column] if \n",
    "            # f[column] = row]. so, true_lookup[i, j] = pr(f(x) = i| x=j)\n",
    "            true_lookup[func(x), i] = 1\n",
    "            # iterate over all of the z values that contribute to \n",
    "            for e in product([0, 1], repeat=k):\n",
    "                z = np.array(x) ^ np.array(e)\n",
    "                p_x_given_z = p ** sum(e) * (1-p)**(k - sum(e))\n",
    "                # increment noisy_lookup at the binary index of z\n",
    "                # noisy_lookup[i, j] = pr(f(z) = i,  x=j) \n",
    "                noisy_lookup[func_value, int(''.join(map(str, z)), 2)] += p_x_given_z \n",
    "        \n",
    "        # the function is balanced if the sums of the two rows of true_lookup are equal\n",
    "        # imbal = abs(true_lookup[0,:].sum() - true_lookup[1,:].sum())  / 2 ** n\n",
    "        # round up to get argmax \n",
    "        noisy_mle = np.round(noisy_lookup)  \n",
    "        out = np.multiply(noisy_mle, true_lookup) / 2 ** n # \"inner product\" of the functions\n",
    "        diff = out.sum()\n",
    "        fnstar_dct = {}\n",
    "        for i, x in enumerate(X_arr):\n",
    "            fnstar_dct[tuple(x)] = np.argmax(noisy_lookup[:, i])\n",
    "        def fnstar(x):\n",
    "            return fnstar_dct[tuple(x)]\n",
    "        \n",
    "        sensitivity_f = boolean.average_sensitivity(func, X_arr)\n",
    "        sensitivity_fnstar = boolean.average_sensitivity(fnstar, X_arr)\n",
    "        sensitivity_diff = sensitivity_f - sensitivity_fnstar\n",
    "        # accuracies on dataset\n",
    "        p_zy = boolean.generate_noisy_distr(k, p, func)\n",
    "        noisy_f_acc = boolean.compute_acc_noisytest(p_zy, func, n) # accuracy of f on noisy data\n",
    "        # noiseless_fnstar_acc = compute_acc_test(fnstar, func, n) # accuracy of fN* on noiseless data\n",
    "        noisy_fnstar_acc = boolean.compute_acc_noisytest(p_zy, fnstar, n) # accuracy of fN* MLE on noisy data\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Random search over boolean functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "!!!p=0\n",
      "boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\n",
      "--------------------------------------------------------------------------\n",
      "  [0 0 0 0 0 0]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "  [0 0 0 0 0 1]   | 0.9375 |   1.0000  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [0 0 0 0 1 0]   | 0.6875 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 0 0 0 1 1]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [0 0 0 1 0 0]   | 0.3750 |   1.0000  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 0 0 1 0 1]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 0 0 1 1 0]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [0 0 0 1 1 1]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 0 1 0 0 0]   | 0.3750 |   1.0000  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 0 1 0 0 1]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 0 1 0 1 0]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 0 1 0 1 1]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [0 0 1 1 0 0]   | 0.2500 |   1.0000  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [0 0 1 1 0 1]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [0 0 1 1 1 0]   | 0.5625 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 0 1 1 1 1]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [0 1 0 0 0 0]   | 0.6875 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 1 0 0 0 1]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 1 0 0 1 0]   | 0.3750 |   1.0000  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 1 0 0 1 1]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [0 1 0 1 0 0]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 1 0 1 0 1]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [0 1 0 1 1 0]   | 0.2500 |   1.0000  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      "  [0 1 0 1 1 1]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      ">>[0 1 1 0 0 0]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [0 1 1 0 0 1]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      ">>[0 1 1 0 1 0]   | 0.2500 |   1.0000  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      "  [0 1 1 0 1 1]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 1 1 1 0 0]   | 0.5625 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 1 1 1 0 1]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 1 1 1 1 0]   | 0.8750 |   1.0000  | 1.0000 | 0.0000    |0.6250|0.6250  |  0.0000 \n",
      "  [0 1 1 1 1 1]   | 0.9375 |   1.0000  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [1 0 0 0 0 0]   | 0.9375 |   1.0000  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [1 0 0 0 0 1]   | 0.8750 |   1.0000  | 1.0000 | 0.0000    |0.6250|0.6250  |  0.0000 \n",
      "  [1 0 0 0 1 0]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 0 0 0 1 1]   | 0.5625 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 0 0 1 0 0]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 0 0 1 0 1]   | 0.2500 |   1.0000  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      ">>[1 0 0 1 1 0]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [1 0 0 1 1 1]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [1 0 1 0 0 0]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 0 1 0 0 1]   | 0.2500 |   1.0000  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      "  [1 0 1 0 1 0]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [1 0 1 0 1 1]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 0 1 1 0 0]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [1 0 1 1 0 1]   | 0.3750 |   1.0000  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 0 1 1 1 0]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 0 1 1 1 1]   | 0.6875 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 1 0 0 0 0]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [1 1 0 0 0 1]   | 0.5625 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 1 0 0 1 0]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [1 1 0 0 1 1]   | 0.2500 |   1.0000  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [1 1 0 1 0 0]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [1 1 0 1 0 1]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 1 0 1 1 0]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 1 0 1 1 1]   | 0.3750 |   1.0000  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 1 1 0 0 0]   | 0.0000 |   1.0000  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 1 1 0 0 1]   | 0.0625 |   1.0000  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [1 1 1 0 1 0]   | 0.3125 |   1.0000  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 1 1 0 1 1]   | 0.3750 |   1.0000  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 1 1 1 0 0]   | 0.6250 |   1.0000  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [1 1 1 1 0 1]   | 0.6875 |   1.0000  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 1 1 1 1 0]   | 0.9375 |   1.0000  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [1 1 1 1 1 1]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "\n",
      "!!!p=0.1\n",
      "boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\n",
      "--------------------------------------------------------------------------\n",
      "  [0 0 0 0 0 0]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "  [0 0 0 0 0 1]   | 0.9375 |   0.9744  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [0 0 0 0 1 0]   | 0.6875 |   0.8811  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 0 0 0 1 1]   | 0.6250 |   0.8966  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [0 0 0 1 0 0]   | 0.3750 |   0.7716  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 0 0 1 0 1]   | 0.3125 |   0.7551  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 0 0 1 1 0]   | 0.0625 |   0.8198  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [0 0 0 1 1 1]   | 0.0000 |   0.8443  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 0 1 0 0 0]   | 0.3750 |   0.7716  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 0 1 0 0 1]   | 0.3125 |   0.7470  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 0 1 0 1 0]   | 0.0625 |   0.6803  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 0 1 0 1 1]   | 0.0000 |   0.6967  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [0 0 1 1 0 0]   | 0.2500 |   0.7953  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [0 0 1 1 0 1]   | 0.3125 |   0.7798  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [0 0 1 1 1 0]   | 0.5625 |   0.8710  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 0 1 1 1 1]   | 0.6250 |   0.8966  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [0 1 0 0 0 0]   | 0.6875 |   0.8811  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 1 0 0 0 1]   | 0.6250 |   0.8556  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 1 0 0 1 0]   | 0.3750 |   0.7643  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 1 0 0 1 1]   | 0.3125 |   0.7798  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [0 1 0 1 0 0]   | 0.0625 |   0.6803  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 1 0 1 0 1]   | 0.0000 |   0.6638  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [0 1 0 1 1 0]   | 0.2500 |   0.7305  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      "  [0 1 0 1 1 1]   | 0.3125 |   0.7551  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      ">>[0 1 1 0 0 0]   | 0.0625 |   0.8198  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [0 1 1 0 0 1]   | 0.0000 |   0.7952  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      ">>[0 1 1 0 1 0]   | 0.2500 |   0.7305  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      "  [0 1 1 0 1 1]   | 0.3125 |   0.7470  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 1 1 1 0 0]   | 0.5625 |   0.8710  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [0 1 1 1 0 1]   | 0.6250 |   0.8556  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 1 1 1 1 0]   | 0.8750 |   0.9488  | 1.0000 | 0.0000    |0.6250|0.6250  |  0.0000 \n",
      "  [0 1 1 1 1 1]   | 0.9375 |   0.9744  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [1 0 0 0 0 0]   | 0.9375 |   0.9744  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [1 0 0 0 0 1]   | 0.8750 |   0.9488  | 1.0000 | 0.0000    |0.6250|0.6250  |  0.0000 \n",
      "  [1 0 0 0 1 0]   | 0.6250 |   0.8556  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 0 0 0 1 1]   | 0.5625 |   0.8710  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 0 0 1 0 0]   | 0.3125 |   0.7470  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 0 0 1 0 1]   | 0.2500 |   0.7305  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      ">>[1 0 0 1 1 0]   | 0.0000 |   0.7952  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [1 0 0 1 1 1]   | 0.0625 |   0.8198  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [1 0 1 0 0 0]   | 0.3125 |   0.7551  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 0 1 0 0 1]   | 0.2500 |   0.7305  | 1.0000 | 0.0000    |3.7500|3.7500  |  0.0000 \n",
      "  [1 0 1 0 1 0]   | 0.0000 |   0.6638  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [1 0 1 0 1 1]   | 0.0625 |   0.6803  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 0 1 1 0 0]   | 0.3125 |   0.7798  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [1 0 1 1 0 1]   | 0.3750 |   0.7643  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 0 1 1 1 0]   | 0.6250 |   0.8556  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 0 1 1 1 1]   | 0.6875 |   0.8811  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 1 0 0 0 0]   | 0.6250 |   0.8966  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [1 1 0 0 0 1]   | 0.5625 |   0.8710  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 1 0 0 1 0]   | 0.3125 |   0.7798  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [1 1 0 0 1 1]   | 0.2500 |   0.7953  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [1 1 0 1 0 0]   | 0.0000 |   0.6967  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [1 1 0 1 0 1]   | 0.0625 |   0.6803  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 1 0 1 1 0]   | 0.3125 |   0.7470  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 1 0 1 1 1]   | 0.3750 |   0.7716  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 1 1 0 0 0]   | 0.0000 |   0.8443  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 1 1 0 0 1]   | 0.0625 |   0.8198  | 1.0000 | 0.0000    |2.1875|2.1875  |  0.0000 \n",
      "  [1 1 1 0 1 0]   | 0.3125 |   0.7551  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 1 1 0 1 1]   | 0.3750 |   0.7716  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 1 1 1 0 0]   | 0.6250 |   0.8966  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [1 1 1 1 0 1]   | 0.6875 |   0.8811  | 1.0000 | 0.0000    |1.5625|1.5625  |  0.0000 \n",
      "  [1 1 1 1 1 0]   | 0.9375 |   0.9744  | 1.0000 | 0.0000    |0.3125|0.3125  |  0.0000 \n",
      "  [1 1 1 1 1 1]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "\n",
      "!!!p=0.15\n",
      "boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\n",
      "--------------------------------------------------------------------------\n",
      "  [0 0 0 0 0 0]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "  [0 0 0 0 0 1]   | 0.9375 |   0.9687  | 0.9688 | 0.0035    |0.3125|0.0000  |  0.3125 \n",
      "  [0 0 0 0 1 0]   | 0.6875 |   0.8437  | 0.8438 | 0.0003    |1.5625|0.0000  |  1.5625 \n",
      "  [0 0 0 0 1 1]   | 0.6250 |   0.8576  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [0 0 0 1 0 0]   | 0.3750 |   0.7049  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 0 0 1 0 1]   | 0.3125 |   0.6874  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [0 0 0 1 1 0]   | 0.0625 |   0.7551  | 0.9688 | 0.0019    |2.1875|1.8750  |  0.3125 \n",
      "  [0 0 0 1 1 1]   | 0.0000 |   0.7847  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 0 1 0 0 0]   | 0.3750 |   0.7049  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 0 1 0 0 1]   | 0.3125 |   0.6752  | 0.9688 | 0.0020    |3.4375|3.1250  |  0.3125 \n",
      "  [0 0 1 0 1 0]   | 0.0625 |   0.6013  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 0 1 0 1 1]   | 0.0000 |   0.6185  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [0 0 1 1 0 0]   | 0.2500 |   0.7219  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [0 0 1 1 0 1]   | 0.3125 |   0.7074  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [0 0 1 1 1 0]   | 0.5625 |   0.8265  | 0.9688 | 0.0034    |1.5625|1.2500  |  0.3125 \n",
      "  [0 0 1 1 1 1]   | 0.6250 |   0.8576  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [0 1 0 0 0 0]   | 0.6875 |   0.8437  | 0.8438 | 0.0003    |1.5625|0.0000  |  1.5625 \n",
      "  [0 1 0 0 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0036    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 0 0 1 0]   | 0.3750 |   0.6930  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [0 1 0 0 1 1]   | 0.3125 |   0.7074  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [0 1 0 1 0 0]   | 0.0625 |   0.6013  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 1 0 1 0 1]   | 0.0000 |   0.5840  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [0 1 0 1 1 0]   | 0.2500 |   0.6577  | 0.9688 | 0.0020    |3.7500|3.4375  |  0.3125 \n",
      "  [0 1 0 1 1 1]   | 0.3125 |   0.6874  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      ">>[0 1 1 0 0 0]   | 0.0625 |   0.7551  | 0.9688 | 0.0019    |2.1875|1.8750  |  0.3125 \n",
      "  [0 1 1 0 0 1]   | 0.0000 |   0.7255  | 0.9375 | 0.0037    |2.5000|1.8750  |  0.6250 \n",
      ">>[0 1 1 0 1 0]   | 0.2500 |   0.6577  | 0.9688 | 0.0020    |3.7500|3.4375  |  0.3125 \n",
      "  [0 1 1 0 1 1]   | 0.3125 |   0.6752  | 0.9688 | 0.0020    |3.4375|3.1250  |  0.3125 \n",
      "  [0 1 1 1 0 0]   | 0.5625 |   0.8265  | 0.9688 | 0.0034    |1.5625|1.2500  |  0.3125 \n",
      "  [0 1 1 1 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0036    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 1 1 1 0]   | 0.8750 |   0.9375  | 0.9375 | 0.0070    |0.6250|0.0000  |  0.6250 \n",
      "  [0 1 1 1 1 1]   | 0.9375 |   0.9687  | 0.9688 | 0.0035    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 0]   | 0.9375 |   0.9687  | 0.9688 | 0.0035    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 1]   | 0.8750 |   0.9375  | 0.9375 | 0.0070    |0.6250|0.0000  |  0.6250 \n",
      "  [1 0 0 0 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0036    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 0 0 1 1]   | 0.5625 |   0.8265  | 0.9688 | 0.0034    |1.5625|1.2500  |  0.3125 \n",
      "  [1 0 0 1 0 0]   | 0.3125 |   0.6752  | 0.9688 | 0.0020    |3.4375|3.1250  |  0.3125 \n",
      "  [1 0 0 1 0 1]   | 0.2500 |   0.6577  | 0.9688 | 0.0020    |3.7500|3.4375  |  0.3125 \n",
      ">>[1 0 0 1 1 0]   | 0.0000 |   0.7255  | 0.9375 | 0.0037    |2.5000|1.8750  |  0.6250 \n",
      "  [1 0 0 1 1 1]   | 0.0625 |   0.7551  | 0.9688 | 0.0019    |2.1875|1.8750  |  0.3125 \n",
      "  [1 0 1 0 0 0]   | 0.3125 |   0.6874  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 0 1 0 0 1]   | 0.2500 |   0.6577  | 0.9688 | 0.0020    |3.7500|3.4375  |  0.3125 \n",
      "  [1 0 1 0 1 0]   | 0.0000 |   0.5840  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [1 0 1 0 1 1]   | 0.0625 |   0.6013  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 0 1 1 0 0]   | 0.3125 |   0.7074  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [1 0 1 1 0 1]   | 0.3750 |   0.6930  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 0 1 1 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0036    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 1 1 1 1]   | 0.6875 |   0.8437  | 0.8438 | 0.0003    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 0 0 0 0]   | 0.6250 |   0.8576  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [1 1 0 0 0 1]   | 0.5625 |   0.8265  | 0.9688 | 0.0034    |1.5625|1.2500  |  0.3125 \n",
      "  [1 1 0 0 1 0]   | 0.3125 |   0.7074  | 1.0000 | 0.0000    |2.8125|2.8125  |  0.0000 \n",
      "  [1 1 0 0 1 1]   | 0.2500 |   0.7219  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [1 1 0 1 0 0]   | 0.0000 |   0.6185  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [1 1 0 1 0 1]   | 0.0625 |   0.6013  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 1 0 1 1 0]   | 0.3125 |   0.6752  | 0.9688 | 0.0020    |3.4375|3.1250  |  0.3125 \n",
      "  [1 1 0 1 1 1]   | 0.3750 |   0.7049  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 1 1 0 0 0]   | 0.0000 |   0.7847  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 1 1 0 0 1]   | 0.0625 |   0.7551  | 0.9688 | 0.0019    |2.1875|1.8750  |  0.3125 \n",
      "  [1 1 1 0 1 0]   | 0.3125 |   0.6874  | 1.0000 | 0.0000    |3.4375|3.4375  |  0.0000 \n",
      "  [1 1 1 0 1 1]   | 0.3750 |   0.7049  | 1.0000 | 0.0000    |3.1250|3.1250  |  0.0000 \n",
      "  [1 1 1 1 0 0]   | 0.6250 |   0.8576  | 1.0000 | 0.0000    |1.2500|1.2500  |  0.0000 \n",
      "  [1 1 1 1 0 1]   | 0.6875 |   0.8437  | 0.8438 | 0.0003    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 1 1 1 0]   | 0.9375 |   0.9687  | 0.9688 | 0.0035    |0.3125|0.0000  |  0.3125 \n",
      "  [1 1 1 1 1 1]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "\n",
      "!!!p=0.2\n",
      "boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\n",
      "--------------------------------------------------------------------------\n",
      "  [0 0 0 0 0 0]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "  [0 0 0 0 0 1]   | 0.9375 |   0.9688  | 0.9688 | 0.0108    |0.3125|0.0000  |  0.3125 \n",
      "  [0 0 0 0 1 0]   | 0.6875 |   0.8438  | 0.8438 | 0.0282    |1.5625|0.0000  |  1.5625 \n",
      "  [0 0 0 0 1 1]   | 0.6250 |   0.8273  | 0.8438 | 0.0026    |1.2500|0.3125  |  0.9375 \n",
      "  [0 0 0 1 0 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0285    |3.1250|0.0000  |  3.1250 \n",
      "  [0 0 0 1 0 1]   | 0.3125 |   0.6583  | 0.6875 | 0.0157    |3.4375|0.3125  |  3.1250 \n",
      "  [0 0 0 1 1 0]   | 0.0625 |   0.7057  | 0.9688 | 0.0071    |2.1875|1.8750  |  0.3125 \n",
      "  [0 0 0 1 1 1]   | 0.0000 |   0.7333  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 0 1 0 0 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0285    |3.1250|0.0000  |  3.1250 \n",
      "  [0 0 1 0 0 1]   | 0.3125 |   0.6563  | 0.6562 | 0.0329    |3.4375|0.0000  |  3.4375 \n",
      "  [0 0 1 0 1 0]   | 0.0625 |   0.5545  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 0 1 0 1 1]   | 0.0000 |   0.5701  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [0 0 1 1 0 0]   | 0.2500 |   0.6640  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [0 0 1 1 0 1]   | 0.3125 |   0.6753  | 0.8438 | 0.0213    |2.8125|1.2500  |  1.5625 \n",
      "  [0 0 1 1 1 0]   | 0.5625 |   0.7961  | 0.8125 | 0.0126    |1.5625|0.3125  |  1.2500 \n",
      "  [0 0 1 1 1 1]   | 0.6250 |   0.8273  | 0.8438 | 0.0027    |1.2500|0.3125  |  0.9375 \n",
      "  [0 1 0 0 0 0]   | 0.6875 |   0.8438  | 0.8438 | 0.0283    |1.5625|0.0000  |  1.5625 \n",
      "  [0 1 0 0 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0382    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 0 0 1 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0435    |3.1250|0.0000  |  3.1250 \n",
      "  [0 1 0 0 1 1]   | 0.3125 |   0.6753  | 0.8438 | 0.0214    |2.8125|1.2500  |  1.5625 \n",
      "  [0 1 0 1 0 0]   | 0.0625 |   0.5545  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [0 1 0 1 0 1]   | 0.0000 |   0.5389  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [0 1 0 1 1 0]   | 0.2500 |   0.6271  | 0.6562 | 0.0200    |3.7500|0.3125  |  3.4375 \n",
      "  [0 1 0 1 1 1]   | 0.3125 |   0.6583  | 0.6875 | 0.0157    |3.4375|0.3125  |  3.1250 \n",
      ">>[0 1 1 0 0 0]   | 0.0625 |   0.7056  | 0.9688 | 0.0071    |2.1875|1.8750  |  0.3125 \n",
      "  [0 1 1 0 0 1]   | 0.0000 |   0.6780  | 0.9375 | 0.0143    |2.5000|1.8750  |  0.6250 \n",
      ">>[0 1 1 0 1 0]   | 0.2500 |   0.6271  | 0.6562 | 0.0200    |3.7500|0.3125  |  3.4375 \n",
      "  [0 1 1 0 1 1]   | 0.3125 |   0.6563  | 0.6562 | 0.0329    |3.4375|0.0000  |  3.4375 \n",
      "  [0 1 1 1 0 0]   | 0.5625 |   0.7961  | 0.8125 | 0.0126    |1.5625|0.3125  |  1.2500 \n",
      "  [0 1 1 1 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0382    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 1 1 1 0]   | 0.8750 |   0.9375  | 0.9375 | 0.0215    |0.6250|0.0000  |  0.6250 \n",
      "  [0 1 1 1 1 1]   | 0.9375 |   0.9688  | 0.9688 | 0.0108    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 0]   | 0.9375 |   0.9688  | 0.9688 | 0.0108    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 1]   | 0.8750 |   0.9375  | 0.9375 | 0.0215    |0.6250|0.0000  |  0.6250 \n",
      "  [1 0 0 0 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0382    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 0 0 1 1]   | 0.5625 |   0.7961  | 0.8125 | 0.0126    |1.5625|0.3125  |  1.2500 \n",
      "  [1 0 0 1 0 0]   | 0.3125 |   0.6563  | 0.6562 | 0.0329    |3.4375|0.0000  |  3.4375 \n",
      "  [1 0 0 1 0 1]   | 0.2500 |   0.6271  | 0.6562 | 0.0200    |3.7500|0.3125  |  3.4375 \n",
      ">>[1 0 0 1 1 0]   | 0.0000 |   0.6780  | 0.9375 | 0.0143    |2.5000|1.8750  |  0.6250 \n",
      "  [1 0 0 1 1 1]   | 0.0625 |   0.7056  | 0.9688 | 0.0071    |2.1875|1.8750  |  0.3125 \n",
      "  [1 0 1 0 0 0]   | 0.3125 |   0.6583  | 0.6875 | 0.0157    |3.4375|0.3125  |  3.1250 \n",
      "  [1 0 1 0 0 1]   | 0.2500 |   0.6271  | 0.6562 | 0.0200    |3.7500|0.3125  |  3.4375 \n",
      "  [1 0 1 0 1 0]   | 0.0000 |   0.5389  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [1 0 1 0 1 1]   | 0.0625 |   0.5545  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 0 1 1 0 0]   | 0.3125 |   0.6753  | 0.8438 | 0.0214    |2.8125|1.2500  |  1.5625 \n",
      "  [1 0 1 1 0 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0435    |3.1250|0.0000  |  3.1250 \n",
      "  [1 0 1 1 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0382    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 1 1 1 1]   | 0.6875 |   0.8438  | 0.8438 | 0.0283    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 0 0 0 0]   | 0.6250 |   0.8273  | 0.8438 | 0.0027    |1.2500|0.3125  |  0.9375 \n",
      "  [1 1 0 0 0 1]   | 0.5625 |   0.7961  | 0.8125 | 0.0126    |1.5625|0.3125  |  1.2500 \n",
      "  [1 1 0 0 1 0]   | 0.3125 |   0.6753  | 0.8438 | 0.0213    |2.8125|1.2500  |  1.5625 \n",
      "  [1 1 0 0 1 1]   | 0.2500 |   0.6640  | 1.0000 | 0.0000    |2.5000|2.5000  |  0.0000 \n",
      "  [1 1 0 1 0 0]   | 0.0000 |   0.5701  | 1.0000 | 0.0000    |4.3750|4.3750  |  0.0000 \n",
      "  [1 1 0 1 0 1]   | 0.0625 |   0.5545  | 1.0000 | 0.0000    |4.6875|4.6875  |  0.0000 \n",
      "  [1 1 0 1 1 0]   | 0.3125 |   0.6563  | 0.6562 | 0.0329    |3.4375|0.0000  |  3.4375 \n",
      "  [1 1 0 1 1 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0285    |3.1250|0.0000  |  3.1250 \n",
      "  [1 1 1 0 0 0]   | 0.0000 |   0.7333  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 1 1 0 0 1]   | 0.0625 |   0.7057  | 0.9688 | 0.0071    |2.1875|1.8750  |  0.3125 \n",
      "  [1 1 1 0 1 0]   | 0.3125 |   0.6583  | 0.6875 | 0.0157    |3.4375|0.3125  |  3.1250 \n",
      "  [1 1 1 0 1 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0285    |3.1250|0.0000  |  3.1250 \n",
      "  [1 1 1 1 0 0]   | 0.6250 |   0.8273  | 0.8438 | 0.0026    |1.2500|0.3125  |  0.9375 \n",
      "  [1 1 1 1 0 1]   | 0.6875 |   0.8438  | 0.8438 | 0.0282    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 1 1 1 0]   | 0.9375 |   0.9688  | 0.9688 | 0.0108    |0.3125|0.0000  |  0.3125 \n",
      "  [1 1 1 1 1 1]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "\n",
      "!!!p=0.25\n",
      "boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\n",
      "--------------------------------------------------------------------------\n",
      "  [0 0 0 0 0 0]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "  [0 0 0 0 0 1]   | 0.9375 |   0.9688  | 0.9688 | 0.0164    |0.3125|0.0000  |  0.3125 \n",
      "  [0 0 0 0 1 0]   | 0.6875 |   0.8438  | 0.8438 | 0.0491    |1.5625|0.0000  |  1.5625 \n",
      "  [0 0 0 0 1 1]   | 0.6250 |   0.8208  | 0.8438 | 0.0244    |1.2500|0.3125  |  0.9375 \n",
      "  [0 0 0 1 0 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0598    |3.1250|0.0000  |  3.1250 \n",
      "  [0 0 0 1 0 1]   | 0.3125 |   0.6563  | 0.6875 | 0.0433    |3.4375|0.3125  |  3.1250 \n",
      "  [0 0 0 1 1 0]   | 0.0625 |   0.6630  | 0.9688 | 0.0099    |2.1875|1.8750  |  0.3125 \n",
      "  [0 0 0 1 1 1]   | 0.0000 |   0.6877  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 0 1 0 0 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0598    |3.1250|0.0000  |  3.1250 \n",
      "  [0 0 1 0 0 1]   | 0.3125 |   0.6562  | 0.6562 | 0.0652    |3.4375|0.0000  |  3.4375 \n",
      "  [0 0 1 0 1 0]   | 0.0625 |   0.5352  | 0.6875 | 0.0067    |4.6875|1.5625  |  3.1250 \n",
      "  [0 0 1 0 1 1]   | 0.0000 |   0.5437  | 0.3750 | 0.0024    |4.3750|1.8750  |  2.5000 \n",
      "  [0 0 1 1 0 0]   | 0.2500 |   0.6436  | 0.6875 | 0.0244    |2.5000|0.6250  |  1.8750 \n",
      "  [0 0 1 1 0 1]   | 0.3125 |   0.6655  | 0.6875 | 0.0500    |2.8125|0.3125  |  2.5000 \n",
      "  [0 0 1 1 1 0]   | 0.5625 |   0.7896  | 0.8125 | 0.0389    |1.5625|0.3125  |  1.2500 \n",
      "  [0 0 1 1 1 1]   | 0.6250 |   0.8208  | 0.8438 | 0.0244    |1.2500|0.3125  |  0.9375 \n",
      "  [0 1 0 0 0 0]   | 0.6875 |   0.8438  | 0.8438 | 0.0491    |1.5625|0.0000  |  1.5625 \n",
      "  [0 1 0 0 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0637    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 0 0 1 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0757    |3.1250|0.0000  |  3.1250 \n",
      "  [0 1 0 0 1 1]   | 0.3125 |   0.6655  | 0.6875 | 0.0500    |2.8125|0.3125  |  2.5000 \n",
      "  [0 1 0 1 0 0]   | 0.0625 |   0.5352  | 0.6875 | 0.0067    |4.6875|1.5625  |  3.1250 \n",
      "  [0 1 0 1 0 1]   | 0.0000 |   0.5156  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [0 1 0 1 1 0]   | 0.2500 |   0.6251  | 0.6562 | 0.0487    |3.7500|0.3125  |  3.4375 \n",
      "  [0 1 0 1 1 1]   | 0.3125 |   0.6563  | 0.6875 | 0.0433    |3.4375|0.3125  |  3.1250 \n",
      ">>[0 1 1 0 0 0]   | 0.0625 |   0.6630  | 0.9688 | 0.0099    |2.1875|1.8750  |  0.3125 \n",
      "  [0 1 1 0 0 1]   | 0.0000 |   0.6382  | 0.9375 | 0.0200    |2.5000|1.8750  |  0.6250 \n",
      ">>[0 1 1 0 1 0]   | 0.2500 |   0.6251  | 0.6562 | 0.0487    |3.7500|0.3125  |  3.4375 \n",
      "  [0 1 1 0 1 1]   | 0.3125 |   0.6562  | 0.6562 | 0.0652    |3.4375|0.0000  |  3.4375 \n",
      "  [0 1 1 1 0 0]   | 0.5625 |   0.7896  | 0.8125 | 0.0389    |1.5625|0.3125  |  1.2500 \n",
      "  [0 1 1 1 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0637    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 1 1 1 0]   | 0.8750 |   0.9375  | 0.9375 | 0.0327    |0.6250|0.0000  |  0.6250 \n",
      "  [0 1 1 1 1 1]   | 0.9375 |   0.9688  | 0.9688 | 0.0164    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 0]   | 0.9375 |   0.9688  | 0.9688 | 0.0164    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 1]   | 0.8750 |   0.9375  | 0.9375 | 0.0327    |0.6250|0.0000  |  0.6250 \n",
      "  [1 0 0 0 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0637    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 0 0 1 1]   | 0.5625 |   0.7896  | 0.8125 | 0.0389    |1.5625|0.3125  |  1.2500 \n",
      "  [1 0 0 1 0 0]   | 0.3125 |   0.6562  | 0.6562 | 0.0652    |3.4375|0.0000  |  3.4375 \n",
      "  [1 0 0 1 0 1]   | 0.2500 |   0.6251  | 0.6562 | 0.0487    |3.7500|0.3125  |  3.4375 \n",
      ">>[1 0 0 1 1 0]   | 0.0000 |   0.6382  | 0.9375 | 0.0200    |2.5000|1.8750  |  0.6250 \n",
      "  [1 0 0 1 1 1]   | 0.0625 |   0.6630  | 0.9688 | 0.0099    |2.1875|1.8750  |  0.3125 \n",
      "  [1 0 1 0 0 0]   | 0.3125 |   0.6563  | 0.6875 | 0.0433    |3.4375|0.3125  |  3.1250 \n",
      "  [1 0 1 0 0 1]   | 0.2500 |   0.6251  | 0.6562 | 0.0487    |3.7500|0.3125  |  3.4375 \n",
      "  [1 0 1 0 1 0]   | 0.0000 |   0.5156  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [1 0 1 0 1 1]   | 0.0625 |   0.5352  | 0.6875 | 0.0067    |4.6875|1.5625  |  3.1250 \n",
      "  [1 0 1 1 0 0]   | 0.3125 |   0.6655  | 0.6875 | 0.0500    |2.8125|0.3125  |  2.5000 \n",
      "  [1 0 1 1 0 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0757    |3.1250|0.0000  |  3.1250 \n",
      "  [1 0 1 1 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0637    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 1 1 1 1]   | 0.6875 |   0.8438  | 0.8438 | 0.0491    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 0 0 0 0]   | 0.6250 |   0.8208  | 0.8438 | 0.0244    |1.2500|0.3125  |  0.9375 \n",
      "  [1 1 0 0 0 1]   | 0.5625 |   0.7896  | 0.8125 | 0.0389    |1.5625|0.3125  |  1.2500 \n",
      "  [1 1 0 0 1 0]   | 0.3125 |   0.6655  | 0.6875 | 0.0500    |2.8125|0.3125  |  2.5000 \n",
      "  [1 1 0 0 1 1]   | 0.2500 |   0.6436  | 0.6875 | 0.0244    |2.5000|0.6250  |  1.8750 \n",
      "  [1 1 0 1 0 0]   | 0.0000 |   0.5437  | 0.3750 | 0.0024    |4.3750|1.8750  |  2.5000 \n",
      "  [1 1 0 1 0 1]   | 0.0625 |   0.5352  | 0.6875 | 0.0067    |4.6875|1.5625  |  3.1250 \n",
      "  [1 1 0 1 1 0]   | 0.3125 |   0.6562  | 0.6562 | 0.0652    |3.4375|0.0000  |  3.4375 \n",
      "  [1 1 0 1 1 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0598    |3.1250|0.0000  |  3.1250 \n",
      "  [1 1 1 0 0 0]   | 0.0000 |   0.6877  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 1 1 0 0 1]   | 0.0625 |   0.6630  | 0.9688 | 0.0099    |2.1875|1.8750  |  0.3125 \n",
      "  [1 1 1 0 1 0]   | 0.3125 |   0.6563  | 0.6875 | 0.0433    |3.4375|0.3125  |  3.1250 \n",
      "  [1 1 1 0 1 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0598    |3.1250|0.0000  |  3.1250 \n",
      "  [1 1 1 1 0 0]   | 0.6250 |   0.8208  | 0.8438 | 0.0244    |1.2500|0.3125  |  0.9375 \n",
      "  [1 1 1 1 0 1]   | 0.6875 |   0.8438  | 0.8438 | 0.0491    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 1 1 1 0]   | 0.9375 |   0.9688  | 0.9688 | 0.0164    |0.3125|0.0000  |  0.3125 \n",
      "  [1 1 1 1 1 1]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "\n",
      "!!!p=0.3\n",
      "boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\n",
      "--------------------------------------------------------------------------\n",
      "  [0 0 0 0 0 0]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n",
      "  [0 0 0 0 0 1]   | 0.9375 |   0.9687  | 0.9688 | 0.0207    |0.3125|0.0000  |  0.3125 \n",
      "  [0 0 0 0 1 0]   | 0.6875 |   0.8437  | 0.8438 | 0.0651    |1.5625|0.0000  |  1.5625 \n",
      "  [0 0 0 0 1 1]   | 0.6250 |   0.8143  | 0.8438 | 0.0426    |1.2500|0.3125  |  0.9375 \n",
      "  [0 0 0 1 0 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0811    |3.1250|0.0000  |  3.1250 \n",
      "  [0 0 0 1 0 1]   | 0.3125 |   0.6562  | 0.6562 | 0.0632    |3.4375|0.0000  |  3.4375 \n",
      "  [0 0 0 1 1 0]   | 0.0625 |   0.6253  | 0.9688 | 0.0106    |2.1875|1.8750  |  0.3125 \n",
      "  [0 0 0 1 1 1]   | 0.0000 |   0.6463  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [0 0 1 0 0 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0811    |3.1250|0.0000  |  3.1250 \n",
      "  [0 0 1 0 0 1]   | 0.3125 |   0.6562  | 0.6562 | 0.0853    |3.4375|0.0000  |  3.4375 \n",
      "  [0 0 1 0 1 0]   | 0.0625 |   0.5312  | 0.5312 | 0.0163    |4.6875|0.0000  |  4.6875 \n",
      "  [0 0 1 0 1 1]   | 0.0000 |   0.5402  | 0.3750 | 0.0157    |4.3750|1.8750  |  2.5000 \n",
      "  [0 0 1 1 0 0]   | 0.2500 |   0.6324  | 0.6875 | 0.0471    |2.5000|0.6250  |  1.8750 \n",
      "  [0 0 1 1 0 1]   | 0.3125 |   0.6598  | 0.6875 | 0.0714    |2.8125|0.3125  |  2.5000 \n",
      "  [0 0 1 1 1 0]   | 0.5625 |   0.7832  | 0.8125 | 0.0597    |1.5625|0.3125  |  1.2500 \n",
      "  [0 0 1 1 1 1]   | 0.6250 |   0.8143  | 0.8438 | 0.0426    |1.2500|0.3125  |  0.9375 \n",
      "  [0 1 0 0 0 0]   | 0.6875 |   0.8437  | 0.8438 | 0.0651    |1.5625|0.0000  |  1.5625 \n",
      "  [0 1 0 0 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0823    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 0 0 1 0]   | 0.3750 |   0.6875  | 0.6875 | 0.0957    |3.1250|0.0000  |  3.1250 \n",
      "  [0 1 0 0 1 1]   | 0.3125 |   0.6598  | 0.6875 | 0.0714    |2.8125|0.3125  |  2.5000 \n",
      "  [0 1 0 1 0 0]   | 0.0625 |   0.5312  | 0.5312 | 0.0163    |4.6875|0.0000  |  4.6875 \n",
      "  [0 1 0 1 0 1]   | 0.0000 |   0.5051  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [0 1 0 1 1 0]   | 0.2500 |   0.6250  | 0.6250 | 0.0671    |3.7500|0.0000  |  3.7500 \n",
      "  [0 1 0 1 1 1]   | 0.3125 |   0.6562  | 0.6562 | 0.0632    |3.4375|0.0000  |  3.4375 \n",
      ">>[0 1 1 0 0 0]   | 0.0625 |   0.6253  | 0.9688 | 0.0106    |2.1875|1.8750  |  0.3125 \n",
      "  [0 1 1 0 0 1]   | 0.0000 |   0.6042  | 0.9375 | 0.0214    |2.5000|1.8750  |  0.6250 \n",
      ">>[0 1 1 0 1 0]   | 0.2500 |   0.6250  | 0.6250 | 0.0671    |3.7500|0.0000  |  3.7500 \n",
      "  [0 1 1 0 1 1]   | 0.3125 |   0.6562  | 0.6562 | 0.0853    |3.4375|0.0000  |  3.4375 \n",
      "  [0 1 1 1 0 0]   | 0.5625 |   0.7832  | 0.8125 | 0.0597    |1.5625|0.3125  |  1.2500 \n",
      "  [0 1 1 1 0 1]   | 0.6250 |   0.8125  | 0.8125 | 0.0823    |1.8750|0.0000  |  1.8750 \n",
      "  [0 1 1 1 1 0]   | 0.8750 |   0.9375  | 0.9375 | 0.0412    |0.6250|0.0000  |  0.6250 \n",
      "  [0 1 1 1 1 1]   | 0.9375 |   0.9687  | 0.9688 | 0.0207    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 0]   | 0.9375 |   0.9687  | 0.9688 | 0.0207    |0.3125|0.0000  |  0.3125 \n",
      "  [1 0 0 0 0 1]   | 0.8750 |   0.9375  | 0.9375 | 0.0412    |0.6250|0.0000  |  0.6250 \n",
      "  [1 0 0 0 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0823    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 0 0 1 1]   | 0.5625 |   0.7832  | 0.8125 | 0.0597    |1.5625|0.3125  |  1.2500 \n",
      "  [1 0 0 1 0 0]   | 0.3125 |   0.6562  | 0.6562 | 0.0853    |3.4375|0.0000  |  3.4375 \n",
      "  [1 0 0 1 0 1]   | 0.2500 |   0.6250  | 0.6250 | 0.0671    |3.7500|0.0000  |  3.7500 \n",
      ">>[1 0 0 1 1 0]   | 0.0000 |   0.6042  | 0.9375 | 0.0214    |2.5000|1.8750  |  0.6250 \n",
      "  [1 0 0 1 1 1]   | 0.0625 |   0.6253  | 0.9688 | 0.0106    |2.1875|1.8750  |  0.3125 \n",
      "  [1 0 1 0 0 0]   | 0.3125 |   0.6562  | 0.6562 | 0.0632    |3.4375|0.0000  |  3.4375 \n",
      "  [1 0 1 0 0 1]   | 0.2500 |   0.6250  | 0.6250 | 0.0671    |3.7500|0.0000  |  3.7500 \n",
      "  [1 0 1 0 1 0]   | 0.0000 |   0.5051  | 1.0000 | 0.0000    |5.0000|5.0000  |  0.0000 \n",
      "  [1 0 1 0 1 1]   | 0.0625 |   0.5312  | 0.5312 | 0.0163    |4.6875|0.0000  |  4.6875 \n",
      "  [1 0 1 1 0 0]   | 0.3125 |   0.6598  | 0.6875 | 0.0714    |2.8125|0.3125  |  2.5000 \n",
      "  [1 0 1 1 0 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0957    |3.1250|0.0000  |  3.1250 \n",
      "  [1 0 1 1 1 0]   | 0.6250 |   0.8125  | 0.8125 | 0.0823    |1.8750|0.0000  |  1.8750 \n",
      "  [1 0 1 1 1 1]   | 0.6875 |   0.8437  | 0.8438 | 0.0651    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 0 0 0 0]   | 0.6250 |   0.8143  | 0.8438 | 0.0426    |1.2500|0.3125  |  0.9375 \n",
      "  [1 1 0 0 0 1]   | 0.5625 |   0.7832  | 0.8125 | 0.0597    |1.5625|0.3125  |  1.2500 \n",
      "  [1 1 0 0 1 0]   | 0.3125 |   0.6598  | 0.6875 | 0.0714    |2.8125|0.3125  |  2.5000 \n",
      "  [1 1 0 0 1 1]   | 0.2500 |   0.6324  | 0.6875 | 0.0471    |2.5000|0.6250  |  1.8750 \n",
      "  [1 1 0 1 0 0]   | 0.0000 |   0.5402  | 0.3750 | 0.0157    |4.3750|1.8750  |  2.5000 \n",
      "  [1 1 0 1 0 1]   | 0.0625 |   0.5312  | 0.5312 | 0.0163    |4.6875|0.0000  |  4.6875 \n",
      "  [1 1 0 1 1 0]   | 0.3125 |   0.6562  | 0.6562 | 0.0853    |3.4375|0.0000  |  3.4375 \n",
      "  [1 1 0 1 1 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0811    |3.1250|0.0000  |  3.1250 \n",
      "  [1 1 1 0 0 0]   | 0.0000 |   0.6463  | 1.0000 | 0.0000    |1.8750|1.8750  |  0.0000 \n",
      "  [1 1 1 0 0 1]   | 0.0625 |   0.6253  | 0.9688 | 0.0106    |2.1875|1.8750  |  0.3125 \n",
      "  [1 1 1 0 1 0]   | 0.3125 |   0.6562  | 0.6562 | 0.0632    |3.4375|0.0000  |  3.4375 \n",
      "  [1 1 1 0 1 1]   | 0.3750 |   0.6875  | 0.6875 | 0.0811    |3.1250|0.0000  |  3.1250 \n",
      "  [1 1 1 1 0 0]   | 0.6250 |   0.8143  | 0.8438 | 0.0426    |1.2500|0.3125  |  0.9375 \n",
      "  [1 1 1 1 0 1]   | 0.6875 |   0.8437  | 0.8438 | 0.0651    |1.5625|0.0000  |  1.5625 \n",
      "  [1 1 1 1 1 0]   | 0.9375 |   0.9687  | 0.9688 | 0.0207    |0.3125|0.0000  |  0.3125 \n",
      "  [1 1 1 1 1 1]   | 1.0000 |   1.0000  | 1.0000 | 0.0000    |0.0000|0.0000  |  0.0000 \n"
     ]
    }
   ],
   "source": [
    "# n = 5\n",
    "n = 5\n",
    "k = n\n",
    "# p = 0.2\n",
    "pvals = [0, 0.1, 0.15, 0.2, 0.25, 0.3]\n",
    "\n",
    "X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\n",
    "\n",
    "\n",
    "p_x = 1 / (2 ** k) # uniform distribution over x\n",
    "\n",
    "H = np.array(list(itertools.product([0, 1], repeat=n+1)))\n",
    "mine = [\n",
    "    [1,0,0,1,1,0],\n",
    "    [0,1,1,0,1,0],\n",
    "    [0,1,1,0,0,0]\n",
    "]\n",
    "# H = [[0, 1, 0, 0, 0]]\n",
    "\n",
    "for p in pvals:\n",
    "    print()\n",
    "    print(f\"!!!p={p}\")\n",
    "    print(\"boolean signature | imbal? | nacc(fN*) | taccfN* |naccfN*-naccf| S(f) | S(fN*) | S(f) - S(fN*)\")\n",
    "    print(\"--------------------------------------------------------------------------\")\n",
    "    for i, signature in enumerate(H):\n",
    "        sss = \"  \"\n",
    "        if list(signature) in mine:\n",
    "            sss = \">>\"\n",
    "        hash = dict(zip(range(n+1), signature))\n",
    "        func = lambda b: hash[sum(b)]\n",
    "        \n",
    "        # noisy_lookup[row,col] is the JOINT probability Pr(f(z)=row| x=col)\n",
    "        noisy_lookup = np.zeros((2, 2**n))\n",
    "        true_lookup = np.zeros((2, 2**n))\n",
    "        # simulate a noisy dataset essentially\n",
    "        for i, x in enumerate(product([0,1], repeat=k)):\n",
    "            func_value = func(x)\n",
    "            # true lookup is an array with 2 rows; there is a p_x at [row, column] if \n",
    "            # f[column] = row]. so, true_lookup[i, j] = pr(f(x) = i| x=j)\n",
    "            true_lookup[func(x), i] = 1\n",
    "            # iterate over all of the z values that contribute to \n",
    "            for e in product([0, 1], repeat=k):\n",
    "                z = np.array(x) ^ np.array(e)\n",
    "                p_x_given_z = p ** sum(e) * (1-p)**(k - sum(e))\n",
    "                # increment noisy_lookup at the binary index of z\n",
    "                # noisy_lookup[i, j] = pr(f(z) = i,  x=j) \n",
    "                noisy_lookup[func_value, int(''.join(map(str, z)), 2)] += p_x_given_z \n",
    "        \n",
    "        # the function is balanced if the sums of the two rows of true_lookup are equal\n",
    "        imbal = abs(true_lookup[0,:].sum() - true_lookup[1,:].sum())  / 2 ** n\n",
    "        \n",
    "        # if not balanced:\n",
    "        #     continue\n",
    "        # round up to get argmax \n",
    "        noisy_mle = np.round(noisy_lookup)  \n",
    "        out = np.multiply(noisy_mle, true_lookup) / 2 ** n # \"inner product\" of the functions\n",
    "        noiseless_fnstar_acc = out.sum()\n",
    "\n",
    "\n",
    "        fnstar_dct = {}\n",
    "        for i, x in enumerate(X_arr):\n",
    "            fnstar_dct[tuple(x)] = np.argmax(noisy_lookup[:, i])\n",
    "        def fnstar(x):\n",
    "            return fnstar_dct[tuple(x)]\n",
    "        \n",
    "        sensitivity_f = average_sensitivity(func, X_arr)\n",
    "        sensitivity_fnstar = average_sensitivity(fnstar, X_arr)\n",
    "        sensitivity_diff = sensitivity_f - sensitivity_fnstar\n",
    "        # accuracies on dataset\n",
    "        #  = compute_acc_test(fnstar, func, n) # accuracy of fN* on noiseless data\n",
    "\n",
    "        p_zy = generate_noisy_distr(k, p, func)\n",
    "        noisy_f_acc = compute_acc_noisytest(p_zy, func, n) # accuracy of f on noisy data\n",
    "        noisy_fnstar_acc = compute_acc_noisytest(p_zy, fnstar, n) # accuracy of fN* MLE on noisy data\n",
    "        nacc_diff = noisy_fnstar_acc - noisy_f_acc\n",
    "\n",
    "\n",
    "\n",
    "        print(f\"{sss}{signature}   | {imbal:0.4f} |   {noisy_fnstar_acc:1.4f}  | {noiseless_fnstar_acc:1.4f} | {nacc_diff:1.4f}    |{sensitivity_f:1.4f}|{sensitivity_fnstar:1.4f}  |  {sensitivity_diff:1.4f} \")\n",
    "        if sensitivity_fnstar == 0:\n",
    "            topr = sum(noisy_mle[0,:])\n",
    "            botr = sum(noisy_mle[1,:])\n",
    "            assert (np.allclose(topr, 0) or np.allclose(topr, 1 << n))\n",
    "            assert (np.allclose(botr, 0) or np.allclose(botr, 1 << n))\n",
    "\n",
    "    \n",
    "    \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "autobots",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
