{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "24d1e519-0f85-4e46-83f9-d722aa6f8df4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append(\"../src/\")\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import kernel, util\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from sklearn import linear_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ab09343c-2742-4dda-966e-ab5c359b2cad",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4026dabf-01c5-4d63-9490-fe3329fbafa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "lengths = 2**np.arange(7,19)\n",
    "num_omegas = 10\n",
    "d=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "74ee9787-fb1e-465c-bb85-9abb73fd3e8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([   128,    256,    512,   1024,   2048,   4096,   8192,  16384,\n",
       "        32768,  65536, 131072, 262144])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b546083c-75c1-45c0-b1b5-8b5b9a12d915",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bb05b6f5-56d5-45b4-aa5b-652107a15f30",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = pd.DataFrame()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ec81e9c-135e-4714-9394-fd2fc61a0571",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                    | 0/10 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "for _ in tqdm(range(10)):\n",
    "    for num_omegas in [10, 50, 100, 500, 1000]:\n",
    "        for l in reversed(lengths):\n",
    "            X = rng.normal(size=(l,d))\n",
    "            cd = kernel.StreamingRFFMMD(kernel=kernel.Gauss(gamma=1),d=d,num_omegas=num_omegas)\n",
    "            with util.ContextTimer() as t:\n",
    "                for e in X:\n",
    "                    cd.insert(e)\n",
    "                    cd.mmd_values()\n",
    "            results = pd.concat((results,pd.DataFrame({\"l\" : [l], \n",
    "                                                       \"total\" : [t.secs],\n",
    "                                                       \"per_insert\" : [t.secs / l],\n",
    "                                                       \"r\" : [num_omegas]})))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50b5a401-6a35-4daa-ac5a-8948a81d036a",
   "metadata": {},
   "outputs": [],
   "source": [
    "results.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e8e261-3aa9-4b62-996a-efa8c38d0430",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1,1)\n",
    "sns.lineplot(data=results, x=\"l\", y=\"per_insert\", hue=\"r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fb984cf-d190-4c60-866c-2c57e6c6ccf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "results.to_csv(\"../results/runtimes.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "786256d7-ad4b-414f-9469-9377b39bef23",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = results.groupby(\"l\").mean()[\"per_insert\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "f02c0921-21a1-4624-a196-be3c64b37788",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9981439103056091, array([6.97460331e-06]))"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = linear_model.LinearRegression()\n",
    "X = np.log(lengths.reshape(-1,1))**1\n",
    "#y = results[\"per_insert\"]\n",
    "model.fit(X, y)\n",
    "model.score(X, y), model.coef_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d60f0f0-fc06-4b8c-8065-ef4a740a24d6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
