{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "d = 200\n",
    "theta = np.array([-1]*int(d/2) + [1]*int(d/2))\n",
    "d = len(theta)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_exp(Z_comp_guess, Z_version, i,l, num_samples):\n",
    "    \n",
    "    for j in range(Z_version.shape[0]):\n",
    "        #first try seeing what kill label is -1\n",
    "        if Z_version[j,i] == l:\n",
    "            if np.any(Z_comp_guess) == None:\n",
    "                Z_comp_guess = Z_version[j,:]\n",
    "            else:\n",
    "                Z_comp_guess = np.vstack([Z_comp_guess,Z_version[j,:]])\n",
    "                \n",
    "#     print(Z_comp_guess)\n",
    "\n",
    "    #if complement is empty set\n",
    "    if np.any(Z_comp_guess) == None:\n",
    "        return -np.inf\n",
    "\n",
    "    sup_vals = []\n",
    "    for k in range(num_samples):\n",
    "        eta_k = np.random.multivariate_normal([0]*d, np.eye(d), 1).flatten()\n",
    "\n",
    "        sup_vals.append(np.max(Z_comp_guess@eta_k))\n",
    "\n",
    "    val_1 = np.mean(sup_vals)\n",
    "\n",
    "    return val_1\n",
    "\n",
    "def alg(theta, num_samples=100):\n",
    "    Z = np.vstack([np.triu(np.ones((d,d))), np.zeros(d)])\n",
    "    Z_version = Z\n",
    "    Z_comp = None\n",
    "    I = []\n",
    "    num_queries = 0\n",
    "    while Z_version.shape[0] > 1:\n",
    "        print(f\"num_queries {num_queries}\")\n",
    "        num_queries += 1\n",
    "\n",
    "        best_val = -np.inf\n",
    "        best_index = 0 \n",
    "        for i in range(d):\n",
    "            if i not in I:\n",
    "                val = -np.inf\n",
    "                Z_comp_guess = np.copy(Z_comp)\n",
    "\n",
    "#                 print(f\"Z_comp_guess {Z_comp_guess}\")\n",
    "#                 print(f\"Z_version {Z_version}\")\n",
    "\n",
    "                val_1 = compute_exp(Z_comp_guess, Z_version, i,1, num_samples)\n",
    "                val_0 = compute_exp(Z_comp_guess, Z_version, i,0, num_samples)\n",
    "\n",
    "                val = np.min([val_1, val_0])\n",
    "#                 print(f\"val {val}\")\n",
    "                if val > best_val:\n",
    "                    best_val = val\n",
    "                    best_index = i\n",
    "                    \n",
    "#         print(f\"best_index {best_index}\")\n",
    "\n",
    "        I.append(best_index)\n",
    "\n",
    "        zs_to_delete = []\n",
    "        for j in range(Z_version.shape[0]):\n",
    "            #first try seeing what kill label is -1\n",
    "            if 2*Z_version[j,best_index]-1 != theta[best_index]:\n",
    "                if np.any(Z_comp) == None:\n",
    "                    Z_comp = Z_version[j,:]\n",
    "                else:\n",
    "                    Z_comp = np.vstack([Z_comp,Z_version[j,:]])\n",
    "\n",
    "                zs_to_delete.append(False)\n",
    "\n",
    "            else:\n",
    "                zs_to_delete.append(True)\n",
    "        \n",
    "        Z_comp = None\n",
    "\n",
    "    #     print(Z_version)\n",
    "\n",
    "        Z_version = Z_version[np.array(zs_to_delete)]\n",
    "\n",
    "\n",
    "    print(f\"num_queries {num_queries} log_2(d) {np.log2(d)}\")\n",
    "    if Z_version@theta == np.sum(theta > 0):\n",
    "        print(\"correct\")\n",
    "\n",
    "    return Z_version@theta == np.sum(theta > 0), num_queries\n",
    "\n",
    "        \n",
    "\n",
    "    \n",
    "                \n",
    "                \n",
    "                \n",
    "            \n",
    "        \n",
    "    \n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-4-2fc9c6b27543>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mcorrect_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_trials\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     \u001b[0mcorrect\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_queries\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0mnum_queries_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_queries\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mcorrect_list\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcorrect\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-0d4d918decc7>\u001b[0m in \u001b[0;36malg\u001b[0;34m(theta, num_samples)\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[0;31m#                 print(f\"Z_version {Z_version}\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m                 \u001b[0mval_1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_exp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mZ_comp_guess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mZ_version\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m                 \u001b[0mval_0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_exp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mZ_comp_guess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mZ_version\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-0d4d918decc7>\u001b[0m in \u001b[0;36mcompute_exp\u001b[0;34m(Z_comp_guess, Z_version, i, l, num_samples)\u001b[0m\n\u001b[1;32m     17\u001b[0m     \u001b[0msup_vals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 19\u001b[0;31m         \u001b[0meta_k\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultivariate_normal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     21\u001b[0m         \u001b[0msup_vals\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mZ_comp_guess\u001b[0m\u001b[0;34m@\u001b[0m\u001b[0meta_k\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32mmtrand.pyx\u001b[0m in \u001b[0;36mnumpy.random.mtrand.RandomState.multivariate_normal\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36msvd\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/lib/python3.7/site-packages/numpy/linalg/linalg.py\u001b[0m in \u001b[0;36msvd\u001b[0;34m(a, full_matrices, compute_uv, hermitian)\u001b[0m\n\u001b[1;32m   1634\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1635\u001b[0m         \u001b[0msignature\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'D->DdD'\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misComplexType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m'd->ddd'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1636\u001b[0;31m         \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgufunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msignature\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msignature\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mextobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1637\u001b[0m         \u001b[0mu\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1638\u001b[0m         \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_realType\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_t\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "num_trials = 10\n",
    "num_queries_list = []\n",
    "correct_list = []\n",
    "for i in range(num_trials):\n",
    "    correct, num_queries = alg(theta, 10)\n",
    "    num_queries_list.append(num_queries)\n",
    "    correct_list.append(correct)\n",
    "    \n",
    "print(f\"num_queries {np.mean(num_queries_list)} log_2(d) {np.log2(d)} correct {np.mean(correct_list)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 100)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eta_k.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.73743198,  0.98273711, -0.46309205, -1.25783626,  0.64364922,\n",
       "        -0.23975986, -0.16774429,  0.0364398 ,  0.71444186,  1.36807469,\n",
       "        -0.18409699,  1.03828584, -0.11592847,  1.75368764, -1.42249174,\n",
       "         0.34290708, -1.28202255, -0.89884292, -1.06404908, -1.75253268,\n",
       "         0.68227314, -0.20156948,  0.15957522,  0.17970811,  0.45694724,\n",
       "         0.04822644,  1.53353521,  0.35356725,  0.36640589,  1.49692999,\n",
       "        -0.54115376, -0.00478105,  0.34800978, -0.18184813, -0.50614433,\n",
       "        -0.58870755,  1.58103524, -1.21171044,  1.02559994,  1.03659995,\n",
       "         1.76624185,  0.73472849,  1.7679651 , -1.31740282,  0.22223177,\n",
       "         0.56118956, -0.86158222,  2.11415806,  0.69760646,  0.33290574,\n",
       "         0.53186914,  0.70161446, -0.46981829, -2.19971282, -0.29946021,\n",
       "        -0.19150158, -2.03139917, -1.40451026,  1.5596409 , -1.38175786,\n",
       "         0.85496229, -0.01705598, -1.23637523,  0.3861246 ,  0.00692676,\n",
       "        -0.47639615, -1.0311239 ,  0.35373542, -0.47975858,  1.09415221,\n",
       "        -2.53700516,  0.83232339, -0.8047969 ,  1.00801098,  1.48762908,\n",
       "        -0.888423  ,  0.02113312, -0.20753878,  0.18251979,  0.2412457 ,\n",
       "        -1.02036447,  1.16590248,  1.53842794,  1.70379488,  0.09005794,\n",
       "        -0.22345326,  1.69051059, -0.69131475, -0.18622543, -0.21113146,\n",
       "        -0.28631329,  1.75401731, -0.31459552, -1.35390137, -0.07324743,\n",
       "        -0.46365294, -1.68241723,  1.31447427, -0.06718388, -0.69759969]])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eta_k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
