{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "35de5c1e-9130-4019-b2ab-805bfda44279",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-07-29 09:01:26.682840: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2025-07-29 09:01:26.687848: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n",
      "2025-07-29 09:01:26.701728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "E0000 00:00:1753772486.723799 1470003 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "E0000 00:00:1753772486.730252 1470003 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "W0000 00:00:1753772486.747545 1470003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1753772486.747570 1470003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1753772486.747574 1470003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "W0000 00:00:1753772486.747579 1470003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n",
      "2025-07-29 09:01:26.752975: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from online_rff_mmd import kernel, util, data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74d6d3bf-706f-4d5b-a3e1-28bdf9208760",
   "metadata": {},
   "source": [
    "As in \"[Online Kernel CUSUM for Change-Point Detection](https://arxiv.org/pdf/2211.15070)\" Section 5.2: \"Walking\" -> \"Staying\" of subject 101. Based on their results the most difficult setting (the most misses)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "19a5b1d0-3a26-4c17-a588-745ab777f41f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hasc = data.HASC()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5ca01037-cbe6-4d7b-ba71-da8c33e3c01a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(1.552088888358322)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gamma = kernel.Gauss.est_gamma(hasc.h0(num=0))\n",
    "gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4aaee2aa-dfc7-414e-90d7-4e372a759427",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_alg(name, ref_sample):\n",
    "    if name == \"streamingrffmmd\":\n",
    "        return kernel.StreamingRFFMMD(kernel=kernel.Gauss(gamma=gamma), d=3, num_omegas=1000)\n",
    "    elif name == \"newma\":\n",
    "        return kernel.NewMAAdapter(reference_sample=ref_sample, d=3, B=200)\n",
    "    elif name == \"scanb\":\n",
    "        return kernel.ScanBStatistic(reference_sample=ref_sample, B0=114, N=14, gamma=gamma)\n",
    "    elif name == \"okcusum\":\n",
    "        return kernel.OKCUSUM(reference_sample=ref_sample, B_max=114, N=14, gamma=gamma)\n",
    "    else:\n",
    "        raise ValueError(\"Invalid argument\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2e6a32e9-645a-4209-b185-d81372fbdcf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "len_ref = 14*114 # based on the discussion in their appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "32f3d957-67b4-499c-9363-0e4c73fa71f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "abd00111-9052-47aa-9ba5-4fc5bd2a8c5d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "streamingrffmmd ... working ...\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'streamingrffmmd': np.float64(2.68840147621956)}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "alg2threshold = {}\n",
    "for name_alg in [\"streamingrffmmd\"]: #, \"newma\", \"scanb\", \"okcusum\"]:\n",
    "    print(name_alg, \"... working ...\")\n",
    "    stats = []\n",
    "    \n",
    "    for i in range(10):\n",
    "        w = hasc.h0(num=i)\n",
    "        \n",
    "        alg = get_alg(name_alg, ref_sample=w.values[:len_ref])\n",
    "        for elem in w.values[len_ref:]:\n",
    "            alg.insert(elem)\n",
    "            stats += [alg.statistic()]\n",
    "    alg2threshold[name_alg] = np.quantile(stats, 1-alpha)\n",
    "alg2threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "69dc3b98-c109-4f06-9c1d-2dfc1df3ecc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "streamingrffmmd ... working ...\n"
     ]
    }
   ],
   "source": [
    "results = pd.DataFrame()\n",
    "\n",
    "for name_alg in [\"streamingrffmmd\"]: #, \"newma\", \"scanb\", \"okcusum\"]:\n",
    "    print(name_alg, \"... working ...\")\n",
    "    delays = []\n",
    "    \n",
    "    for i in range(10):\n",
    "        w = hasc.h0(num=i)\n",
    "        \n",
    "        s = hasc.h1(num=i)\n",
    "        \n",
    "        alg = get_alg(name_alg, ref_sample=w.values[:len_ref])\n",
    "        for i, elem in enumerate(np.concatenate((w.values[len_ref:len_ref+100], s.values[:100]))):\n",
    "            alg.insert(elem)\n",
    "            if alg.statistic() >= alg2threshold[name_alg]:\n",
    "                delays += [i]\n",
    "                break\n",
    "    d = np.array(delays)\n",
    "    results = pd.concat((results, pd.DataFrame({\n",
    "        \"Algorithm\" : [name_alg],\n",
    "        \"Average delay\" : [np.mean(d[d>=100]-100)],\n",
    "        \"Too early\" : [np.sum(d<100)],\n",
    "        \"Too late\" : [10 - len(d)]\n",
    "    })))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "46897533-138f-46a7-ba79-559b21f844b7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Algorithm</th>\n",
       "      <th>Average delay</th>\n",
       "      <th>Too early</th>\n",
       "      <th>Too late</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>streamingrffmmd</td>\n",
       "      <td>21.857143</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         Algorithm  Average delay  Too early  Too late\n",
       "0  streamingrffmmd      21.857143          2         1"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "9b955e7f-bc96-4a6a-9d45-02faf82476e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lrrr}\n",
      "\\toprule\n",
      "Algorithm & Average delay & Too early & Too late \\\\\n",
      "\\midrule\n",
      "streamingrffmmd & 21.857143 & 2 & 1 \\\\\n",
      "newma & 34.250000 & 1 & 5 \\\\\n",
      "scanb & 31.200000 & 0 & 0 \\\\\n",
      "okcusum & 17.444444 & 1 & 0 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(results.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c307fc2b-0a82-423c-b3e9-cff2773e29a3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
