{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook searches for counterexamples to the conjecture that $s(f) \\geq s(f_N^*)$\n",
    "\n",
    "\n",
    "### Brute-force for $n \\leq 4$\n",
    "We are looking for a case where \n",
    "$$\n",
    "s(f) < s(g^*)\n",
    "$$\n",
    "where $g^*$ is the maximum likelihood predictor for the noisy data $(Z, f(X))$. We will brute force search over _all_ boolean functions, which can be iterated since there is a one-to-one mapping between $F_n:=\\{f: \\{0,1\\}^n \\rightarrow \\{0,1\\}\\}$ and $\\{0,1\\}^{2^n}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "boolean_function_from_signature() missing 1 required positional argument: 'n'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[3], line 15\u001b[0m\n\u001b[0;32m     11\u001b[0m \u001b[38;5;66;03m# iterate over all boolean functions\u001b[39;00m\n\u001b[0;32m     12\u001b[0m \u001b[38;5;66;03m# idea: do something parallelizable by having a matrix of all boolean functions, and vectorize FT?\u001b[39;00m\n\u001b[0;32m     13\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, f_signature \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(itertools\u001b[38;5;241m.\u001b[39mproduct([\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m], repeat\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m n)):\n\u001b[0;32m     14\u001b[0m     \u001b[38;5;66;03m# compute its fourier transform\u001b[39;00m\n\u001b[1;32m---> 15\u001b[0m     f_dct, f_func \u001b[38;5;241m=\u001b[39m \u001b[43mboolean\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mboolean_function_from_signature\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf_signature\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     18\u001b[0m \u001b[38;5;66;03m# # a \"signature\" of a boolean function is a length 2**n bitstring S where f(x) = S[bin(x)]\u001b[39;00m\n\u001b[0;32m     19\u001b[0m \u001b[38;5;66;03m# sig_arr = np.array(list(itertools.product([0, 1], repeat=2 ** n)))\u001b[39;00m\n\u001b[0;32m     20\u001b[0m \u001b[38;5;66;03m# X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     79\u001b[0m \u001b[38;5;66;03m#             print(\"ERROR: fn* has lower sensitivity than f\")\u001b[39;00m\n\u001b[0;32m     80\u001b[0m \u001b[38;5;66;03m#             raise ValueError\u001b[39;00m\n",
      "\u001b[1;31mTypeError\u001b[0m: boolean_function_from_signature() missing 1 required positional argument: 'n'"
     ]
    }
   ],
   "source": [
    "# n = 5\n",
    "\n",
    "# assert False\n",
    "\n",
    "n = 4\n",
    "assert n <= 4\n",
    "k = n\n",
    "# pvals = [0.01, 0.25, 0.49]\n",
    "pvals = [0.49]\n",
    "\n",
    "# iterate over all boolean functions\n",
    "# idea: do something parallelizable by having a matrix of all boolean functions, and vectorize FT?\n",
    "for i, f_signature in enumerate(itertools.product([0, 1], repeat=2 ** n)):\n",
    "    # compute its fourier transform\n",
    "    f_dct, f_func = boolean.boolean_function_from_signature(f_signature)\n",
    "\n",
    "\n",
    "# # a \"signature\" of a boolean function is a length 2**n bitstring S where f(x) = S[bin(x)]\n",
    "# sig_arr = np.array(list(itertools.product([0, 1], repeat=2 ** n)))\n",
    "# X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\n",
    "# p_x = 1 / (2 ** k) # uniform distribution over x\n",
    "# for i, signature in enumerate(sig_arr):\n",
    "#     # iterate over signatures\n",
    "#     dct, func = boolean.boolean_function_from_signature(signature)\n",
    "#     for p in pvals:        \n",
    "#         # noisy_lookup[row,col] is the JOINT probability Pr(f(z)=row| x=col)\n",
    "#         noisy_lookup = np.zeros((2, 2**n))\n",
    "#         true_lookup = np.zeros((2, 2**n))\n",
    "#         # simulate a noisy dataset essentially\n",
    "#         for i, x in enumerate(product([0,1], repeat=k)):\n",
    "#             func_value = func(x)\n",
    "#             # true lookup is an array with 2 rows; there is a p_x at [row, column] if \n",
    "#             # f[column] = row]. so, true_lookup[i, j] = pr(f(x) = i| x=j)\n",
    "#             true_lookup[func(x), i] = 1\n",
    "#             # iterate over all of the z values that contribute to \n",
    "#             for e in product([0, 1], repeat=k):\n",
    "#                 z = np.array(x) ^ np.array(e)\n",
    "#                 p_x_given_z = p ** sum(e) * (1-p)**(k - sum(e))\n",
    "#                 # increment noisy_lookup at the binary index of z\n",
    "#                 # noisy_lookup[i, j] = pr(f(z) = i,  x=j) \n",
    "#                 noisy_lookup[func_value, int(''.join(map(str, z)), 2)] += p_x_given_z \n",
    "        \n",
    "#         # the function is balanced if the sums of the two rows of true_lookup are equal\n",
    "#         # imbal = abs(true_lookup[0,:].sum() - true_lookup[1,:].sum())  / 2 ** n\n",
    "#         # round up to get argmax \n",
    "#         noisy_mle = np.round(noisy_lookup)  \n",
    "#         out = np.multiply(noisy_mle, true_lookup) / 2 ** n # \"inner product\" of the functions\n",
    "#         diff = out.sum()\n",
    "#         fnstar_dct = {}\n",
    "#         for i, x in enumerate(X_arr):\n",
    "#             fnstar_dct[tuple(x)] = np.argmax(noisy_lookup[:, i])\n",
    "#         def fnstar(x):\n",
    "#             return fnstar_dct[tuple(x)]\n",
    "        \n",
    "#         # compute sensitivity at i of f and fnstar to check for unformity of bound?\n",
    "#         # for kk in range(n):\n",
    "#         #     sensitivity_i_f = boolean.average_s_i(func, kk, X_arr)\n",
    "#         #     sensitivity_i_fnstar = boolean.average_s_i(fnstar, kk, X_arr)\n",
    "#         #     if sensitivity_i_f < sensitivity_i_fnstar:\n",
    "#         #         sensitivity_i_f = boolean.average_s_i(func, kk, X_arr, verbose=True)\n",
    "#         #         sensitivity_i_fnstar = boolean.average_s_i(fnstar, kk, X_arr, verbose=True)\n",
    "#         #         print(f\"ERROR: fn* has higher sensitivity than f at bit {kk}\")\n",
    "#         #         print(true_lookup)\n",
    "#         #         print(noisy_mle)\n",
    "#         #         for k, v in dct.items():\n",
    "#         #             print(k, v)\n",
    "\n",
    "#         #         raise ValueError\n",
    "\n",
    "#         sensitivity_f = boolean.average_sensitivity(func, X_arr)\n",
    "#         sensitivity_fnstar = boolean.average_sensitivity(fnstar, X_arr)\n",
    "#         sensitivity_diff = sensitivity_f - sensitivity_fnstar\n",
    "#         # accuracies on dataset\n",
    "#         p_zy = boolean.generate_noisy_distr(k, p, func)\n",
    "#         noisy_f_acc = boolean.compute_acc_noisytest(p_zy, func, n) # accuracy of f on noisy data\n",
    "#         # noiseless_fnstar_acc = compute_acc_test(fnstar, func, n) # accuracy of fN* on noiseless data\n",
    "#         noisy_fnstar_acc = boolean.compute_acc_noisytest(p_zy, fnstar, n) # accuracy of fN* MLE on noisy data\n",
    "#         if sensitivity_f < sensitivity_fnstar:\n",
    "#             print(\"ERROR: fn* has lower sensitivity than f\")\n",
    "#             raise ValueError\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Seaching for counterexamples to $f_N^* = f$ for parity and majority."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def weight_by_lambda_S(fhat, lambda_func):\n",
    "    \"\"\"\n",
    "    Weight each Fourier coefficient by the size of its corresponding subset.\n",
    "    \n",
    "    Args:\n",
    "        fhat: A length 2^n array of Fourier coefficients\n",
    "        lambda_func: given |S|, returns the weighting function\n",
    "        \n",
    "    Returns:\n",
    "        A length 2^n array with each coefficient multiplied by |S|\n",
    "    \"\"\"\n",
    "    n = int(np.log2(len(fhat)))\n",
    "    assert len(fhat) == 2**n, \"Input array must have length 2^n\"\n",
    "    # Initialize output array\n",
    "    weighted = np.zeros_like(fhat)\n",
    "    \n",
    "    # For each subset S of [n]\n",
    "    for S in range(2**n):\n",
    "        # Convert S to binary representation\n",
    "        S_bin = format(S, f'0{n}b')\n",
    "        S_size = sum(1 for bit in S_bin if bit == '1')\n",
    "        \n",
    "        # Multiply coefficient by |S|\n",
    "        weighted[S] = fhat[S] * lambda_func(S_size)\n",
    "        \n",
    "    return weighted\n",
    "\n",
    "\n",
    "def guide(n):\n",
    "    out = []\n",
    "    for x in range(2**n):\n",
    "        x_bin = format(x, f'0{n}b')\n",
    "        out.append(x_bin)\n",
    "    return out\n",
    "\n",
    "def bool_list(n):\n",
    "    return list(itertools.product([0,1], repeat=n))\n",
    "\n",
    "\n",
    "\n",
    "def guess_ghat(fhat, rho):\n",
    "    \"\"\"This doesn't give a valid boolean function necessarily.\"\"\"\n",
    "    n = int(np.log2(len(fhat)))\n",
    "    ghat = np.zeros(2**n)\n",
    "    denom = 0\n",
    "    for S in range(2**n):\n",
    "        s_sum = sum([int(c) for c in format(S, f'0{n}b')])\n",
    "        ghat[S] = rho ** (s_sum) * fhat[S]\n",
    "        denom += rho ** (2*s_sum) * fhat[S] ** 2\n",
    "    ghat = ghat / np.sqrt(denom)\n",
    "    return ghat\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# a more efficient counterexample search\n",
    "\n",
    "# we will use linear algebra to do fourier transforms in bulk.\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building Walsh-Hadamard matrix...\n",
      "Walsh-Hadamard matrix built in 0.0000 seconds\n",
      "Generating all boolean functions...\n",
      "All boolean functions generated in 0.0000 seconds\n",
      "Computing Fourier transforms...\n",
      "Fourier transforms computed in 0.0112 seconds\n",
      "Building noise operator mask...\n",
      "Noise operator mask built in 0.0010 seconds\n",
      "Applying noise operator...\n",
      "Noise operator applied in 0.0002 seconds\n",
      "Influences computed in 0.0000 seconds\n",
      "all checks passed.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "# print(\"Computing inverse Fourier transforms...\")\n",
    "# t11 = time.time()\n",
    "# Trho_f = H @ Trho_H_f\n",
    "# t12 = time.time()\n",
    "# print(f\"Inverse Fourier transforms computed in {t12-t11:.4f} seconds\")\n",
    "\n",
    "# print(\"Computing signs...\")\n",
    "# t13 = time.time()\n",
    "# all_sgn_Trho_f = np.sign(Trho_f)\n",
    "# t14 = time.time()\n",
    "# print(f\"Signs computed in {t14-t13:.4f} seconds\")\n",
    "\n",
    "# print(\"Checking for counterexamples...\")\n",
    "# t15 = time.time()\n",
    "# counterexample_count = 0\n",
    "# for col in range(all_sgn_Trho_f.shape[1]):\n",
    "#     g = all_sgn_Trho_f[:, col]\n",
    "#     f = all_f[:, col]``\n",
    "#     inf_g = boolean.total_inf(g)\n",
    "#     inf_f = boolean.total_inf(f)\n",
    "#     if inf_g - inf_f > 0:\n",
    "#         print(\"counterexample!\", f)\n",
    "#         counterexample_count += 1\n",
    "# t16 = time.time()\n",
    "# print(f\"Counterexample search completed in {t16-t15:.4f} seconds\")\n",
    "# print(f\"Found {counterexample_count} counterexamples\")\n",
    "\n",
    "# end = time.time()\n",
    "# print(f\"Total time: {end - start:.4f} seconds\")\n",
    "# prod_noisy = prod * trho_mask\n",
    "# unfortunatly this still takes 3 days to check n=5; for n=2 we get .001 s for 2^(2^2) checks, = 6.25*10^-5 seconds per check, times 2^(2^5) checks for n=5 = 3 days...\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3.1875\n",
      "Influences computed in 0.3965 seconds\n",
      "all checks passed.\n"
     ]
    }
   ],
   "source": [
    "n = 8\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# # generate some random ltf signatures\n",
    "# a = np.random.randint(0, 200, size=n+1)\n",
    "\n",
    "\n",
    "# generate random functions and get their fourier xform\n",
    "start = time.time()\n",
    "n_samples = 10000\n",
    "sampled_f = np.random.choice([-1, 1], size=(2**n, n_samples))\n",
    "# for doing the entire set of functions, do not use for n > 4\n",
    "# sampled_f = np.array(list(itertools.product([-1, 1], repeat=2**n))).T\n",
    "\n",
    "# sampled_f = []\n",
    "# def maj(x):\n",
    "#     return np.sign(np.sum(x))\n",
    "# for x in itertools.product([-1, 1], repeat=n):\n",
    "#     sampled_f.append(maj(x))\n",
    "# sampled_f = np.array(sampled_f).reshape(-1, 1)\n",
    "\n",
    "\n",
    "H = walsh_hadamard_matrix(n)\n",
    "sampled_fhat = H @ sampled_f\n",
    "inf_f = compute_influence_fourier(sampled_fhat)\n",
    "# show a counter on inf_f\n",
    "# from collections import Counter\n",
    "# counter = Counter(inf_f)\n",
    "# print(counter)\n",
    "\n",
    "# build sgn(Trho(f)) and compute its influence\n",
    "rho = .56\n",
    "inf_g = compute_sgn_Trho_f_influence(sampled_fhat, rho)\n",
    "# counter = Counter(inf_g)\n",
    "# print(counter)\n",
    "# assert np.allclose(sampled_fhat, sampled_ghat)\n",
    "\n",
    "print(max(inf_f - inf_g))\n",
    "end = time.time()\n",
    "print(f\"Influences computed in {end - start:.4f} seconds\")\n",
    "assert np.all((inf_f - inf_g) >= 0)\n",
    "print(\"all checks passed.\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f: (-1, -1, -1, 1, -1, 1, 0, 1, -1, 0, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, 1, 0, 1, -1, 1, 1, 1)\n",
      "g: [-1. -1. -1.  1. -1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1.  1. -1. -1.\n",
      " -1.  1. -1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1.  1.]\n",
      "time: 0.019031047821044922\n",
      "sens_f: 2.0625\n",
      "sens_g: 1.5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "def ltf(x, a0, a1):\n",
    "    return int(np.sign(a0 + np.dot(x, a1)))\n",
    "\n",
    "# a balanced LTF for n=11\n",
    "# a0 = 0\n",
    "# a1 = (13, 43, 67, 67, 67, 117, 153, 165, 165, 179, 179)\n",
    "# a0 = 0\n",
    "# a1 = (2, 1, 1, 1)\n",
    "a0 = 1\n",
    "a1 = (1, 2, 3, 4, 5)\n",
    "n = len(a1)\n",
    "\n",
    "f = []\n",
    "for x in itertools.product([-1, 1], repeat=n):\n",
    "    f.append(ltf(x, a0, a1))\n",
    "f = tuple(f)\n",
    "\n",
    "start = time.time()\n",
    "# FIXME: rigth now, i compute sgn(trho(f)) in the fourier domain, not sure if that is actually efficient or not. \n",
    "# For large batches it might be efficient...\n",
    "rho = .45\n",
    "fhat = boolean.boolean_fourier_transform(f)\n",
    "trhof_hat = boolean.noise_operator_on(fhat, rho, input_fourier=True, return_fourier=True)\n",
    "trho_f = boolean.inverse_boolean_fourier_transform(trhof_hat)\n",
    "g = np.sign(trho_f)\n",
    "\n",
    "print(\"f:\", f)\n",
    "print(\"g:\", g)\n",
    "\n",
    "sens_f = boolean.total_inf(f)\n",
    "sens_g = boolean.total_inf(g)\n",
    "end = time.time()\n",
    "print(\"time:\", end - start)\n",
    "print(\"sens_f:\", sens_f)\n",
    "print(\"sens_g:\", sens_g)\n",
    "\n",
    "\n",
    "# f = (1,-1, -1, -1, -1, -1, 1, -1)\n",
    "# n = int(np.log2(len(f)))\n",
    "# print(guide(n))\n",
    "# xvals = np.array(bool_list(n)) * 2 - 1\n",
    "\n",
    "# print(f)\n",
    "# rho = .02\n",
    "# p_bitflip = (1 - rho) / 2\n",
    "\n",
    "# S_weights = np.array([sum([int(i) for i in s]) for s in guide(n)])\n",
    "# S_by_weight = np.argsort(S_weights)\n",
    "\n",
    "# _, _, g_func = boolean.compute_fnstar_err_sens(f, p_bitflip, signature_type=\"all\", signature_signed=True)\n",
    "# g = [int(g_func(x)) for x in bool_list(n)]\n",
    "# print(\"f:\", f)\n",
    "# print(\"g:\", g)\n",
    "# fhat = boolean.boolean_fourier_transform(f)\n",
    "# ghat = boolean.boolean_fourier_transform(g)\n",
    "# print(\"fhat:\", fhat)\n",
    "# print(\"ghat:\", ghat)\n",
    "# trhof_hat = boolean.noise_operator_on(fhat, rho, input_fourier=True, return_fourier=True)\n",
    "# trho_f = boolean.inverse_boolean_fourier_transform(trhof_hat)\n",
    "\n",
    "# print(\"T_rho_f_hat:\", trhof_hat)\n",
    "# print(\"T_rho(f):\", trho_f)\n",
    "\n",
    "# g_maybe = np.sign(trho_f)\n",
    "# print(\"=g?\",g_maybe)\n",
    "\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.7722161664\n"
     ]
    }
   ],
   "source": [
    "print(min(f_influences - g_influences))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.30448337, 0.23790028, 0.28384501, 0.25401338, 0.27978237,\n",
       "       0.25123264, 0.29622459, 0.27199464, 0.25802301, 0.27833354,\n",
       "       0.28614691, 0.27251038, 0.27439829, 0.26909267, 0.2666106 ,\n",
       "       0.26876362, 0.23088781, 0.26074889, 0.25886327, 0.27227496,\n",
       "       0.23771108, 0.30651288, 0.28411603, 0.2483439 , 0.26875881,\n",
       "       0.2469599 , 0.26735753, 0.23637999, 0.26615679, 0.26704046])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(f_influences - g_influences)[:30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "74.56540444444444"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((.001 / 2**4) * 2 ** 32) / 3600"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Investigating LTFs "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def bool_list(n):\n",
    "    return list(itertools.product([0,1], repeat=n))\n",
    "\n",
    "def bool_list_pm(n):\n",
    "    return list(itertools.product([-1, 1], repeat=n))\n",
    "\n",
    "def guess_ghat(fhat, rho):\n",
    "    \"\"\"This doesn't give a valid boolean function necessarily.\"\"\"\n",
    "    n = int(np.log2(len(fhat)))\n",
    "    ghat = np.zeros(2**n)\n",
    "    denom = 0\n",
    "    for S in range(2**n):\n",
    "        s_sum = sum([int(c) for c in format(S, f'0{n}b')])\n",
    "        ghat[S] = rho ** (s_sum) * fhat[S]\n",
    "        denom += rho ** (2*s_sum) * fhat[S] ** 2\n",
    "    ghat = ghat / np.sqrt(denom)\n",
    "    return ghat\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f: (-1, -1, -1, 1, -1, 1, 0, 1, -1, 0, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, 1, 0, 1, -1, 1, 1, 1)\n",
      "g: [-1. -1. -1.  1. -1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1.  1. -1. -1.\n",
      " -1.  1. -1.  1.  1.  1. -1.  1. -1.  1. -1.  1.  1.  1.]\n",
      "time: 0.019190073013305664\n",
      "sens_f: 2.0625\n",
      "sens_g: 1.5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "def ltf(x, a0, a1):\n",
    "    return int(np.sign(a0 + np.dot(x, a1)))\n",
    "\n",
    "# a balanced LTF for n=11\n",
    "# a0 = 0\n",
    "# a1 = (13, 43, 67, 67, 67, 117, 153, 165, 165, 179, 179)\n",
    "# a0 = 0\n",
    "# a1 = (2, 1, 1, 1)\n",
    "a0 = 1\n",
    "a1 = (1, 2, 3, 4, 5)\n",
    "n = len(a1)\n",
    "\n",
    "f = []\n",
    "for x in bool_list_pm(n):\n",
    "    f.append(ltf(x, a0, a1))\n",
    "f = tuple(f)\n",
    "\n",
    "start = time.time()\n",
    "# FIXME: rigth now, i compute sgn(trho(f)) in the fourier domain, not sure if that is actually efficient or not. \n",
    "# For large batches it might be efficient...\n",
    "rho = .45\n",
    "fhat = boolean.boolean_fourier_transform(f)\n",
    "trhof_hat = boolean.noise_operator_on(fhat, rho, input_fourier=True, return_fourier=True)\n",
    "trho_f = boolean.inverse_boolean_fourier_transform(trhof_hat)\n",
    "g = np.sign(trho_f)\n",
    "\n",
    "print(\"f:\", f)\n",
    "print(\"g:\", g)\n",
    "\n",
    "sens_f = boolean.total_inf(f)\n",
    "sens_g = boolean.total_inf(g)\n",
    "end = time.time()\n",
    "print(\"time:\", end - start)\n",
    "print(\"sens_f:\", sens_f)\n",
    "print(\"sens_g:\", sens_g)\n",
    "\n",
    "\n",
    "# f = (1,-1, -1, -1, -1, -1, 1, -1)\n",
    "# n = int(np.log2(len(f)))\n",
    "# print(guide(n))\n",
    "# xvals = np.array(bool_list(n)) * 2 - 1\n",
    "\n",
    "# print(f)\n",
    "# rho = .02\n",
    "# p_bitflip = (1 - rho) / 2\n",
    "\n",
    "# S_weights = np.array([sum([int(i) for i in s]) for s in guide(n)])\n",
    "# S_by_weight = np.argsort(S_weights)\n",
    "\n",
    "# _, _, g_func = boolean.compute_fnstar_err_sens(f, p_bitflip, signature_type=\"all\", signature_signed=True)\n",
    "# g = [int(g_func(x)) for x in bool_list(n)]\n",
    "# print(\"f:\", f)\n",
    "# print(\"g:\", g)\n",
    "# fhat = boolean.boolean_fourier_transform(f)\n",
    "# ghat = boolean.boolean_fourier_transform(g)\n",
    "# print(\"fhat:\", fhat)\n",
    "# print(\"ghat:\", ghat)\n",
    "# trhof_hat = boolean.noise_operator_on(fhat, rho, input_fourier=True, return_fourier=True)\n",
    "# trho_f = boolean.inverse_boolean_fourier_transform(trhof_hat)\n",
    "\n",
    "# print(\"T_rho_f_hat:\", trhof_hat)\n",
    "# print(\"T_rho(f):\", trho_f)\n",
    "\n",
    "# g_maybe = np.sign(trho_f)\n",
    "# print(\"=g?\",g_maybe)\n",
    "\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fhat norm: 1.004987562112089\n",
      "hhat norm: 0.8717797887081347\n",
      "f: [-1.  1. -1.  1.  1.  1.  1.  1.]\n",
      "h: [-1. -1. -1.  1. -1.  1.  1.  1.]\n",
      "i: 0\n",
      "Inf_i[f]: 0.5\n",
      "Inf_i[h]: 0.5\n",
      "Inf_i[g]: 0.5\n",
      "\n",
      "i: 1\n",
      "Inf_i[f]: 0.0\n",
      "Inf_i[h]: 0.5\n",
      "Inf_i[g]: 0.5\n",
      "\n",
      "i: 2\n",
      "Inf_i[f]: 0.5\n",
      "Inf_i[h]: 0.5\n",
      "Inf_i[g]: 0.5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "fhat = (0.1, 0.5, 0.5, 0, 0.5, -0.5, 0, 0)\n",
    "hhat = (0.1, 0.5, 0.5, 0, 0.5, 0, 0, 0)\n",
    "# hhat =  boolean.noise_operator_on(fhat, 1, input_fourier=True, return_fourier=True)\n",
    "print(\"fhat norm:\", np.linalg.norm(fhat))\n",
    "print(\"hhat norm:\", np.linalg.norm(hhat))\n",
    "\n",
    "f = np.sign(boolean.inverse_boolean_fourier_transform(fhat))\n",
    "h = np.sign(boolean.inverse_boolean_fourier_transform(hhat))\n",
    "g = np.sign(h)\n",
    "print(\"f:\", f)\n",
    "print(\"h:\", h)\n",
    "\n",
    "for i in range(3):\n",
    "    print(f\"i: {i}\")\n",
    "    print(f\"Inf_i[f]: {boolean.inf_i(f, i)}\")\n",
    "    print(f\"Inf_i[h]: {boolean.inf_i(h, i)}\")\n",
    "    print(f\"Inf_i[g]: {boolean.inf_i(g, i)}\")\n",
    "    print()\n",
    "\n",
    "# print(\"fhat:\", fhat)\n",
    "# print(\"hhat:\", hhat)\n",
    "# trhof_hat = boolean.noise_operator_on(fhat, rho, input_fourier=True, return_fourier=True)\n",
    "# trho_f = boolean.inverse_boolean_fourier_transform(trhof_hat)\n",
    "\n",
    "# print(\"T_rho_f_hat:\", trhof_hat)\n",
    "# print(\"T_rho(f):\", trho_f)\n",
    "\n",
    "# g_maybe = np.sign(trho_f)\n",
    "# print(\"=g?\",g_maybe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 1. 1. 1. 1. 1. 1. 1.]\n",
      "[-1.  1. -1.  1. -1.  1. -1.  1.]\n",
      "[-1. -1.  1.  1. -1. -1.  1.  1.]\n",
      "[ 1. -1. -1.  1.  1. -1. -1.  1.]\n",
      "[-1. -1. -1. -1.  1.  1.  1.  1.]\n",
      "[ 1. -1.  1. -1. -1.  1. -1.  1.]\n",
      "[ 1.  1. -1. -1. -1. -1.  1.  1.]\n",
      "[-1.  1.  1. -1.  1. -1. -1.  1.]\n"
     ]
    }
   ],
   "source": [
    "for s, s_bin in enumerate(itertools.product([0, 1], repeat=3)):\n",
    "    print(boolean.chi(s_bin, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "signature= [0, 0, 1, 1]\n",
      "noisy lookup table\n",
      "[[1. 1. 1. 0. 1. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 1. 1. 1.]]\n",
      "true lookup table\n",
      "[[1. 1. 1. 0. 1. 0. 0. 0.]\n",
      " [0. 0. 0. 1. 0. 1. 1. 1.]]\n",
      "\n",
      "signature= [0, 0, 0, 1, 1, 1]\n",
      "noisy lookup table\n",
      "[[1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 0. 0.\n",
      "  1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 1. 1.\n",
      "  0. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "true lookup table\n",
      "[[1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 0. 0.\n",
      "  1. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 1. 1.\n",
      "  0. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "\n",
      "signature= [0, 0, 0, 0, 1, 1, 1, 1]\n",
      "noisy lookup table\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      "  0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1.\n",
      "  1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "true lookup table\n",
      "[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      "  1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.\n",
      "  0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1.\n",
      "  0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1.\n",
      "  1. 1. 1. 1. 1. 1. 1. 1.]]\n",
      "\n",
      "signature= [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n",
      "noisy lookup table\n",
      "[[1. 1. 1. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 1. 1. 1.]]\n",
      "true lookup table\n",
      "[[1. 1. 1. ... 0. 0. 0.]\n",
      " [0. 0. 0. ... 1. 1. 1.]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# n = 3 # total number of bits\n",
    "p = 0.49# probability of each bit of X flipping, this DOESN'T flip Y ever.\n",
    "# signature = [0, 1, 0, 1]\n",
    "\n",
    "\n",
    "# signatures = [itertools.product([0, 1], repeat=n+1)]\n",
    "# signatures = [signature]\n",
    "signatures = [majority_signature(n) for n in [3, 5, 7, 9]]\n",
    "for signature in signatures:\n",
    "    n = len(signature) - 1\n",
    "    X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\n",
    "    \n",
    "    print(\"signature=\", signature)\n",
    "    hash = dict(zip(range(n+1), signature))\n",
    "    func = lambda b: hash[sum(b)]\n",
    "\n",
    "    f_accs = []\n",
    "    fn_accs = []\n",
    "    # noisy_lookup[row,col] is the JOINT probability Pr(f(z)=row| x=col)\n",
    "    noisy_lookup = np.zeros((2, 2**n))\n",
    "    true_lookup = np.zeros((2, 2**n))\n",
    "    # simulate a noisy dataset essentially\n",
    "    for i, x in enumerate(X_arr):\n",
    "        func_value = func(x) # compute y=f(x)\n",
    "        # true lookup is an array with 2 rows; there is a p_x at [row, column] if \n",
    "        # f[column] = row]. so, true_lookup[i, j] = pr(f(x) = i| x=j)\n",
    "        true_lookup[func(x), i] = 1\n",
    "        # iterate over all of the z values that contribute to \n",
    "        for e in product([0, 1], repeat=n):\n",
    "            z = np.array(x) ^ np.array(e)\n",
    "            p_x_given_z = p ** sum(e) * (1-p)**(n - sum(e))\n",
    "            # increment noisy_lookup at the binary index of z\n",
    "            # noisy_lookup[i, j] = pr(f(z) = i,  x=j) \n",
    "            noisy_lookup[func_value, int(''.join(map(str, z)), 2)] += p_x_given_z \n",
    "\n",
    "    # # round up to get argmax \n",
    "    noisy_mle = np.round(noisy_lookup)  \n",
    "    print(\"noisy lookup table\")\n",
    "    print(noisy_mle)\n",
    "    print(\"true lookup table\")\n",
    "    print(true_lookup)\n",
    "    assert(np.all(noisy_mle == true_lookup))\n",
    "\n",
    "    # print(noisy_mle)\n",
    "    print()\n",
    "\n",
    "    # # the function is balanced if the sums of the two rows of true_lookup are equal\n",
    "    # # GOAL: f should not be too imbalanced <=> `imbal` should be close to 0\n",
    "    # imbal = abs(true_lookup[0,:].sum() - true_lookup[1,:].sum())  / 2 ** n\n",
    "\n",
    "    # # if not balanced:\n",
    "    # #     continue\n",
    "    # out = np.multiply(noisy_mle, true_lookup) / 2 ** n # \"inner product\" of the functions\n",
    "    # diff = out.sum()\n",
    "\n",
    "\n",
    "    # fnstar_dct = {}\n",
    "    # for i, x in enumerate(X_arr):\n",
    "    #     fnstar_dct[tuple(x)] = np.argmax(noisy_lookup[:, i])\n",
    "    # def fnstar(x):\n",
    "    #     return fnstar_dct[tuple(x)]\n",
    "\n",
    "    # sensitivity_f = boolean.average_sensitivity(func, X_arr)\n",
    "    # sensitivity_fnstar = boolean.average_sensitivity(fnstar, X_arr)\n",
    "    # sensitivity_diff = sensitivity_f - sensitivity_fnstar\n",
    "    # # accuracies on dataset\n",
    "    # p_zy = boolean.generate_noisy_distr(n, p, func)\n",
    "    # noisy_f_acc = boolean.compute_acc_noisytest(p_zy, func, n) # accuracy of f on noisy data\n",
    "    # noiseless_fnstar_acc = boolean.compute_acc_test(fnstar, func, n) # accuracy of fN* on noiseless data\n",
    "    # noisy_fnstar_acc = boolean.compute_acc_noisytest(p_zy, fnstar, n) # accuracy of fN* MLE on noisy data\n",
    "\n",
    "    # acc_diff = noisy_fnstar_acc - noisy_f_acc\n",
    "    # f_accs.append(noisy_f_acc)\n",
    "    # fn_accs.append(noisy_fnstar_acc)\n",
    "\n",
    "    # print(\"\\t noisy fN* acc=\", fn_accs)\n",
    "    # print(\"\\t noisy fN* sensitivity=\", sensitivity_fnstar)\n",
    "\n",
    "    # print(\"\\t noisy f acc=\", f_accs)\n",
    "    # print(\"\\t noisy f sensitivity=\", sensitivity_f)\n",
    "\n",
    "    # print((fn_accs[0], sensitivity_fnstar), (f_accs[0], sensitivity_f))\n",
    "    # print()b\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Looking for $\\text{Pr}_{X|Z=z}(f(X) = f(z))$\n",
    "\n",
    "This is to test our theory work on the majority function being optimal for noisy prediction. We compute (numerically) noisy majority performance, and compare it to an analytical formula."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# signatures = [itertools.product([0, 1], repeat=n+1)]\n",
    "# signatures = [signature]\n",
    "pvals = np.linspace(0.01, 0.5, 30)\n",
    "signatures = [majority_signature(n) for n in [7]]\n",
    "all_given_z = np.zeros((len(pvals), 2 ** 7))\n",
    "for kk, p in enumerate(pvals):\n",
    "    for signature in signatures:\n",
    "        n = len(signature) - 1\n",
    "        X_arr = np.array(list(itertools.product([0, 1], repeat=n)))\n",
    "        hash = dict(zip(range(n+1), signature))\n",
    "        func = lambda b: hash[sum(b)]\n",
    "\n",
    "        f_accs = []\n",
    "        fn_accs = []\n",
    "        # noisy_lookup[row,col] is the JOINT probability Pr(f(z)=row| x=col)\n",
    "        noisy_lookup = np.zeros((2, 2**n))\n",
    "        true_lookup = np.zeros((2, 2**n))\n",
    "\n",
    "        given_z = np.zeros(2**n)\n",
    "        # simulate a noisy dataset essentially\n",
    "        for i, x in enumerate(X_arr):\n",
    "            func_value = func(x) # compute y=f(x)\n",
    "            # true lookup is an array with 2 rows; there is a p_x at [row, column] if \n",
    "            # f[column] = row]. so, true_lookup[i, j] = pr(f(x) = i| x=j)\n",
    "            true_lookup[func_value, i] = 1\n",
    "            # iterate over all of the z values that contribute to \n",
    "            for e in product([0, 1], repeat=n):\n",
    "                z = np.array(x) ^ np.array(e)\n",
    "                p_x_given_z = p ** sum(e) * (1-p)**(n - sum(e))\n",
    "                # increment noisy_lookup at the binary index of z\n",
    "                # noisy_lookup[i, j] = pr(f(z) = i,  x=j) \n",
    "                z_idx = int(''.join(map(str, z)), 2)\n",
    "                noisy_lookup[func_value, z_idx] += p_x_given_z \n",
    "\n",
    "                f_z = func(z)\n",
    "                if f_z == func_value:\n",
    "                    given_z[z_idx] += p_x_given_z\n",
    "        all_given_z[kk] = given_z\n",
    "\n",
    "\n",
    "        # # # round up to get argmax \n",
    "        # noisy_mle = np.round(noisy_lookup)  \n",
    "        # print(\"noisy lookup table\")\n",
    "        # print(noisy_mle)\n",
    "        # print(\"true lookup table\")\n",
    "        # print(true_lookup)\n",
    "        # assert(np.all(noisy_mle == true_lookup))\n",
    "\n",
    "        # # print(noisy_mle)\n",
    "        # print()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
