{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import imageio\n",
    "\n",
    "from utils import introduce_missing_data\n",
    "from utils import normalization, renormalization\n",
    "from utils import compute_normalised_rmse\n",
    "from utils import log_likelihood_model1, log_likelihood_model2, log_likelihood_model3, log_likelihood_model4\n",
    "\n",
    "from sklearn.impute import KNNImputer\n",
    "from sklearn.experimental import enable_iterative_imputer\n",
    "from sklearn.impute import IterativeImputer\n",
    "from sklearn.ensemble import ExtraTreesRegressor\n",
    "from sklearn.linear_model import BayesianRidge\n",
    "\n",
    "from scipy.special import softmax\n",
    "from sklearn.metrics.pairwise import nan_euclidean_distances\n",
    "\n",
    "from warnings import simplefilter\n",
    "from sklearn.exceptions import ConvergenceWarning\n",
    "simplefilter(\"ignore\", category=ConvergenceWarning)\n",
    "\n",
    "import os\n",
    "os.chdir(\"GAIN\")\n",
    "from gain import gain\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_receivers(norm_miss_data, current_miss_pattern):\n",
    "    \"\"\"Select the observations matching the missing pattern.\n",
    "    Args:\n",
    "        - norm_miss_data: normalized missing data, shape (n, d)\n",
    "        - current_miss_pattern: current missing pattern, boolean of shape (d, )\n",
    "    Returns:\n",
    "        - id_receivers: list of id corresponding to the rows matching the missing pattern\n",
    "    \"\"\"\n",
    "    (n, d) = norm_miss_data.shape\n",
    "    final_filter = np.ones(n).astype(\"bool\")\n",
    "    for i in range(d):\n",
    "        cur_filter = (np.isnan(norm_miss_data[:, i]) == current_miss_pattern[i])\n",
    "        final_filter = np.logical_and(final_filter, cur_filter)\n",
    "    id_receivers = np.where(final_filter)[0]\n",
    "    return id_receivers\n",
    "\n",
    "\n",
    "def select_givers(norm_miss_data, current_miss_pattern):\n",
    "    \"\"\"Select the observations having all entries for the missing pattern.\n",
    "    Args:\n",
    "        - norm_miss_data: normalized missing data, shape (n, d)\n",
    "        - current_miss_pattern: current missing pattern, boolean of shape (d, )\n",
    "    Returns:\n",
    "        - id_givers: list of id corresponding to potential givers for kNNxKDE\n",
    "    \"\"\"\n",
    "    (n, d) = norm_miss_data.shape\n",
    "    final_filter = np.ones(n).astype(\"bool\")\n",
    "    for i in range(d):\n",
    "        if current_miss_pattern[i]:\n",
    "            cur_filter = (np.isnan(norm_miss_data[:, i]) != current_miss_pattern[i])\n",
    "            final_filter = np.logical_and(final_filter, cur_filter)\n",
    "    id_givers = np.where(final_filter)[0]\n",
    "    return id_givers\n",
    "    \n",
    "\n",
    "def kNNxKDE(norm_miss_data, h=0.03, tau=50.0, nb_draws=10000):\n",
    "    \"\"\"The kNNxKDE algorithm!!!\n",
    "    Args:\n",
    "        - norm_miss_data: normalized missing data, shape (n, d)\n",
    "        - h: standard deviation (bandwidth) of the Gaussian kernel KDE (default=0.03)\n",
    "        - tau: temperature for the softmax (default=50.0)\n",
    "        - nb_draws: number of draws per missing entry (default=10000)\n",
    "    Returns:\n",
    "        - imputed_samples: a dictionary using tuples for key. The entry in\n",
    "        imputed_samples[(i, j)] in a numpy array with nb_draws samples,\n",
    "        representing the estimated distribution for the missing cell (i, j)\n",
    "    \"\"\"\n",
    "    (n, d) = norm_miss_data.shape\n",
    "    all_miss_patterns = np.unique(np.isnan(norm_miss_data), axis=0)\n",
    "    imputed_samples = dict()\n",
    "    \n",
    "    for current_miss_pattern in all_miss_patterns:\n",
    "        if not np.logical_or.reduce(current_miss_pattern):  # if there is no missing value\n",
    "            print(f\"Not done: {current_miss_pattern}\")\n",
    "            continue\n",
    "        if np.logical_and.reduce(current_miss_pattern):  # if there are only missing values\n",
    "            print(f\"Not done: {current_miss_pattern}\")\n",
    "            continue\n",
    "        \n",
    "        id_receivers = select_receivers(norm_miss_data, current_miss_pattern)\n",
    "        id_givers = select_givers(norm_miss_data, current_miss_pattern)\n",
    "        \n",
    "        data_receivers = norm_miss_data[id_receivers]\n",
    "        data_givers = norm_miss_data[id_givers]\n",
    "        \n",
    "        d_ij = nan_euclidean_distances(data_receivers, data_givers)\n",
    "        d_ij[np.isnan(d_ij)] = np.inf\n",
    "        p_ij = softmax(- tau * d_ij, axis=1)\n",
    "        \n",
    "        cur_sample = np.zeros(nb_draws)\n",
    "        for i1 in range(len(id_receivers)):\n",
    "            probs = p_ij[i1]\n",
    "            neighbors = np.random.choice(len(id_givers), p=probs, size=nb_draws)  # Corresponding shuffled id\n",
    "            current_sample = data_givers[neighbors] + np.random.normal(loc=0.0, scale=h, size=(nb_draws, d))\n",
    "            for i2 in range(d):\n",
    "                if current_miss_pattern[i2]:\n",
    "                    imputed_samples[(id_receivers[i1], i2)] = current_sample[:, i2]\n",
    "    \n",
    "    return imputed_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def renormalization_samples(imputed_samples, norm_parameters):\n",
    "    \"\"\"Renormalize imputed sample from [0, 1] back to the original range.\n",
    "    This is specific to kNNxKDE since it returns a sample.\n",
    "    Args:\n",
    "        - imputed_samples: output of kNNxKDE, dictionary with data in [0, 1]\n",
    "        - norm_parameters: min_val and max_val for each column\n",
    "    Returns:\n",
    "        - renorm_samples: renormalized samples in the original range, dictionary\n",
    "    \"\"\"\n",
    "    min_val = norm_parameters[\"min_val\"]\n",
    "    max_val = norm_parameters[\"max_val\"]\n",
    "    \n",
    "    renorm_samples = dict()\n",
    "    for key in imputed_samples.keys():\n",
    "        i = key[1]\n",
    "        current_sample = imputed_samples[key]\n",
    "        renorm_current_sample = current_sample * (max_val[i] + 1e-6)\n",
    "        renorm_current_sample = renorm_current_sample + min_val[i]\n",
    "        renorm_samples[key] = renorm_current_sample\n",
    "    \n",
    "    return renorm_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_gif_3d_sphere(miss_data, renorm_imputed_data, name):\n",
    "    mask_miss1 = np.logical_and.reduce(np.isnan(miss_data) == [True, False, False], axis=1)\n",
    "    mask_miss2 = np.logical_and.reduce(np.isnan(miss_data) == [False, True, False], axis=1)\n",
    "    mask_miss3 = np.logical_and.reduce(np.isnan(miss_data) == [False, False, True], axis=1)\n",
    "    mask_miss12 = np.logical_and.reduce(np.isnan(miss_data) == [True, True, False], axis=1)\n",
    "    mask_miss13 = np.logical_and.reduce(np.isnan(miss_data) == [True, False, True], axis=1)\n",
    "    mask_miss23 = np.logical_and.reduce(np.isnan(miss_data) == [False, True, True], axis=1)\n",
    "    mask_full = np.logical_and.reduce(np.isnan(miss_data) == [False, False, False], axis=1)\n",
    "    \n",
    "    list_masks = [mask_full, mask_miss1, mask_miss2, mask_miss3, mask_miss12, mask_miss13, mask_miss23]\n",
    "    list_sizes = [10, 40, 40, 40, 40, 40, 40]\n",
    "    list_colors = [\"C0\", \"C1\", \"C3\", \"C5\", \"C4\", \"C2\", \"C6\"]\n",
    "    list_labels = [\"Complete\", \"Miss $x_1$\", \"Miss $x_2$\", \"Miss $x_3$\", \"Miss $x_1, x_2$\", \"Miss $x_1, x_3$\", \"Miss $x_2, x_3$\"]\n",
    "\n",
    "    if not os.path.exists(f\"neurips2022_figures/gif_3d/{name}\"):\n",
    "        os.makedirs(f\"neurips2022_figures/gif_3d/{name}\")\n",
    "    \n",
    "    for angle in range(0, 360, 5):\n",
    "        print(angle, end=\" \")\n",
    "        fig = plt.figure(figsize=(10, 10))\n",
    "        ax = fig.add_subplot(projection=\"3d\")\n",
    "        if angle<=180:\n",
    "            ax.view_init(angle/4, angle)\n",
    "        if angle>180:\n",
    "            ax.view_init(90-angle/4, angle)\n",
    "        for k in range(7):\n",
    "            cur_mask = list_masks[k]\n",
    "            data1 = renorm_imputed_data[cur_mask, 0]\n",
    "            data2 = renorm_imputed_data[cur_mask, 1]\n",
    "            data3 = renorm_imputed_data[cur_mask, 2]\n",
    "            ax.scatter(data1, data2, data3, s=list_sizes[k], alpha=0.5, color=list_colors[k], label=list_labels[k])\n",
    "        ax.set_xlabel(\"$x_1$\", size=14)\n",
    "        ax.set_ylabel(\"$x_2$\", size=14)\n",
    "        ax.set_zlabel(\"$x_3$\", size=14)\n",
    "        fig.legend()\n",
    "        fig.savefig(f\"neurips2022_figures/gif_3d/{name}/angle{angle}.png\")\n",
    "        plt.close()\n",
    "\n",
    "    images = [] # Make a GIF\n",
    "    for angle in range(0, 360, 5):\n",
    "        filename = f\"neurips2022_figures/gif_3d/{name}/angle{angle}.png\"\n",
    "        images.append(imageio.imread(filename))\n",
    "    imageio.mimsave(f\"neurips2022_figures/gif_3d/{name}.gif\", images, fps=10)\n",
    "    print(\"GIF done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot 3d sphere"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### kNN Imputer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.25211956154329696\n",
      "-2129.954351306573\n",
      "-5682.965028561877\n"
     ]
    }
   ],
   "source": [
    "original_data = np.genfromtxt(\"datasets/dataset4.csv\", delimiter=\",\")\n",
    "np.random.seed(666)\n",
    "miss_rate = 0.2\n",
    "\n",
    "miss_data = introduce_missing_data(original_data, miss_rate)\n",
    "norm_miss_data, norm_params = normalization(data=miss_data)\n",
    "norm_original_data, _ = normalization(data=original_data, parameters=norm_params)\n",
    "\n",
    "imputer = KNNImputer(n_neighbors=20)  # 20 neighbors...\n",
    "norm_imputed_data = imputer.fit_transform(norm_miss_data)\n",
    "renorm_imputed_data = renormalization(norm_imputed_data, norm_params)\n",
    "\n",
    "#make_gif_3d_sphere(miss_data, renorm_imputed_data, name=\"knnimpute\")\n",
    "rmse = compute_normalised_rmse(norm_original_data, norm_miss_data, norm_imputed_data)\n",
    "print(rmse)\n",
    "loglik_original = log_likelihood_model4(original_data)\n",
    "print(loglik_original.sum())\n",
    "loglik = log_likelihood_model4(renorm_imputed_data)\n",
    "print(loglik.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MissForest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 205 210 215 220 225 230 235 240 245 250 255 260 265 270 275 280 285 290 295 300 305 310 315 320 325 330 335 340 345 350 355 GIF done!\n",
      "0.2755368711171749\n",
      "-2129.954351306573\n",
      "-4022.742979556496\n"
     ]
    }
   ],
   "source": [
    "original_data = np.genfromtxt(\"datasets/dataset4.csv\", delimiter=\",\")\n",
    "np.random.seed(666)\n",
    "miss_rate = 0.2\n",
    "\n",
    "miss_data = introduce_missing_data(original_data, miss_rate)\n",
    "norm_miss_data, norm_params = normalization(data=miss_data)\n",
    "norm_original_data, _ = normalization(data=original_data, parameters=norm_params)\n",
    "\n",
    "estimator = ExtraTreesRegressor(n_estimators=15)  # 10 trees...\n",
    "imputer = IterativeImputer(estimator=estimator, max_iter=10, tol=1e-1, verbose=0)\n",
    "norm_imputed_data = imputer.fit_transform(norm_miss_data)\n",
    "renorm_imputed_data = renormalization(norm_imputed_data, norm_params)\n",
    "\n",
    "make_gif_3d_sphere(miss_data, renorm_imputed_data, name=\"missforest\")\n",
    "rmse = compute_normalised_rmse(norm_original_data, norm_miss_data, norm_imputed_data)\n",
    "print(rmse)\n",
    "loglik_original = log_likelihood_model4(original_data)\n",
    "print(loglik_original.sum())\n",
    "loglik = log_likelihood_model4(renorm_imputed_data)\n",
    "print(loglik.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MICE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.24765220475352293\n",
      "-2129.954351306573\n",
      "-6308.570702068148\n"
     ]
    }
   ],
   "source": [
    "original_data = np.genfromtxt(\"datasets/dataset4.csv\", delimiter=\",\")\n",
    "np.random.seed(666)\n",
    "miss_rate = 0.2\n",
    "\n",
    "miss_data = introduce_missing_data(original_data, miss_rate)\n",
    "norm_miss_data, norm_params = normalization(data=miss_data)\n",
    "norm_original_data, _ = normalization(data=original_data, parameters=norm_params)\n",
    "\n",
    "estimator = BayesianRidge()  # No hyperparameter here\n",
    "imputer = IterativeImputer(estimator=estimator, max_iter=10, tol=1e-1, verbose=0)\n",
    "norm_imputed_data = imputer.fit_transform(norm_miss_data)\n",
    "renorm_imputed_data = renormalization(norm_imputed_data, norm_params)\n",
    "\n",
    "#make_gif_3d_sphere(miss_data, renorm_imputed_data, name=\"mice\")\n",
    "rmse = compute_normalised_rmse(norm_original_data, norm_miss_data, norm_imputed_data)\n",
    "print(rmse)\n",
    "loglik_original = log_likelihood_model4(original_data)\n",
    "print(loglik_original.sum())\n",
    "loglik = log_likelihood_model4(renorm_imputed_data)\n",
    "print(loglik.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GAIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 205 210 215 220 225 230 235 240 245 250 255 260 265 270 275 280 285 290 295 300 305 310 315 320 325 330 335 340 345 350 355 GIF done!\n",
      "0.256785908536229\n",
      "-2129.954351306573\n",
      "-5793.318505298776\n"
     ]
    }
   ],
   "source": [
    "original_data = np.genfromtxt(\"datasets/dataset4.csv\", delimiter=\",\")\n",
    "np.random.seed(666)\n",
    "miss_rate = 0.2\n",
    "\n",
    "miss_data = introduce_missing_data(original_data, miss_rate)\n",
    "norm_miss_data, norm_params = normalization(data=miss_data)\n",
    "norm_original_data, _ = normalization(data=original_data, parameters=norm_params)\n",
    "\n",
    "gain_parameters = {\"batch_size\": 128, \"hint_rate\": 0.9, \"alpha\": 100, \"iterations\": 10000}\n",
    "gain_parameters[\"iterations\"] = 1200  # 1200 iterations...\n",
    "norm_imputed_data = gain(norm_miss_data, gain_parameters)\n",
    "renorm_imputed_data = renormalization(norm_imputed_data, norm_params)\n",
    "\n",
    "make_gif_3d_sphere(miss_data, renorm_imputed_data, name=\"gain\")\n",
    "rmse = compute_normalised_rmse(norm_original_data, norm_miss_data, norm_imputed_data)\n",
    "print(rmse)\n",
    "loglik_original = log_likelihood_model4(original_data)\n",
    "print(loglik_original.sum())\n",
    "loglik = log_likelihood_model4(renorm_imputed_data)\n",
    "print(loglik.sum())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### kNNxKDE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Not done: [False False False]\n",
      "0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 165 170 175 180 185 190 195 200 205 210 215 220 225 230 235 240 245 250 255 260 265 270 275 280 285 290 295 300 305 310 315 320 325 330 335 340 345 350 355 GIF done!\n"
     ]
    }
   ],
   "source": [
    "original_data = np.genfromtxt(\"datasets/dataset4.csv\", delimiter=\",\")\n",
    "np.random.seed(666)\n",
    "miss_rate = 0.2\n",
    "\n",
    "miss_data = introduce_missing_data(original_data, miss_rate)\n",
    "norm_miss_data, norm_params = normalization(data=miss_data)\n",
    "norm_original_data, _ = normalization(data=original_data, parameters=norm_params)\n",
    "\n",
    "imputed_samples = kNNxKDE(norm_miss_data, h=0.02, tau=100.0, nb_draws=10000)  # with default values\n",
    "renorm_samples = renormalization_samples(imputed_samples, norm_params)\n",
    "\n",
    "nb_sub_samples = 5\n",
    "(n, d) = miss_data.shape\n",
    "\n",
    "renorm_imputed_data = np.zeros((nb_sub_samples, n, d))  # size (nb_sub_samples, n, d)\n",
    "for i in range(nb_sub_samples):\n",
    "    renorm_imputed_data[i] = miss_data\n",
    "rand_ind = np.random.randint(low=0, high=10000, size=nb_sub_samples)\n",
    "for key in renorm_samples.keys():\n",
    "    renorm_imputed_data[:, key[0], key[1]] = renorm_samples[key][rand_ind]\n",
    "\n",
    "mask_miss1 = np.logical_and.reduce(np.isnan(miss_data) == [True, False, False], axis=1)\n",
    "mask_miss2 = np.logical_and.reduce(np.isnan(miss_data) == [False, True, False], axis=1)\n",
    "mask_miss3 = np.logical_and.reduce(np.isnan(miss_data) == [False, False, True], axis=1)\n",
    "mask_miss12 = np.logical_and.reduce(np.isnan(miss_data) == [True, True, False], axis=1)\n",
    "mask_miss13 = np.logical_and.reduce(np.isnan(miss_data) == [True, False, True], axis=1)\n",
    "mask_miss23 = np.logical_and.reduce(np.isnan(miss_data) == [False, True, True], axis=1)\n",
    "mask_full = np.logical_and.reduce(np.isnan(miss_data) == [False, False, False], axis=1)\n",
    "\n",
    "list_masks = [mask_full, mask_miss1, mask_miss2, mask_miss3, mask_miss12, mask_miss13, mask_miss23]\n",
    "list_sizes = [30, 10, 10, 10, 10, 10, 10]\n",
    "list_colors = [\"C0\", \"C1\", \"C3\", \"C5\", \"C4\", \"C2\", \"C6\"]\n",
    "list_labels = [\"Complete\", \"Miss $x_1$\", \"Miss $x_2$\", \"Miss $x_3$\", \"Miss $x_1, x_2$\", \"Miss $x_1, x_3$\", \"Miss $x_2, x_3$\"]\n",
    "    \n",
    "for angle in range(0, 360, 5):\n",
    "    print(angle, end=\" \")\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.add_subplot(projection=\"3d\")\n",
    "    if angle<=180:\n",
    "        ax.view_init(angle/4, angle)\n",
    "    if angle>180:\n",
    "        ax.view_init(90-angle/4, angle)\n",
    "    for k in range(7):\n",
    "        cur_mask = list_masks[k]\n",
    "        data1 = renorm_imputed_data[:, cur_mask, 0]\n",
    "        data2 = renorm_imputed_data[:, cur_mask, 1]\n",
    "        data3 = renorm_imputed_data[:, cur_mask, 2]\n",
    "        if k==0:\n",
    "            ax.scatter(data1[0], data2[0], data3[0], s=list_sizes[k], alpha=0.5, color=list_colors[k], label=list_labels[k])\n",
    "        else:\n",
    "            ax.scatter(data1[:], data2[:], data3[:], s=list_sizes[k], alpha=0.5, color=list_colors[k], label=list_labels[k])\n",
    "    ax.set_xlabel(\"$x_1$\", size=14)\n",
    "    ax.set_ylabel(\"$x_2$\", size=14)\n",
    "    ax.set_zlabel(\"$x_3$\", size=14)\n",
    "    fig.legend()\n",
    "    fig.savefig(f\"neurips2022_figures/gif_3d/knnxkde/angle{angle}.png\")\n",
    "    plt.close()\n",
    "\n",
    "images = []  # Make a GIF\n",
    "for angle in range(0, 360, 5):\n",
    "    filename = f\"neurips2022_figures/gif_3d/knnxkde/angle{angle}.png\"\n",
    "    images.append(imageio.imread(filename))\n",
    "imageio.mimsave(f\"neurips2022_figures/gif_3d/knnxkde.gif\", images, fps=10)\n",
    "print(\"GIF done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### kNNxKDE just one draw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Not done: [False False False]\n",
      "0.3580793672535385\n",
      "-2129.954351306573\n",
      "-3007.676442177979\n"
     ]
    }
   ],
   "source": [
    "original_data = np.genfromtxt(\"datasets/dataset4.csv\", delimiter=\",\")\n",
    "np.random.seed(666)\n",
    "miss_rate = 0.2\n",
    "\n",
    "miss_data = introduce_missing_data(original_data, miss_rate)\n",
    "norm_miss_data, norm_params = normalization(data=miss_data)\n",
    "norm_original_data, _ = normalization(data=original_data, parameters=norm_params)\n",
    "\n",
    "imputed_samples = kNNxKDE(norm_miss_data, h=0.03, tau=50.0, nb_draws=10000)  # with default values\n",
    "renorm_samples = renormalization_samples(imputed_samples, norm_params)\n",
    "\n",
    "norm_imputed_data = np.copy(norm_miss_data)\n",
    "renorm_imputed_data = np.copy(miss_data)\n",
    "r = np.random.randint(low=0, high=10000)\n",
    "for key in renorm_samples.keys():\n",
    "    norm_imputed_data[key[0], key[1]] = imputed_samples[key][r]\n",
    "    renorm_imputed_data[key[0], key[1]] = renorm_samples[key][r]\n",
    "\n",
    "#make_gif_3d_sphere(miss_data, renorm_imputed_data, name=\"knnxkde\")\n",
    "rmse = compute_normalised_rmse(norm_original_data, norm_miss_data, norm_imputed_data)\n",
    "print(rmse)\n",
    "loglik_original = log_likelihood_model4(original_data)\n",
    "print(loglik_original.sum())\n",
    "loglik = log_likelihood_model4(renorm_imputed_data)\n",
    "print(loglik.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
