{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nuq import NuqClassifier\n",
    "\n",
    "import numpy as np\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_points = 100\n",
    "circle_r = 10\n",
    "dim = 2\n",
    "h = 4\n",
    "\n",
    "method = 'all_data' # all_data\n",
    "kernel = 'RBF'\n",
    "use_uniform_prior = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_points_on_circle(circle_r=10, circle_x=0, circle_y=0, n_points=10, uniform=True):\n",
    "    if not uniform:\n",
    "        alpha = 2 * np.pi * np.random.rand(n_points)\n",
    "    else:\n",
    "        alpha = 2 * np.pi * np.linspace(start=0, stop=1, num=n_points)\n",
    "\n",
    "    x = circle_r * np.cos(alpha) + circle_x\n",
    "    y = circle_r * np.sin(alpha) + circle_y\n",
    "\n",
    "    return np.vstack((x, y)).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = generate_points_on_circle(n_points=n_points, circle_r=circle_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib widget"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86c39ddf8a3b4cfc969de47655528df6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.close()\n",
    "plt.figure()\n",
    "\n",
    "plt.scatter(X[:, 0], X[:, 1])\n",
    "plt.scatter(0, 0, color='red')\n",
    "plt.axis('equal')\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "nuq = NuqClassifier(\n",
    "    kernel_type=kernel,\n",
    "    use_uniform_prior=use_uniform_prior,\n",
    "    method=method,\n",
    "    tune_bandwidth=False,\n",
    "    bandwidth=np.array([h, h])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.zeros(X.shape[0], dtype='int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "NuqClassifier(bandwidth=array([4, 4]), method='all_data', n_neighbors=100,\n",
       "              tune_bandwidth=False, use_uniform_prior=False)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nuq.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_circle(nuq_trained, dim, r, i=0):\n",
    "    origin = np.zeros((1, dim))\n",
    "    rand_vect = np.random.randn(1, dim)\n",
    "    N = nuq_trained.n_neighbors\n",
    "    \n",
    "    kde_pred = nuq_trained.get_kde(origin)\n",
    "    kernel_output = nuq_trained.kernel(rand_vect, rand_vect)\n",
    "    uncertainties_pred = nuq_trained.predict_uncertainty(np.zeros((1, 2)), )\n",
    "    \n",
    "    aleatoric_pred = uncertainties_pred['aleatoric']\n",
    "    epistemic_pred = uncertainties_pred['epistemic']\n",
    "    \n",
    "    if len(np.unique(nuq_trained.bandwidth)) > 1:\n",
    "        raise ValueError('Bandwidth must be isotropic')\n",
    "    h = nuq_trained.bandwidth[0]\n",
    "    KDE_error = np.abs(-dim * (np.log(2 * np.pi) * 0.5 + np.log(h)) - (r**2 / (2 * h**2)) - kde_pred)\n",
    "    Kernel_error = np.abs(-dim / 2. * np.log(2 * np.pi) - kernel_output)\n",
    "    Aleatoric_error = np.abs(np.log(nuq_trained.coeff + min(1 - i / N, i / N)) - aleatoric_pred)\n",
    "    Epistemic_error = np.abs((6 - dim) * 0.25 * np.log(2.) - 0.5 * np.log(np.pi) - 0.5 * np.log(nuq_trained.n_neighbors) + \n",
    "                             0.5 * np.log(nuq_trained.coeff + i * (N - i) / N ** 2) + (r**2 / (4 * h**2)) - epistemic_pred)\n",
    "    \n",
    "    \n",
    "    print(f'{KDE_error=}')\n",
    "    print(f'{Kernel_error=}')\n",
    "    print(f'{Aleatoric_error=}')\n",
    "    print(f'{Epistemic_error=}')\n",
    "    \n",
    "    return {\n",
    "        \"kde_pred\": kde_pred,\n",
    "        \"kernel_output\": kernel_output,\n",
    "        **uncertainties_pred\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(\"KDE error:\", -dim * (np.log(2 * np.pi) * 0.5 + np.log(h)) - (circle_r**2 / (2 * h**2)) - nuq.get_kde(np.zeros((1, 2))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(\"Ue error:\", (6 - dim) * 0.25 * np.log(2.) - 0.5 * np.log(np.pi) - 0.5 * np.log(nuq.n_neighbors) + 0.5 * np.log(nuq.coeff) + (circle_r**2 / (4 * h**2))- uncertainties['epistemic'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rand_vect = np.random.randn(1, 2)\n",
    "# print(\"Kernel:\", -dim / 2. * np.log(2 * np.pi) - nuq.kernel(rand_vect, rand_vect))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> \u001b[0;32m/home/nkotelevskii/github/nw_uncertainty/nuq/method/nuq.py\u001b[0m(152)\u001b[0;36mpredict_uncertainty\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m    150 \u001b[0;31m                \u001b[0;32mimport\u001b[0m \u001b[0mpdb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    151 \u001b[0;31m                \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m--> 152 \u001b[0;31m                \u001b[0mUe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_half_gaussian_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0masymptotic_var\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlog_as_var\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    153 \u001b[0;31m                \u001b[0mUa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf1_hat_y_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    154 \u001b[0;31m                \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_uniform_prior\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  log_half_gaussian_mean(asymptotic_var=log_as_var)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array([[-6.37576559]])\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  np.min(f1_hat_y_x, axis=1, keepdims=True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array([[-inf]])\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  n\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> \u001b[0;32m/home/nkotelevskii/github/nw_uncertainty/nuq/method/nuq.py\u001b[0m(153)\u001b[0;36mpredict_uncertainty\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m    151 \u001b[0;31m                \u001b[0mpdb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    152 \u001b[0;31m                \u001b[0mUe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_half_gaussian_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0masymptotic_var\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlog_as_var\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m--> 153 \u001b[0;31m                \u001b[0mUa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf1_hat_y_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    154 \u001b[0;31m                \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_uniform_prior\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    155 \u001b[0;31m                    Ua = logsumexp(\n",
      "\u001b[0m\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  n\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> \u001b[0;32m/home/nkotelevskii/github/nw_uncertainty/nuq/method/nuq.py\u001b[0m(154)\u001b[0;36mpredict_uncertainty\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m    152 \u001b[0;31m                \u001b[0mUe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_half_gaussian_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0masymptotic_var\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlog_as_var\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    153 \u001b[0;31m                \u001b[0mUa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf1_hat_y_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m--> 154 \u001b[0;31m                \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_uniform_prior\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m    155 \u001b[0;31m                    Ua = logsumexp(\n",
      "\u001b[0m\u001b[0;32m    156 \u001b[0;31m                        np.concatenate([Ua[None], np.log(self.coeff) * np.ones(shape=broadcast_shape)],\n",
      "\u001b[0m\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  logsumexp(np.concatenate([Ua[None], np.log(self.coeff) * np.ones(shape=broadcast_shape)], axis=0).squeeze(), axis=0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-11.512925464970229\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  Ue\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array(-6.37576559)\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  log_half_gaussian_mean(asymptotic_var=log_as_var).squeeze()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array(-6.37576559)\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  log_half_gaussian_mean(asymptotic_var=log_as_var).squeeze().squeeze(0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*** SyntaxError: unexpected EOF while parsing\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  log_half_gaussian_mean(asymptotic_var=log_as_var).squeeze().squeeze()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array(-6.37576559)\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  logsumexp(np.concatenate([Ua[None], np.log(self.coeff) * np.ones(shape=broadcast_shape)], axis=0), axis=0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array([[-11.51292546]])\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  logsumexp(np.concatenate([Ua[None], np.log(self.coeff) * np.ones(shape=broadcast_shape)], axis=0), axis=0).squeeze()\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "array(-11.51292546)\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "ipdb>  q\n"
     ]
    },
    {
     "ename": "BdbQuit",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mBdbQuit\u001b[0m                                   Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-15-6d980b134a45>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtest_circle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnuq_trained\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnuq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcircle_r\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-10-badb2019be56>\u001b[0m in \u001b[0;36mtest_circle\u001b[0;34m(nuq_trained, dim, r, i)\u001b[0m\n\u001b[1;32m      6\u001b[0m     \u001b[0mkde_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnuq_trained\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_kde\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morigin\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mkernel_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnuq_trained\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrand_vect\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrand_vect\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m     \u001b[0muncertainties_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnuq_trained\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_uncertainty\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m     \u001b[0maleatoric_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0muncertainties_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'aleatoric'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/github/nw_uncertainty/nuq/method/nuq.py\u001b[0m in \u001b[0;36mpredict_uncertainty\u001b[0;34m(self, X, batch_size)\u001b[0m\n\u001b[1;32m    152\u001b[0m                 \u001b[0mUe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_half_gaussian_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0masymptotic_var\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlog_as_var\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    153\u001b[0m                 \u001b[0mUa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf1_hat_y_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_uniform_prior\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    155\u001b[0m                     Ua = logsumexp(\n\u001b[1;32m    156\u001b[0m                         \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mUa\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoeff\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbroadcast_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/github/nw_uncertainty/nuq/method/nuq.py\u001b[0m in \u001b[0;36mpredict_uncertainty\u001b[0;34m(self, X, batch_size)\u001b[0m\n\u001b[1;32m    152\u001b[0m                 \u001b[0mUe\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_half_gaussian_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0masymptotic_var\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlog_as_var\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    153\u001b[0m                 \u001b[0mUa\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf1_hat_y_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeepdims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 154\u001b[0;31m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_uniform_prior\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    155\u001b[0m                     Ua = logsumexp(\n\u001b[1;32m    156\u001b[0m                         \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcatenate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mUa\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoeff\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbroadcast_shape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/condatorch/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mtrace_dispatch\u001b[0;34m(self, frame, event, arg)\u001b[0m\n\u001b[1;32m     86\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0;31m# None\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     87\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'line'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     89\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mevent\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'call'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     90\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/condatorch/lib/python3.8/bdb.py\u001b[0m in \u001b[0;36mdispatch_line\u001b[0;34m(self, frame)\u001b[0m\n\u001b[1;32m    111\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstop_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbreak_here\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    112\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_line\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mframe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 113\u001b[0;31m             \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquitting\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mBdbQuit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    114\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_dispatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mBdbQuit\u001b[0m: "
     ]
    }
   ],
   "source": [
    "_ = test_circle(nuq_trained=nuq, dim=dim, r=circle_r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = np.random.permutation(np.arange(n_points))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.zeros(X.shape[0], dtype='int')\n",
    "\n",
    "dict_of_results = {\n",
    "    'epistemic': [],\n",
    "    'aleatoric': [],\n",
    "    'total': [],\n",
    "    'KDE': [],\n",
    "    'kernel': []\n",
    "}\n",
    "\n",
    "for i in tqdm(range(n_points + 1)):\n",
    "    \n",
    "    if i > 0:\n",
    "        ind = indices[i - 1]\n",
    "        y[ind] = 1\n",
    "    \n",
    "    nuq = NuqClassifier(\n",
    "        method=method,\n",
    "        tune_bandwidth=False,\n",
    "        bandwidth=np.array([h, h])\n",
    "    )\n",
    "    \n",
    "    nuq.fit(X, y)\n",
    "    res = test_circle(nuq_trained=nuq, dim=dim, r=circle_r, i=i)\n",
    "    \n",
    "    dict_of_results['epistemic'].append(res['epistemic'].squeeze())\n",
    "    dict_of_results['aleatoric'].append(res['aleatoric'].squeeze())\n",
    "    dict_of_results['total'].append(res['total'].squeeze())\n",
    "    dict_of_results['KDE'].append(res['kde_pred'].squeeze())\n",
    "    dict_of_results['kernel'].append(res['kernel_output'].squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.close()\n",
    "_, ax = plt.subplots(nrows=1, ncols=len(dict_of_results.keys()), sharex=True, figsize=(12, 4))\n",
    "\n",
    "for i, key in enumerate(dict_of_results.keys()):\n",
    "    ax[i].set_title(key)\n",
    "    ax[i].plot(np.arange(n_points + 1), dict_of_results[key])\n",
    "\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.zeros(X.shape[0], dtype='int')\n",
    "\n",
    "dict_of_results = {\n",
    "    'epistemic': [],\n",
    "    'aleatoric': [],\n",
    "    'total': [],\n",
    "    'KDE': [],\n",
    "    'kernel': []\n",
    "}\n",
    "\n",
    "radii = [1, 5, 10, 20, 50, 100]\n",
    "\n",
    "for r in tqdm(radii):\n",
    "    \n",
    "    X = generate_points_on_circle(n_points=n_points, circle_r=r)\n",
    "    \n",
    "    nuq = NuqClassifier(\n",
    "        method=method,\n",
    "        tune_bandwidth=False,\n",
    "        bandwidth=np.array([h, h])\n",
    "    )\n",
    "    \n",
    "    nuq.fit(X, y)\n",
    "    res = test_circle(nuq_trained=nuq, dim=dim, r=r, i=0)\n",
    "    \n",
    "    dict_of_results['epistemic'].append(res['epistemic'].squeeze())\n",
    "    dict_of_results['aleatoric'].append(res['aleatoric'].squeeze())\n",
    "    dict_of_results['total'].append(res['total'].squeeze())\n",
    "    dict_of_results['KDE'].append(res['kde_pred'].squeeze())\n",
    "    dict_of_results['kernel'].append(res['kernel_output'].squeeze())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.close()\n",
    "_, ax = plt.subplots(nrows=1, ncols=len(dict_of_results.keys()), sharex=True, figsize=(12, 4))\n",
    "\n",
    "for i, key in enumerate(dict_of_results.keys()):\n",
    "    ax[i].set_title(key)\n",
    "    ax[i].plot(radii, dict_of_results[key])\n",
    "\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Condatorch",
   "language": "python",
   "name": "condatorch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
