{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "d = 200\n",
    "theta = np.array([-1]*int(d/2) + [1]*int(d/2))\n",
    "d = len(theta)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n",
       "       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n",
       "        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_exp(Z_comp_guess, Z_version, i,l, num_samples):\n",
    "    \n",
    "    for j in range(Z_version.shape[0]):\n",
    "        #first try seeing what kill label is -1\n",
    "        if Z_version[j,i] == l:\n",
    "            if np.any(Z_comp_guess) == None:\n",
    "                Z_comp_guess = Z_version[j,:]\n",
    "            else:\n",
    "                Z_comp_guess = np.vstack([Z_comp_guess,Z_version[j,:]])\n",
    "                \n",
    "#     print(Z_comp_guess)\n",
    "\n",
    "    #if complement is empty set\n",
    "    if np.any(Z_comp_guess) == None:\n",
    "        return -np.inf\n",
    "\n",
    "    sup_vals = []\n",
    "    for k in range(num_samples):\n",
    "        eta_k = np.random.multivariate_normal([0]*d, np.eye(d), 1).flatten()\n",
    "\n",
    "        sup_vals.append(np.max(Z_comp_guess@eta_k))\n",
    "\n",
    "    val_1 = np.mean(sup_vals)\n",
    "\n",
    "    return val_1\n",
    "\n",
    "def alg(theta, num_samples=100):\n",
    "    Z = np.vstack([np.triu(np.ones((d,d))), np.zeros(d)])\n",
    "    Z_version = Z\n",
    "    Z_comp = None\n",
    "    I = []\n",
    "    num_queries = 0\n",
    "    while Z_version.shape[0] > 1:\n",
    "        print(f\"num_queries {num_queries}\")\n",
    "        num_queries += 1\n",
    "\n",
    "        best_val = -np.inf\n",
    "        best_index = 0 \n",
    "        for i in range(d):\n",
    "            if i not in I:\n",
    "                val = -np.inf\n",
    "                Z_comp_guess = np.copy(Z_comp)\n",
    "\n",
    "#                 print(f\"Z_comp_guess {Z_comp_guess}\")\n",
    "#                 print(f\"Z_version {Z_version}\")\n",
    "\n",
    "                val_1 = compute_exp(Z_comp_guess, Z_version, i,1, num_samples)\n",
    "                val_0 = compute_exp(Z_comp_guess, Z_version, i,0, num_samples)\n",
    "\n",
    "                val = np.min([val_1, val_0])\n",
    "#                 print(f\"val {val}\")\n",
    "                if val > best_val:\n",
    "                    best_val = val\n",
    "                    best_index = i\n",
    "                    \n",
    "#         print(f\"best_index {best_index}\")\n",
    "\n",
    "        I.append(best_index)\n",
    "\n",
    "        zs_to_delete = []\n",
    "        for j in range(Z_version.shape[0]):\n",
    "            #first try seeing what kill label is -1\n",
    "            if 2*Z_version[j,best_index]-1 != theta[best_index]:\n",
    "                if np.any(Z_comp) == None:\n",
    "                    Z_comp = Z_version[j,:]\n",
    "                else:\n",
    "                    Z_comp = np.vstack([Z_comp,Z_version[j,:]])\n",
    "\n",
    "                zs_to_delete.append(False)\n",
    "\n",
    "            else:\n",
    "                zs_to_delete.append(True)\n",
    "        \n",
    "        Z_comp = None\n",
    "\n",
    "    #     print(Z_version)\n",
    "\n",
    "        Z_version = Z_version[np.array(zs_to_delete)]\n",
    "\n",
    "\n",
    "    print(f\"num_queries {num_queries} log_2(d) {np.log2(d)}\")\n",
    "    if Z_version@theta == np.sum(theta > 0):\n",
    "        print(\"correct\")\n",
    "\n",
    "    return Z_version@theta == np.sum(theta > 0), num_queries\n",
    "\n",
    "        \n",
    "\n",
    "    \n",
    "                \n",
    "                \n",
    "                \n",
    "            \n",
    "        \n",
    "    \n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7\n",
      "num_queries 8 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7\n",
      "num_queries 8\n",
      "num_queries 9\n",
      "num_queries 10 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7\n",
      "num_queries 8\n",
      "num_queries 9\n",
      "num_queries 10\n",
      "num_queries 11 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7\n",
      "num_queries 8\n",
      "num_queries 9\n",
      "num_queries 10 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7\n",
      "num_queries 8\n",
      "num_queries 9\n",
      "num_queries 10\n",
      "num_queries 11 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6\n",
      "num_queries 7\n",
      "num_queries 8\n",
      "num_queries 9\n",
      "num_queries 10\n",
      "num_queries 11 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 0\n",
      "num_queries 1\n",
      "num_queries 2\n",
      "num_queries 3\n",
      "num_queries 4\n",
      "num_queries 5\n",
      "num_queries 6 log_2(d) 7.643856189774724\n",
      "correct\n",
      "num_queries 8.7 log_2(d) 7.643856189774724 correct 1.0\n"
     ]
    }
   ],
   "source": [
    "num_trials = 10\n",
    "num_queries_list = []\n",
    "correct_list = []\n",
    "for i in range(num_trials):\n",
    "    correct, num_queries = alg(theta, 10)\n",
    "    num_queries_list.append(num_queries)\n",
    "    correct_list.append(correct)\n",
    "    \n",
    "print(f\"num_queries {np.mean(num_queries_list)} log_2(d) {np.log2(d)} correct {np.mean(correct_list)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 100)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eta_k.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-1.73743198,  0.98273711, -0.46309205, -1.25783626,  0.64364922,\n",
       "        -0.23975986, -0.16774429,  0.0364398 ,  0.71444186,  1.36807469,\n",
       "        -0.18409699,  1.03828584, -0.11592847,  1.75368764, -1.42249174,\n",
       "         0.34290708, -1.28202255, -0.89884292, -1.06404908, -1.75253268,\n",
       "         0.68227314, -0.20156948,  0.15957522,  0.17970811,  0.45694724,\n",
       "         0.04822644,  1.53353521,  0.35356725,  0.36640589,  1.49692999,\n",
       "        -0.54115376, -0.00478105,  0.34800978, -0.18184813, -0.50614433,\n",
       "        -0.58870755,  1.58103524, -1.21171044,  1.02559994,  1.03659995,\n",
       "         1.76624185,  0.73472849,  1.7679651 , -1.31740282,  0.22223177,\n",
       "         0.56118956, -0.86158222,  2.11415806,  0.69760646,  0.33290574,\n",
       "         0.53186914,  0.70161446, -0.46981829, -2.19971282, -0.29946021,\n",
       "        -0.19150158, -2.03139917, -1.40451026,  1.5596409 , -1.38175786,\n",
       "         0.85496229, -0.01705598, -1.23637523,  0.3861246 ,  0.00692676,\n",
       "        -0.47639615, -1.0311239 ,  0.35373542, -0.47975858,  1.09415221,\n",
       "        -2.53700516,  0.83232339, -0.8047969 ,  1.00801098,  1.48762908,\n",
       "        -0.888423  ,  0.02113312, -0.20753878,  0.18251979,  0.2412457 ,\n",
       "        -1.02036447,  1.16590248,  1.53842794,  1.70379488,  0.09005794,\n",
       "        -0.22345326,  1.69051059, -0.69131475, -0.18622543, -0.21113146,\n",
       "        -0.28631329,  1.75401731, -0.31459552, -1.35390137, -0.07324743,\n",
       "        -0.46365294, -1.68241723,  1.31447427, -0.06718388, -0.69759969]])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eta_k"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
