{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28957320-6da2-4e9e-9962-d0f264623905",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics.pairwise import rbf_kernel\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mmdew.fast_rbf_kernel import est_gamma\n",
    "from mmdew.mmdew import MMDEW\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import pandas as pd\n",
    "from notebooks.data import MixedNormal, Uniform, Laplace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c42ddd39-b9c3-4eb7-8c66-9637b3419a71",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d683107c-8a74-4ac0-9719-00b6dd063853",
   "metadata": {},
   "outputs": [],
   "source": [
    "d=20\n",
    "ref_size=1000\n",
    "rng = np.random.default_rng()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "acd8cf9c-7361-4c18-9361-2dfd170bae54",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('mmdew-statistics.pickle', 'rb') as handle:\n",
    "    statistics = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e181b8c1-0db1-489a-a7ae-0096259f25f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_arls_log = np.arange(3,5.1,.25)\n",
    "arl2thresh = { i : np.quantile(statistics, 1-(1/10**i)) for i in target_arls_log}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "05468c21-0277-4365-b805-72c5f8973d19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def edd(arl2thresh, statistics):\n",
    "    arl2edd = {}\n",
    "    for arl, thresh in arl2thresh.items():\n",
    "        edd = [np.argmax(s + [np.inf]>thresh) for s in h1_stats]\n",
    "        arl2edd[arl] = np.mean(edd) + 1 # account for counting from 0\n",
    "    return arl2edd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bf683d1-f6e0-4d0e-9173-a6be40c9a320",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c21844df-bfb1-4aaa-a0eb-66bef7d0b8fc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8ba924ad-cc30-477e-a835-8327c2afe8e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 20\n",
    "n_q = 500\n",
    "qs = {\n",
    "    \"MixedNormal0.3\" : MixedNormal(n_q,d,0.3),\n",
    "    \"MixedNormal0.7\" : MixedNormal(n_q,d,0.7),\n",
    "    \"Laplace\"        : Laplace(n_q,d),\n",
    "    \"Uniform\"        : Uniform(n_q,d)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dd5070a5-f59e-43e7-8fdd-29bd815309d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bbefb957-ca1f-47cc-8c0e-5be3e8a6969b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 100/100 [01:51<00:00,  1.11s/it]\n",
      "100%|█████████████████████████████████████████| 100/100 [01:59<00:00,  1.20s/it]\n",
      "100%|█████████████████████████████████████████| 100/100 [01:58<00:00,  1.19s/it]\n",
      "100%|█████████████████████████████████████████| 100/100 [02:01<00:00,  1.22s/it]\n"
     ]
    }
   ],
   "source": [
    "for name, q in qs.items():\n",
    "    h1_stats = []\n",
    "    for _ in tqdm(range(100)):\n",
    "        ref = rng.normal(size=(10000,d))\n",
    "        gamma = est_gamma(ref)\n",
    "        detector = MMDEW(gamma=gamma)\n",
    "        \n",
    "        for elem in ref[:64]:\n",
    "            detector.insert(elem.reshape(1,-1))\n",
    "            \n",
    "        for elem in q.draw():\n",
    "            detector.insert(elem.reshape(1,-1))\n",
    "        h1_stats += [detector.stats[64:]]\n",
    "    df = pd.concat((df, pd.DataFrame(edd(arl2thresh=arl2thresh, statistics=h1_stats), index=[name])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d4ddda08-7a7c-41f2-bb56-2aac8d8b901e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.reset_index(names=\"data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3b20fe58-61bd-4539-969f-cd7b54afc0df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "      <th>3.0</th>\n",
       "      <th>3.25</th>\n",
       "      <th>3.5</th>\n",
       "      <th>3.75</th>\n",
       "      <th>4.0</th>\n",
       "      <th>4.25</th>\n",
       "      <th>4.5</th>\n",
       "      <th>4.75</th>\n",
       "      <th>5.0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>MixedNormal0.3</td>\n",
       "      <td>1.62</td>\n",
       "      <td>1.62</td>\n",
       "      <td>1.64</td>\n",
       "      <td>1.65</td>\n",
       "      <td>1.68</td>\n",
       "      <td>1.71</td>\n",
       "      <td>1.79</td>\n",
       "      <td>1.82</td>\n",
       "      <td>1.82</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>MixedNormal0.7</td>\n",
       "      <td>3.06</td>\n",
       "      <td>3.14</td>\n",
       "      <td>3.15</td>\n",
       "      <td>3.22</td>\n",
       "      <td>3.25</td>\n",
       "      <td>3.35</td>\n",
       "      <td>3.38</td>\n",
       "      <td>3.46</td>\n",
       "      <td>3.46</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Laplace</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.01</td>\n",
       "      <td>1.03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Uniform</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "             data   3.0  3.25   3.5  3.75   4.0  4.25   4.5  4.75   5.0\n",
       "0  MixedNormal0.3  1.62  1.62  1.64  1.65  1.68  1.71  1.79  1.82  1.82\n",
       "1  MixedNormal0.7  3.06  3.14  3.15  3.22  3.25  3.35  3.38  3.46  3.46\n",
       "2         Laplace  1.01  1.01  1.01  1.01  1.01  1.01  1.01  1.01  1.03\n",
       "3         Uniform  1.00  1.00  1.00  1.00  1.00  1.00  1.00  1.00  1.00"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e74ced02-a996-4b3c-abb5-a390fbc3e2cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df.melt(id_vars=\"data\",var_name=\"logARL\",value_name=\"EDD\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "871a9647-f6c1-494c-96ff-937ff4a89b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"algorithm\"] = \"MMDEW\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4797d7f7-e2a1-4b6d-94cd-5cd7d1e521e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"../results_rebuttal/arl-vs-edd/mmdew.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a293835c-2e3b-4f1e-8ec9-219af443100a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
