{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5125b1f4-34fe-4f08-b7f6-7f9c6f28ae48",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b3960d83-1962-4f28-bdbc-4ebf6caf3927",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append(\"../src/\")\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import kernel\n",
    "from tqdm import tqdm\n",
    "from itertools import chain\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bd160d24-7d86-46fa-82dc-1bf2866efd7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88fe108b-7c73-4db0-b40d-9ef87587520e",
   "metadata": {},
   "source": [
    "# Normal distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7f7b3495-c152-419e-9449-88f95c805e0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "d=20\n",
    "length = 2**17-1\n",
    "X = rng.normal(size=(length,d))   \n",
    "gamma = kernel.Gauss.est_gamma(X)\n",
    "gamma = 0.025232487341699427 # fix for consistency across runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0c19d67e-07a4-4e3e-8443-c4f1d2ad3b9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "gauss = kernel.Gauss(gamma=gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2db60183-cfbe-4a23-ae6c-d010f7c7ee6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 100/100 [59:16<00:00, 35.56s/it]\n"
     ]
    }
   ],
   "source": [
    "mmd_vals = []\n",
    "\n",
    "for _ in tqdm(range(100)):\n",
    "    acc = []\n",
    "    X = rng.normal(size=(length,d))   \n",
    "    cd = kernel.StreamingRFFMMD(gauss,d=d,num_omegas=1000)\n",
    "    for i, elem in enumerate(X):\n",
    "        cd.insert(elem)\n",
    "        acc += [np.max(cd.normalized_mmd()+ [0])]\n",
    "    mmd_vals += [acc]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9da8e39d-59fd-4dde-8acc-68d6cb80044b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_vals = list(chain(*mmd_vals))\n",
    "\n",
    "target_arls_log = np.arange(3,5.1,.25)\n",
    "arl2thresh = { i : np.quantile(all_vals, 1-(1/10**i)) for i in target_arls_log}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "df38b0c5-5644-4ff5-8ac2-54002f739e3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>log ARL</th>\n",
       "      <th>threshold</th>\n",
       "      <th>d</th>\n",
       "      <th>data</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.00</td>\n",
       "      <td>1.086488</td>\n",
       "      <td>20</td>\n",
       "      <td>N(0,I)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.25</td>\n",
       "      <td>1.102175</td>\n",
       "      <td>20</td>\n",
       "      <td>N(0,I)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.50</td>\n",
       "      <td>1.118614</td>\n",
       "      <td>20</td>\n",
       "      <td>N(0,I)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.75</td>\n",
       "      <td>1.135688</td>\n",
       "      <td>20</td>\n",
       "      <td>N(0,I)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4.00</td>\n",
       "      <td>1.150247</td>\n",
       "      <td>20</td>\n",
       "      <td>N(0,I)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   log ARL  threshold   d    data\n",
       "0     3.00   1.086488  20  N(0,I)\n",
       "1     3.25   1.102175  20  N(0,I)\n",
       "2     3.50   1.118614  20  N(0,I)\n",
       "3     3.75   1.135688  20  N(0,I)\n",
       "4     4.00   1.150247  20  N(0,I)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(arl2thresh,orient=\"index\").reset_index().rename(columns={\"index\" : \"log ARL\", 0 : \"threshold\"})\n",
    "df[\"d\"] = d\n",
    "df[\"data\"] = \"N(0,I)\"\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b2dfc598-4037-4997-bc8b-d4296c7681f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"../results/normal_thresholds.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e24353ad-50dc-42ff-a031-1932b4b0e478",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
