{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### New generation scheme\n",
    "\n",
    "Here we will generate counterexample _datasets_ by first using a proposed counterexample signature where $P_{X,Z}(f_N^*(z)=f(x))$ is close to $P_{X,Z}(f(z)=f(x))$, and then we fine-tune this by generate a dataset where the optimal test accuracy of $f_N^*(z)$ is _also_ close to f (i.e. we postselect for validation sets that are very \"average\")."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# autoreload magic\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindreadingautobots.sequence_generators import make_datasets, data_io\n",
    "import numpy as np\n",
    "from mindreadingautobots.entropy_and_bayesian import boolean\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f TRUE error: 0.379958879999999, f sensitivity: 3.5\n",
      "fnstar TRUE error: 0.37608843999999875, fnstar sensitivity: 2.625\n",
      "num validation= 10000\n",
      "fnstar true err |   fnstar val err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3761           0.3676           0.3734       0.3561                 0.0173                 -0.0027   2468\n",
      "    0.3761           0.3791           0.3823       0.3708                 0.0115                  0.0062   2469\n",
      "    0.3761           0.3858           0.3860       0.3700                 0.0160                  0.0099   2470\n",
      "    0.3761           0.3762           0.3856       0.3693                 0.0163                  0.0095   2471\n",
      "    0.3761           0.3734           0.3707       0.3580                 0.0127                 -0.0054   2472\n",
      "    0.3761           0.3793           0.3822       0.3677                 0.0145                  0.0061   2473\n",
      "    0.3761           0.3873           0.3865       0.3759                 0.0106                  0.0104   2474\n",
      "    0.3761           0.3766           0.3781       0.3640                 0.0141                  0.0020   2475\n",
      "    0.3761           0.3725           0.3757       0.3624                 0.0133                 -0.0004   2476\n",
      "    0.3761           0.3785           0.3877       0.3688                 0.0189                  0.0116   2477\n",
      "    0.3761           0.3755           0.3834       0.3631                 0.0203                  0.0073   2478\n",
      "    0.3761           0.3772           0.3786       0.3675                 0.0111                  0.0025   2479\n",
      "    0.3761           0.3760           0.3820       0.3678                 0.0142                  0.0059   2480\n",
      "    0.3761           0.3817           0.3910       0.3707                 0.0203                  0.0149   2481\n",
      "    0.3761           0.3713           0.3773       0.3605                 0.0168                  0.0012   2482\n",
      "    0.3761           0.3772           0.3824       0.3668                 0.0156                  0.0063   2483\n",
      "    0.3761           0.3762           0.3822       0.3681                 0.0141                  0.0061   2484\n",
      "    0.3761           0.3662           0.3681       0.3569                 0.0112                 -0.0080   2485\n",
      "    0.3761           0.3817           0.3816       0.3661                 0.0155                  0.0055   2486\n",
      "    0.3761           0.3716           0.3717       0.3580                 0.0137                 -0.0044   2487\n",
      "    0.3761           0.3689           0.3677       0.3582                 0.0095                 -0.0084   2488\n",
      "    0.3761           0.3759           0.3833       0.3673                 0.0160                  0.0072   2489\n",
      "    0.3761           0.3745           0.3778       0.3614                 0.0164                  0.0017   2490\n",
      "    0.3761           0.3810           0.3817       0.3666                 0.0151                  0.0056   2491\n",
      "    0.3761           0.3719           0.3726       0.3620                 0.0106                 -0.0035   2492\n",
      "Seed with lowest diff: 2488 with diff=0.009499999999998954\n",
      "num validation= 20000\n",
      "fnstar true err |   fnstar val err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3761           0.3667           0.3712       0.3626                 0.0086                 -0.0049   2468\n",
      "    0.3761           0.3770           0.3785       0.3712                 0.0072                  0.0024   2469\n",
      "    0.3761           0.3779           0.3811       0.3725                 0.0085                  0.0050   2470\n",
      "    0.3761           0.3745           0.3765       0.3698                 0.0067                  0.0004   2471\n",
      "    0.3761           0.3726           0.3780       0.3685                 0.0095                  0.0019   2472\n",
      "    0.3761           0.3774           0.3823       0.3725                 0.0097                  0.0062   2473\n",
      "    0.3761           0.3764           0.3814       0.3725                 0.0088                  0.0053   2474\n",
      "    0.3761           0.3770           0.3826       0.3720                 0.0106                  0.0065   2475\n",
      "    0.3761           0.3755           0.3761       0.3690                 0.0070                 -0.0000   2476\n",
      "    0.3761           0.3774           0.3801       0.3717                 0.0084                  0.0040   2477\n",
      "    0.3761           0.3692           0.3744       0.3660                 0.0083                 -0.0017   2478\n",
      "    0.3761           0.3809           0.3819       0.3711                 0.0107                  0.0058   2479\n",
      "    0.3761           0.3737           0.3748       0.3678                 0.0070                 -0.0013   2480\n",
      "    0.3761           0.3715           0.3742       0.3677                 0.0065                 -0.0019   2481\n",
      "    0.3761           0.3742           0.3779       0.3694                 0.0085                  0.0018   2482\n",
      "    0.3761           0.3764           0.3786       0.3713                 0.0073                  0.0025   2483\n",
      "    0.3761           0.3771           0.3813       0.3719                 0.0094                  0.0052   2484\n",
      "    0.3761           0.3746           0.3758       0.3696                 0.0062                 -0.0003   2485\n",
      "    0.3761           0.3766           0.3791       0.3693                 0.0098                  0.0030   2486\n",
      "    0.3761           0.3797           0.3812       0.3741                 0.0070                  0.0051   2487\n",
      "    0.3761           0.3729           0.3738       0.3685                 0.0053                 -0.0023   2488\n",
      "    0.3761           0.3756           0.3808       0.3712                 0.0096                  0.0047   2489\n",
      "    0.3761           0.3739           0.3780       0.3709                 0.0071                  0.0019   2490\n",
      "    0.3761           0.3761           0.3796       0.3692                 0.0104                  0.0035   2491\n",
      "    0.3761           0.3754           0.3778       0.3702                 0.0076                  0.0017   2492\n",
      "Seed with lowest diff: 2488 with diff=0.005349999999999411\n",
      "num validation= 30000\n",
      "fnstar true err |   fnstar val err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3761           0.3816           0.3863       0.3794                 0.0069                  0.0102   2468\n",
      "    0.3761           0.3783           0.3782       0.3745                 0.0037                  0.0021   2469\n",
      "    0.3761           0.3772           0.3794       0.3735                 0.0058                  0.0033   2470\n",
      "    0.3761           0.3768           0.3828       0.3749                 0.0079                  0.0067   2471\n",
      "    0.3761           0.3720           0.3750       0.3703                 0.0047                 -0.0011   2472\n",
      "    0.3761           0.3734           0.3771       0.3714                 0.0057                  0.0010   2473\n",
      "    0.3761           0.3736           0.3775       0.3714                 0.0060                  0.0014   2474\n",
      "    0.3761           0.3746           0.3809       0.3723                 0.0086                  0.0048   2475\n",
      "    0.3761           0.3701           0.3755       0.3686                 0.0069                 -0.0006   2476\n",
      "    0.3761           0.3796           0.3824       0.3769                 0.0055                  0.0063   2477\n",
      "    0.3761           0.3713           0.3785       0.3692                 0.0093                  0.0024   2478\n",
      "    0.3761           0.3770           0.3772       0.3729                 0.0043                  0.0011   2479\n",
      "    0.3761           0.3738           0.3768       0.3713                 0.0055                  0.0007   2480\n",
      "    0.3761           0.3809           0.3849       0.3782                 0.0067                  0.0088   2481\n",
      "    0.3761           0.3760           0.3798       0.3727                 0.0071                  0.0037   2482\n",
      "    0.3761           0.3794           0.3817       0.3762                 0.0054                  0.0056   2483\n",
      "    0.3761           0.3718           0.3767       0.3700                 0.0067                  0.0006   2484\n",
      "    0.3761           0.3815           0.3875       0.3789                 0.0086                  0.0114   2485\n",
      "    0.3761           0.3741           0.3766       0.3701                 0.0065                  0.0005   2486\n",
      "    0.3761           0.3759           0.3772       0.3725                 0.0047                  0.0011   2487\n",
      "    0.3761           0.3767           0.3794       0.3731                 0.0063                  0.0033   2488\n",
      "    0.3761           0.3742           0.3784       0.3711                 0.0073                  0.0023   2489\n",
      "    0.3761           0.3732           0.3773       0.3703                 0.0070                  0.0012   2490\n",
      "    0.3761           0.3742           0.3770       0.3707                 0.0062                  0.0009   2491\n",
      "    0.3761           0.3749           0.3750       0.3702                 0.0047                 -0.0011   2492\n",
      "Seed with lowest diff: 2469 with diff=0.003666666666666263\n",
      "num validation= 40000\n",
      "fnstar true err |   fnstar val err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3761           0.3722           0.3743       0.3689                 0.0054                 -0.0018   2468\n",
      "    0.3761           0.3774           0.3788       0.3745                 0.0044                  0.0028   2469\n",
      "    0.3761           0.3790           0.3818       0.3759                 0.0059                  0.0057   2470\n",
      "    0.3761           0.3756           0.3804       0.3739                 0.0066                  0.0044   2471\n",
      "    0.3761           0.3775           0.3820       0.3761                 0.0059                  0.0059   2472\n",
      "    0.3761           0.3766           0.3797       0.3743                 0.0054                  0.0036   2473\n",
      "    0.3761           0.3760           0.3797       0.3734                 0.0063                  0.0037   2474\n",
      "    0.3761           0.3756           0.3790       0.3736                 0.0054                  0.0029   2475\n",
      "    0.3761           0.3716           0.3791       0.3710                 0.0081                  0.0030   2476\n",
      "    0.3761           0.3745           0.3799       0.3730                 0.0069                  0.0038   2477\n",
      "    0.3761           0.3764           0.3785       0.3743                 0.0042                  0.0024   2478\n",
      "    0.3761           0.3772           0.3821       0.3755                 0.0065                  0.0060   2479\n",
      "    0.3761           0.3791           0.3816       0.3767                 0.0049                  0.0055   2480\n",
      "    0.3761           0.3759           0.3793       0.3732                 0.0061                  0.0032   2481\n",
      "    0.3761           0.3727           0.3779       0.3709                 0.0070                  0.0018   2482\n",
      "    0.3761           0.3760           0.3805       0.3741                 0.0063                  0.0044   2483\n",
      "    0.3761           0.3755           0.3759       0.3717                 0.0042                 -0.0001   2484\n",
      "    0.3761           0.3759           0.3797       0.3743                 0.0054                  0.0036   2485\n",
      "    0.3761           0.3770           0.3797       0.3746                 0.0052                  0.0037   2486\n",
      "    0.3761           0.3742           0.3792       0.3727                 0.0065                  0.0031   2487\n",
      "    0.3761           0.3771           0.3800       0.3745                 0.0055                  0.0039   2488\n",
      "    0.3761           0.3778           0.3809       0.3758                 0.0051                  0.0048   2489\n",
      "    0.3761           0.3773           0.3789       0.3737                 0.0052                  0.0028   2490\n",
      "    0.3761           0.3754           0.3806       0.3740                 0.0066                  0.0045   2491\n",
      "    0.3761           0.3736           0.3777       0.3716                 0.0061                  0.0016   2492\n",
      "Seed with lowest diff: 2478 with diff=0.004150000000008758\n",
      "num validation= 50000\n",
      "fnstar true err |   fnstar val err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3761           0.3767           0.3779       0.3741                 0.0038                  0.0018   2468\n",
      "    0.3761           0.3736           0.3769       0.3719                 0.0050                  0.0008   2469\n",
      "    0.3761           0.3795           0.3821       0.3773                 0.0048                  0.0060   2470\n",
      "    0.3761           0.3770           0.3822       0.3758                 0.0064                  0.0062   2471\n",
      "    0.3761           0.3768           0.3804       0.3752                 0.0053                  0.0044   2472\n",
      "    0.3761           0.3746           0.3773       0.3726                 0.0047                  0.0012   2473\n",
      "    0.3761           0.3761           0.3813       0.3752                 0.0061                  0.0052   2474\n",
      "    0.3761           0.3758           0.3779       0.3728                 0.0051                  0.0018   2475\n",
      "    0.3761           0.3728           0.3768       0.3710                 0.0058                  0.0007   2476\n",
      "    0.3761           0.3810           0.3825       0.3781                 0.0044                  0.0064   2477\n",
      "    0.3761           0.3764           0.3794       0.3745                 0.0049                  0.0033   2478\n",
      "    0.3761           0.3786           0.3809       0.3761                 0.0048                  0.0048   2479\n",
      "    0.3761           0.3763           0.3776       0.3740                 0.0036                  0.0015   2480\n",
      "    0.3761           0.3755           0.3797       0.3747                 0.0050                  0.0037   2481\n",
      "    0.3761           0.3779           0.3826       0.3756                 0.0069                  0.0065   2482\n",
      "    0.3761           0.3767           0.3799       0.3749                 0.0050                  0.0039   2483\n",
      "    0.3761           0.3732           0.3779       0.3719                 0.0060                  0.0018   2484\n",
      "    0.3761           0.3777           0.3800       0.3757                 0.0043                  0.0039   2485\n",
      "    0.3761           0.3754           0.3779       0.3733                 0.0046                  0.0018   2486\n",
      "    0.3761           0.3774           0.3813       0.3749                 0.0064                  0.0052   2487\n",
      "    0.3761           0.3796           0.3800       0.3766                 0.0034                  0.0039   2488\n",
      "    0.3761           0.3767           0.3807       0.3752                 0.0055                  0.0046   2489\n",
      "    0.3761           0.3800           0.3811       0.3765                 0.0046                  0.0050   2490\n",
      "    0.3761           0.3734           0.3790       0.3727                 0.0063                  0.0029   2491\n",
      "    0.3761           0.3799           0.3811       0.3777                 0.0034                  0.0050   2492\n",
      "Seed with lowest diff: 2488 with diff=0.0034200000000034203\n"
     ]
    }
   ],
   "source": [
    "# build your _sparse_ boolean function here\n",
    "p_bitflip = .20\n",
    "signature_tuple = (0, 0, 0, 1, 1, 0, 0, 0, 0)\n",
    "fnstar_err, sensitivity_fnstar, fnstar = boolean.compute_fnstar_err_sens(signature_tuple, p_bitflip)\n",
    "p_bitflip = .2\n",
    "subseq_idx = [3, 4, 5, 6, 7, 9, 10, 11] \n",
    "assert len(subseq_idx) == len(signature_tuple) - 1\n",
    "k = len(signature_tuple) - 1\n",
    "assert len(signature_tuple) == k + 1\n",
    "signature = dict(zip(range(len(signature_tuple)), signature_tuple))\n",
    "\n",
    "# 1, compute the true sensitivity of the function f with the given signature\n",
    "all_bitstrings = np.array(list(itertools.product([0, 1], repeat=k)))\n",
    "true_func = lambda x: signature[sum(x)]\n",
    "true_sens = boolean.average_sensitivity(true_func, all_bitstrings)\n",
    "p_zy = boolean.generate_noisy_distr(k, p_bitflip, true_func)\n",
    "true_err = 1 - boolean.compute_acc_noisytest(p_zy, true_func, k) # accuracy of fN* MLE on noisy data\n",
    "print(f\"f TRUE error: {true_err}, f sensitivity: {true_sens}\")\n",
    "\n",
    "# 2. compute error and sensitivity for infinite dataset limit\n",
    "fnstar_true_err, sensitivity_fnstar, fnstar = boolean.compute_fnstar_err_sens(signature_tuple, p_bitflip)\n",
    "print(f\"fnstar TRUE error: {fnstar_true_err}, fnstar sensitivity: {sensitivity_fnstar}\")\n",
    "\n",
    "# now we generate `ntrials` datasets and compute the best-possible validation accuracy of any function\n",
    "ntrials = 25\n",
    "seedstart = 2468\n",
    "seeds = list(range(seedstart, seedstart + ntrials))\n",
    "n_bits = 14 # total bits in X, including label\n",
    "n_train = 10000\n",
    "# n_val = int(50000)\n",
    "\n",
    "for n_val in [10000, 20000, 30000, 40000, 50000]:\n",
    "    print(\"num validation=\", n_val)\n",
    "    acc_diffs = []\n",
    "    print(f\"{'fnstar true err':>10} |   {'fnstar val err':>10} | {'f val error':>10} | {'Zval lookup error':>15} | {'fval - Zval lookup':>15} | {'fval - fnstar':>12} | {'seed':>6}\")\n",
    "    for i in range(ntrials):\n",
    "        seed = seeds[i]\n",
    "        X, Z, subseq_idx = make_datasets.sparse_boolean_weightbased_k_n(n_bits, k, n_train + n_val, signature, p_bitflip=p_bitflip, seed=seed, subseq_idx=subseq_idx)\n",
    "        Z_tr = Z[:n_train]\n",
    "        Z_val = Z[n_train:]\n",
    "\n",
    "        # slice Z_tr, Z_val to only include the subseq_idx (and the label )\n",
    "        Z_tr_subset = Z_tr[:, subseq_idx + [-1]]\n",
    "        Z_val_subset = Z_val[:, subseq_idx + [-1]]\n",
    "\n",
    "        f_val_err = 1 - boolean.compute_acc_on_dataset(true_func, Z_val_subset) # error of f on noisy validation set\n",
    "        f_tr_err = 1 - boolean.compute_acc_on_dataset(true_func, Z_tr_subset) # error of f on noisy train set\n",
    "\n",
    "        fnstar_val_err = 1 - boolean.compute_acc_on_dataset(fnstar, Z_val_subset) # error of f on noisy validation set\n",
    "\n",
    "        # Note: we can't actually compute sensitivity when num_data << 2^nbits, since there are too many\n",
    "        # 'missing' bitstrings from the dataset.\n",
    "        # compute the performance of a `Ztrlookup` (lookup table on Z_tr) on Z_train and Z_val\n",
    "        # err_Ztrlookup_on_Ztrain, err_Ztrlookup_on_Zval, sens_Ztrlookup = boolean.compute_dataset_optimal_sens_and_err(X, Z_tr_subset, Z_val_subset)\n",
    "        # compute the performance of a `Zvallookup` (lookup table on Z_val) on Z_train and Z_val\n",
    "        err_Zvallookup_on_Zval, err_Zvallookup_on_Ztr, sens_Zvallookup = boolean.compute_dataset_optimal_sens_and_err(X, Z1=Z_val_subset, Z2=None)\n",
    "        # now, err_Zvallookup_on_Zval is the best _possible_ validation error of any function on this specific dataset\n",
    "        diff = (f_val_err - err_Zvallookup_on_Zval)\n",
    "        # column descriptions:\n",
    "        # fnstar val err is the error of fnstar applied to the (finite) validation set\n",
    "        # f val error is '' of f applied to ''\n",
    "        # Zval lookup error is '' of the lookup table on the validation set. This _should_ be about the same as fnstar val err for a sufficiently large validation set.\n",
    "\n",
    "        print(f\"{fnstar_true_err:10.4f}       {fnstar_val_err:10.4f}       {f_val_err:10.4f}   {err_Zvallookup_on_Zval:10.4f}             {diff:10.4f}              {f_val_err - fnstar_err:10.4f} {seed:6d}\")\n",
    "        acc_diffs.append(diff)\n",
    "    # print the seed and smallest value of diff\n",
    "    min_diff_idx = np.argmin(acc_diffs)\n",
    "    min_diff_seed = seeds[min_diff_idx]\n",
    "    print(f\"Seed with lowest diff: {min_diff_seed}\", f\"with diff={acc_diffs[min_diff_idx]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the training data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved counterexample_v2_012345678_nbits14_ntr10000_nval40000_bf20_seed2484/train.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval40000_bf20_seed2484/val.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval40000_bf20_seed2484/noiseless_train.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval40000_bf20_seed2484/noiseless_val.pkl\n",
      "Saved counterexample_v2_012345678_nbits14_ntr10000_nval30000_bf20_seed2469/train.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval30000_bf20_seed2469/val.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval30000_bf20_seed2469/noiseless_train.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval30000_bf20_seed2469/noiseless_val.pkl\n",
      "Saved counterexample_v2_012345678_nbits14_ntr10000_nval20000_bf20_seed2488/train.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval20000_bf20_seed2488/val.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval20000_bf20_seed2488/noiseless_train.pkl, counterexample_v2_012345678_nbits14_ntr10000_nval20000_bf20_seed2488/noiseless_val.pkl\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from mindreadingautobots.sequence_generators import make_datasets, data_io\n",
    "# build your _sparse_ boolean function here\n",
    "p_bitflip = .20\n",
    "signature_tuple = (0, 0, 0, 1, 1, 0, 0, 0, 0)\n",
    "p_bitflip = .2\n",
    "subseq_idx = [3, 4, 5, 6, 7, 9, 10, 11] \n",
    "assert len(subseq_idx) == len(signature_tuple) - 1\n",
    "k = len(signature_tuple) - 1\n",
    "assert len(signature_tuple) == k + 1\n",
    "signature = dict(zip(range(len(signature_tuple)), signature_tuple))\n",
    "\n",
    "# now we generate `ntrials` datasets and compute the best-possible validation accuracy of any function\n",
    "n_bits = 14 # total bits in X, including label\n",
    "n_train = 10000\n",
    "\n",
    "for seed, n_val in [(2484, 40000), (2469, 30000), (2488, 20000)]:\n",
    "    X, Z, subseq_idx = make_datasets.sparse_boolean_weightbased_k_n(n_bits, k, n_train + n_val, signature, p_bitflip=p_bitflip, seed=seed, subseq_idx=subseq_idx)\n",
    "    Z_train = Z[:n_train]\n",
    "    Z_val = Z[n_train:]\n",
    "\n",
    "    gen_name = \"counterexample_v2_\" + \"\".join([str(i) for i in signature])\n",
    "    p100 = int(p_bitflip * 100)\n",
    "    suffix = f\"_nbits{n_bits}_ntr{n_train}_nval{n_val}_bf{p100}_seed{seed}\"\n",
    "    dirname = gen_name + suffix\n",
    "    if not os.path.exists(dirname):\n",
    "        os.makedirs(dirname)\n",
    "\n",
    "    train_path = f\"{dirname}/train.pkl\"\n",
    "    val_path = f\"{dirname}/val.pkl\"\n",
    "    data_io.save_numpy_as_dict(Z_train, train_path)\n",
    "    data_io.save_numpy_as_dict(Z_val, val_path)\n",
    "\n",
    "    X_train = X[:n_train]\n",
    "    X_val = X[n_train:]\n",
    "    noiseless_train_path = f\"{dirname}/noiseless_train.pkl\"\n",
    "    noiseless_val_path = f\"{dirname}/noiseless_val.pkl\"\n",
    "    data_io.save_numpy_as_dict(X_train, noiseless_train_path)\n",
    "    data_io.save_numpy_as_dict(X_val, noiseless_val_path)\n",
    "    print(f\"Saved {train_path}, {val_path}, {noiseless_train_path}, {noiseless_val_path}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Verification\n",
    "\n",
    "- the below script was used to verify that for about $10^6$ samples we get convergence between a (lookup on val evaluated on val) and noisy $f_N^*$ accuracy, for noisy parity with $p=0.25$ and $3$ input bits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4375\n",
      "0.4375\n",
      "f sensitivity: 3.0\n",
      "0.3919999999999999\n",
      "0.3919999999999999\n",
      "fnstar TRUE error: 0.3919999999999999, fnstar sensitivity: 3.0\n",
      "fnstar true err |   fnstar val err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3920           0.3916           0.3916       0.3916                 0.0000                 -0.0459   3456\n",
      "    0.3920           0.3920           0.3920       0.3920                 0.0000                 -0.0455   3457\n",
      "    0.3920           0.3922           0.3922       0.3922                 0.0000                 -0.0453   3458\n",
      "    0.3920           0.3920           0.3920       0.3920                 0.0000                 -0.0455   3459\n",
      "    0.3920           0.3918           0.3918       0.3918                 0.0000                 -0.0457   3460\n"
     ]
    }
   ],
   "source": [
    "# build your _sparse_ boolean function here\n",
    "p_bitflip = .25\n",
    "signature_tuple = (0, 1, 0, 1)\n",
    "fnstar_err, sensitivity_fnstar, fnstar = boolean.compute_fnstar_err_sens(signature_tuple, p_bitflip)\n",
    "p_bitflip = .2\n",
    "subseq_idx = [0, 1, 2]\n",
    "assert len(subseq_idx) == len(signature_tuple) - 1\n",
    "k = len(signature_tuple) - 1\n",
    "assert len(signature_tuple) == k + 1\n",
    "signature = dict(zip(range(len(signature_tuple)), signature_tuple))\n",
    "\n",
    "# first, compute the true sensitivity of the function with the given signature\n",
    "all_bitstrings = np.array(list(itertools.product([0, 1], repeat=k)))\n",
    "true_func = lambda x: signature[sum(x)]\n",
    "true_sens = boolean.average_sensitivity(true_func, all_bitstrings)\n",
    "print(f\"f sensitivity: {true_sens}\")\n",
    "\n",
    "# compute error and sensitivity for infinite dataset limit\n",
    "fnstar_true_err, sensitivity_fnstar, fnstar = boolean.compute_fnstar_err_sens(signature_tuple, p_bitflip)\n",
    "print(f\"fnstar TRUE error: {fnstar_true_err}, fnstar sensitivity: {sensitivity_fnstar}\")\n",
    "\n",
    "# now we generate `ntrials` datasets and compute the best-possible validation accuracy of any function\n",
    "ntrials = 5\n",
    "seeds = list(range(3456, 3456 + ntrials))\n",
    "n_bits = 1 + k # total bits in X, including label\n",
    "n_train = 10000\n",
    "n_val = int(1e6)\n",
    "\n",
    "acc_diffs = []\n",
    "print(f\"{'fnstar true err':>10} |   {'fnstar val err':>10} | {'f val error':>10} | {'Zval lookup error':>15} | {'fval - Zval lookup':>15} | {'fval - fnstar':>12} | {'seed':>6}\")\n",
    "for i in range(ntrials):\n",
    "    seed = seeds[i]\n",
    "    X, Z, subseq_idx = make_datasets.sparse_boolean_weightbased_k_n(n_bits, k, n_train + n_val, signature, p_bitflip=p_bitflip, seed=seed, subseq_idx=subseq_idx)\n",
    "    Z_tr = Z[:n_train]\n",
    "    Z_val = Z[n_train:]\n",
    "\n",
    "    # slice Z_tr, Z_val to only include the subseq_idx (and the label )\n",
    "    Z_tr_subset = Z_tr[:, subseq_idx + [-1]]\n",
    "    Z_val_subset = Z_val[:, subseq_idx + [-1]]\n",
    "\n",
    "    f_val_err = 1 - boolean.compute_acc_on_dataset(true_func, Z_val_subset) # error of f on noisy validation set\n",
    "    f_tr_err = 1 - boolean.compute_acc_on_dataset(true_func, Z_tr_subset) # error of f on noisy train set\n",
    "\n",
    "    fnstar_val_err = 1 - boolean.compute_acc_on_dataset(fnstar, Z_val_subset) # error of f on noisy validation set\n",
    "\n",
    "    # Note: we can't actually compute sensitivity when num_data << 2^nbits, since there are too many\n",
    "    # 'missing' bitstrings from the dataset.\n",
    "    # compute the performance of a `Ztrlookup` (lookup table on Z_tr) on Z_train and Z_val\n",
    "    # err_Ztrlookup_on_Ztrain, err_Ztrlookup_on_Zval, sens_Ztrlookup = boolean.compute_dataset_optimal_sens_and_err(X, Z_tr_subset, Z_val_subset)\n",
    "    # compute the performance of a `Zvallookup` (lookup table on Z_val) on Z_train and Z_val\n",
    "    err_Zvallookup_on_Zval, err_Zvallookup_on_Ztr, sens_Zvallookup = boolean.compute_dataset_optimal_sens_and_err(X, Z1=Z_val_subset, Z2=None)\n",
    "    # now, err_Zvallookup_on_Zval is the best _possible_ validation error of any function on this specific dataset\n",
    "    diff = (err_Zvallookup_on_Zval - f_val_err)\n",
    "    # column descriptions:\n",
    "    # fnstar val err is the error of fnstar applied to the (finite) validation set\n",
    "    # f val error is '' of f applied to ''\n",
    "    # Zval lookup error is '' of the lookup table on the validation set. This _should_ be about the same as fnstar val err for a sufficiently large validation set.\n",
    "\n",
    "    print(f\"{fnstar_true_err:10.4f}       {fnstar_val_err:10.4f}       {f_val_err:10.4f}   {err_Zvallookup_on_Zval:10.4f}             {diff:10.4f}              {f_val_err - fnstar_err:10.4f} {seed:6d}\")\n",
    "    acc_diffs.append(diff)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Its really surprising that with $n=3$ we don't get convergence of fnstar (or fval lookup) to the true generalization error\n",
    "\n",
    "\n",
    "checks:\n",
    " - fnstar DOES agree with the lookup table for sufficient data, this is good.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.3799588799999999\n",
      "0.37608843999999875\n"
     ]
    }
   ],
   "source": [
    "p_bitflip = .2\n",
    "signature_tuple = (0, 0, 0, 1, 1, 0, 0, 0, 0)\n",
    "fnstar_err, sensitivity_fnstar = compute_fnstar_err_sens(signature_tuple, p_bitflip)\n",
    "# print(fnstar_err, sensitivity_fnstar)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "true sensitivity: 3.5\n",
      "0.3799588799999999\n",
      "0.37608843999999875\n",
      "fnstar_err: 0.37608843999999875, sensitivity_fnstar: 2.625\n",
      "fnstar err | f val error | Zval lookup error | fval - Zval lookup | fval - fnstar |   seed\n",
      "    0.3761     0.3782       0.3641                -0.0141                  0.0021   3456\n",
      "    0.3761     0.3818       0.3632                -0.0186                  0.0057   3457\n",
      "    0.3761     0.3750       0.3627                -0.0123                 -0.0011   3458\n",
      "    0.3761     0.3869       0.3676                -0.0193                  0.0108   3459\n",
      "    0.3761     0.3856       0.3676                -0.0180                  0.0095   3460\n"
     ]
    }
   ],
   "source": [
    "p_bitflip = .2\n",
    "signature_tuple = (0, 0, 0, 1, 1, 0, 0, 0, 0)\n",
    "subseq_idx = [3, 4, 5, 6, 7, 9, 10, 11] \n",
    "k = len(signature_tuple) - 1\n",
    "assert len(signature_tuple) == k + 1\n",
    "signature = dict(zip(range(len(signature_tuple)), signature_tuple))\n",
    "\n",
    "# first, compute the true sensitivity of the function with the given signature\n",
    "all_bitstrings = np.array(list(itertools.product([0, 1], repeat=k)))\n",
    "true_func = lambda x: signature[sum(x)]\n",
    "true_sens = boolean.average_sensitivity(true_func, all_bitstrings)\n",
    "print(f\"true sensitivity: {true_sens}\")\n",
    "\n",
    "# compute error and sensitivity for infinite dataset limit\n",
    "fnstar_err, sensitivity_fnstar = compute_fnstar_err_sens(signature_tuple, p_bitflip)\n",
    "print(f\"fnstar_err: {fnstar_err}, sensitivity_fnstar: {sensitivity_fnstar}\")\n",
    "\n",
    "# now we generate `ntrials` datasets and compute the best-possible validation accuracy of any function\n",
    "ntrials = 5\n",
    "seeds = list(range(3456, 3456 + ntrials))\n",
    "n_bits = 14 # total bits in X, including label\n",
    "n_train = 10000\n",
    "n_val = 500000\n",
    "\n",
    "\n",
    "acc_diffs = []\n",
    "print(f\"{'fnstar err':>10} | {'f val error':>10} | {'Zval lookup error':>15} | {'fval - Zval lookup':>15} | {'fval - fnstar':>12} | {'seed':>6}\")\n",
    "for i in range(ntrials):\n",
    "    seed = seeds[i]\n",
    "    X, Z, subseq_idx = make_datasets.sparse_boolean_weightbased_k_n(n_bits, k, n_train + n_val, signature, p_bitflip=p_bitflip, seed=seed, subseq_idx=subseq_idx)\n",
    "    Z_tr = Z[:n_train]\n",
    "    Z_val = Z[n_val:]\n",
    "\n",
    "    # slice Z_tr, Z_val to only include the subseq_idx (and the label )\n",
    "    Z_tr_subset = Z_tr[:, subseq_idx + [-1]]\n",
    "    Z_val_subset = Z_val[:, subseq_idx + [-1]]\n",
    "\n",
    "    f_val_err = 1 - boolean.compute_acc_on_dataset(true_func, Z_val_subset) # error of f on noisy validation set\n",
    "    f_tr_err = 1 - boolean.compute_acc_on_dataset(true_func, Z_tr_subset) # error of f on noisy train set\n",
    "\n",
    "    # Note: we can't actually compute sensitivity when num_data << 2^nbits, since there are too many\n",
    "    # 'missing' bitstrings from the dataset.\n",
    "    # compute the performance of a `Ztrlookup` (lookup table on Z_tr) on Z_train and Z_val\n",
    "    # err_Ztrlookup_on_Ztrain, err_Ztrlookup_on_Zval, sens_Ztrlookup = boolean.compute_dataset_optimal_sens_and_err(X, Z_tr_subset, Z_val_subset)\n",
    "    # compute the performance of a `Zvallookup` (lookup table on Z_val) on Z_train and Z_val\n",
    "    err_Zvallookup_on_Zval, err_Zvallookup_on_Ztr, sens_Zvallookup = boolean.compute_dataset_optimal_sens_and_err(X, Z_val_subset, Z_tr_subset)\n",
    "    # now, err_Zvallookup_on_Zval is the best _possible_ validation error of any function on this specific dataset\n",
    "    diff = (err_Zvallookup_on_Zval - f_val_err)\n",
    "    print(f\"{fnstar_err:10.4f} {f_val_err:10.4f}   {err_Zvallookup_on_Zval:10.4f}             {diff:10.4f}              {f_val_err - fnstar_err:10.4f} {seed:6d}\")\n",
    "    acc_diffs.append(diff)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-0.009699999999998932 3466\n",
      "-0.00989999999999891 3475\n",
      "-0.010399999999998855 3461\n",
      "-0.010999999999998789 3505\n",
      "-0.011099999999998778 3499\n"
     ]
    }
   ],
   "source": [
    "# print the indices of the 5 trials with smallest diff\n",
    "idx_best = np.argsort(-np.array(acc_diffs))[:5]\n",
    "for i in idx_best:\n",
    "    print(acc_diffs[i], seeds[i])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "good_seeds = [2364]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
