{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b3960d83-1962-4f28-bdbc-4ebf6caf3927",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-05-06 17:30:58.503723: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2025-05-06 17:30:58.507792: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2025-05-06 17:30:58.520047: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1746545458.540912   57012 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1746545458.546818   57012 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "W0000 00:00:1746545458.562724   57012 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1746545458.562747   57012 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1746545458.562749   57012 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1746545458.562751   57012 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "2025-05-06 17:30:58.567822: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import sys; sys.path.append(\"../src/\")\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import kernel\n",
    "from tqdm import tqdm\n",
    "from itertools import chain\n",
    "import pandas as pd\n",
    "import data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bd160d24-7d86-46fa-82dc-1bf2866efd7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88fe108b-7c73-4db0-b40d-9ef87587520e",
   "metadata": {},
   "source": [
    "# MNIST data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcfb8eb2-73fb-49c7-af72-a2da09abec8f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "01193050-b29d-48e2-b9c2-4c943540402e",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 784"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "17124c08-0434-42a0-a225-5c59abdc5b07",
   "metadata": {},
   "outputs": [],
   "source": [
    "zeros = data.MNIST(1024, digit=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a93929b5-849a-43b2-89c3-723c3f36e7e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "gamma = kernel.Gauss.est_gamma(zeros.draw())\n",
    "\n",
    "gauss = kernel.Gauss(gamma=gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c90a212a-5806-403b-9185-790ba901027b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.6499206594402888e-07\n"
     ]
    }
   ],
   "source": [
    "print(gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2db60183-cfbe-4a23-ae6c-d010f7c7ee6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████| 100/100 [02:10<00:00,  1.31s/it]\n"
     ]
    }
   ],
   "source": [
    "mmd_vals = []\n",
    "\n",
    "for _ in tqdm(range(100)):\n",
    "    acc = []\n",
    "    cd = kernel.StreamingRFFMMD(gauss,d=d,num_omegas=1000)\n",
    "    for i, elem in enumerate(zeros.draw()):\n",
    "        cd.insert(elem)\n",
    "        acc += [np.max(cd.normalized_mmd()+ [0])]\n",
    "    mmd_vals += [acc]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9da8e39d-59fd-4dde-8acc-68d6cb80044b",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_vals = list(chain(*mmd_vals))\n",
    "\n",
    "target_arls_log = np.arange(3,5.1,.25)\n",
    "arl2thresh = { i : np.quantile(all_vals, 1-(1/10**i)) for i in target_arls_log}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "df38b0c5-5644-4ff5-8ac2-54002f739e3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>log ARL</th>\n",
       "      <th>threshold</th>\n",
       "      <th>d</th>\n",
       "      <th>data</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.00</td>\n",
       "      <td>1.199031</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3.25</td>\n",
       "      <td>1.221040</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3.50</td>\n",
       "      <td>1.236295</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3.75</td>\n",
       "      <td>1.258325</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4.00</td>\n",
       "      <td>1.275323</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>4.25</td>\n",
       "      <td>1.292688</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4.50</td>\n",
       "      <td>1.318921</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>4.75</td>\n",
       "      <td>1.328930</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5.00</td>\n",
       "      <td>1.349832</td>\n",
       "      <td>784</td>\n",
       "      <td>MNIST 0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   log ARL  threshold    d     data\n",
       "0     3.00   1.199031  784  MNIST 0\n",
       "1     3.25   1.221040  784  MNIST 0\n",
       "2     3.50   1.236295  784  MNIST 0\n",
       "3     3.75   1.258325  784  MNIST 0\n",
       "4     4.00   1.275323  784  MNIST 0\n",
       "5     4.25   1.292688  784  MNIST 0\n",
       "6     4.50   1.318921  784  MNIST 0\n",
       "7     4.75   1.328930  784  MNIST 0\n",
       "8     5.00   1.349832  784  MNIST 0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(arl2thresh,orient=\"index\").reset_index().rename(columns={\"index\" : \"log ARL\", 0 : \"threshold\"})\n",
    "df[\"d\"] = d\n",
    "df[\"data\"] = \"MNIST 0\"\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b2dfc598-4037-4997-bc8b-d4296c7681f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"../results/mnist_0_thresholds.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e24353ad-50dc-42ff-a031-1932b4b0e478",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
