{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot figures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-  **_Reproduce small sample versions_** of kernel thinning experiment figures from the paper https://arxiv.org/pdf/2105.05842.pdf\n",
    "    - The paper shows figures for coreset size sqrt(n) = 2^2 to 2^7\n",
    "    - This notebook plots the same figures for coreset size sqrt(n) = 2^1 to 2^4\n",
    "    - For reproducing identical figures, the slurm jobs need to be run with m_max set to 7; and then rerunning this notebook with ms = range(8) in the first cell in Section 1\n",
    "- Plot MMD_k(P, S), and MMD_k(Sin, S) for Gaussian(sigma) kernel for kernel thinning with\n",
    "    - P = N(0, I_d) in d = 2, 3, 4 dimensions; sigma = 1\n",
    "    - Mixture of Gaussian P (MoG) in d=2, for M = 4, 6, 8 mixtures; sigma = 1 (see the function compute_mog_params in util_sample, or the paper for details)\n",
    "    - 8 MCMC data settings; sigma = median_distance in mcmc data (d = 4)\n",
    "- Show scatter plot of iid and KT coresets for MoG settings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.1 import libraries, and fix plot setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import numpy.linalg as npl\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "import pickle as pkl\n",
    "\n",
    "# Fitting linear models\n",
    "import statsmodels.api as sm\n",
    "from scipy.stats import multivariate_normal\n",
    "\n",
    "# plottibg libraries\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import pylab\n",
    "import seaborn as sns\n",
    "plt.style.use('seaborn-white')\n",
    "\n",
    "# utils for generating samples, evaluating kernels, and mmds\n",
    "from util_sample import sample, compute_mcmc_params_p, compute_diag_mog_params, sample_string\n",
    "from util_k_mmd import get_combined_mmd_filename\n",
    "\n",
    "# Autoreload packages that are modified\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "%load_ext line_profiler\n",
    "# https://jakevdp.github.io/PythonDataScienceHandbook/01.07-timing-and-profiling.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_plot_settings = True\n",
    "if fix_plot_settings:\n",
    "    plt.rc('font', family='serif')\n",
    "    plt.rc('text', usetex=False)\n",
    "    label_size = 20\n",
    "    mpl.rcParams['xtick.labelsize'] = label_size \n",
    "    mpl.rcParams['ytick.labelsize'] = label_size \n",
    "    mpl.rcParams['axes.labelsize'] = label_size\n",
    "    mpl.rcParams['axes.titlesize'] = label_size\n",
    "    mpl.rcParams['figure.titlesize'] = label_size\n",
    "    mpl.rcParams['lines.markersize'] = label_size\n",
    "    mpl.rcParams['grid.linewidth'] = 2.5\n",
    "    mpl.rcParams['legend.fontsize'] = label_size\n",
    "    pylab.rcParams['xtick.major.pad']=5\n",
    "    pylab.rcParams['ytick.major.pad']=5\n",
    "\n",
    "    lss = ['--',  ':', '-.', '-', '--', '-.', ':', '-', '--', '-.', ':', '-']\n",
    "    mss = ['>', 'o',  's', 'D', '>', 's', 'o', 'D', '>', 's', 'o', 'D']\n",
    "    ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "    colors = ['#e41a1c', '#0000cd', '#4daf4a',  'black' , 'magenta']\n",
    "else:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0.2 Define helper functions for loading / plotting mmds; and loading input / coresets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_mmds(ms, params_p, params_k_split, params_k_swap, rep_ids, delta=0.5, \n",
    "                      sample_seed=1234567, thin_seed=9876543, results_dir=\"results_new/combined\"):\n",
    "    \"\"\"\n",
    "    Return dictionary of mmd results\n",
    "    \n",
    "    ms: range of thinning rounds\n",
    "    params_p: Dictionary of distribution parameters recognized by sample()\n",
    "    params_k_split: Dictionary of kernel parameters recognized by kernel() # used for kt split\n",
    "    params_k_swap: Dictionary of kernel parameters recognized by kernel() # used for kt swap; and computing mmd\n",
    "    rep_ids: Which replication numbers of experiment to run; the replication\n",
    "        number determines the seeds set for reproducibility\n",
    "    delta: If c is None, delta/(4^m) is the failure probability for\n",
    "        adaptive threshold sequence\n",
    "    sample_seed: (Optional) starting seed for sample generation \n",
    "    thin_seed: (Optional) starting seed for thinning experiments\n",
    "    results_dir: Folder where the results are loaded from\n",
    "    \"\"\"\n",
    "   \n",
    "    mmds_dict = dict()\n",
    "    mmds_dict[\"P\"] = dict()\n",
    "    mmds_dict[\"Sin\"] = dict()\n",
    "    filename =  get_combined_mmd_filename(f\"mc\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "    key = 'iid' if params_p[\"name\"] in [\"gauss\", \"diag_mog\"] else 'standard'\n",
    "    with open(filename, 'rb') as file:\n",
    "        mmds_dict[\"P\"][key] = pkl.load(file)\n",
    "\n",
    "    filename = get_combined_mmd_filename(f\"kt\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "    with open(filename, 'rb') as file:\n",
    "        mmds_dict[\"P\"][\"KT\"] = pkl.load(file)\n",
    "        \n",
    "    filename =  get_combined_mmd_filename(f\"mc-sin\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "    key = 'iid' if params_p[\"name\"] in [\"gauss\", \"diag_mog\"] else 'standard'\n",
    "    with open(filename, 'rb') as file:\n",
    "        mmds_dict[\"Sin\"][key] = pkl.load(file)\n",
    "\n",
    "    filename = get_combined_mmd_filename(f\"kt-sin\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "    with open(filename, 'rb') as file:\n",
    "        mmds_dict[\"Sin\"][\"KT\"] = pkl.load(file)\n",
    "\n",
    "    return(mmds_dict)\n",
    "\n",
    "def load_input_and_coreset(m, params_p, params_k_split, params_k_swap, rep_id, delta=0.5, \n",
    "                      sample_seed=1234567, thin_seed=9876543, results_dir=\"results_new\", verbose=False):\n",
    "    \"\"\"Return exisiting coresets by loading from disk, and the associated MC points used for finding the coresets\n",
    "    if return_ktplus = True, return the KTplus coresets.\n",
    "    \n",
    "    Args:\n",
    "      m: Number of halving rounds (number of sample points n = 2^{2m})\n",
    "      params_p: Dictionary of distribution parameters recognized by sample()\n",
    "      params_k_split: Dictionary of kernel parameters recognized by kernel() # used for kt split\n",
    "      params_k_swap: Dictionary of kernel parameters recognized by kernel() # used for kt swap; and computing mmd\n",
    "      rep_id: A single rep id for which coreset to be returned\n",
    "      delta: If c is None, delta/(4^m) is the failure probability for\n",
    "        adaptive threshold sequence\n",
    "      sample_seed: (Optional) random seed is set to sample_seed + rep\n",
    "        prior to generating input sample for replication rep\n",
    "      thin_seed: (Optional) random seed is set to thin_seed + rep\n",
    "        prior to running thinning for replication rep\n",
    "      results_dir: (Optional) Directory in which results is to be loaded from\n",
    "      verbose: (Optional) If True, print intermediate updates\n",
    "    \"\"\"\n",
    "    \n",
    "    d = params_p[\"d\"]\n",
    "    assert(d == params_k_split[\"d\"])\n",
    "    assert(d == params_k_swap[\"d\"])\n",
    "    sample_str = sample_string(params_p, sample_seed)\n",
    "    split_kernel_str = \"{}_var{:.3f}_seed{}\".format(params_k_split[\"name\"], params_k_split[\"var\"], thin_seed)\n",
    "    swap_kernel_str =  \"{}_var{:.3f}\".format(params_k_swap[\"name\"], params_k_swap[\"var\"])\n",
    "    thresh_str = f\"delta{delta}\"\n",
    "    file_template = os.path.join(results_dir, f\"kt-coresets-{sample_str}-split{split_kernel_str}-swap{swap_kernel_str}-d{d}-m{m}-{thresh_str}-rep{{}}.pkl\")\n",
    "    \n",
    "    filename = file_template.format(rep_id)\n",
    "    n = int(2**(2*m))\n",
    "    ncoreset = int(2**m)\n",
    "    X = sample(n, params_p, seed=sample_seed+rep_id)\n",
    "    if os.path.exists(filename):\n",
    "        with open(filename, 'rb') as file:\n",
    "            if verbose:\n",
    "                print(f\"Loading KT coreset indices from {filename}\")\n",
    "            coresets = pkl.load(file)\n",
    "    else:\n",
    "        raise ValueError(f\"File {filename} not found\")\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Returning all {n} input MC points and {ncoreset} KT points\")\n",
    "    return(X, X[coresets[:ncoreset]])\n",
    "    \n",
    "def plot_mmd_dict(axes, ms, mmds_dict, size_factor=1., error_bar = True, error_shade = False, skip_ns=int(0),\n",
    "                  legend_size=mpl.rcParams['legend.fontsize']):\n",
    "    '''\n",
    "        Plot mmd rates\n",
    "        axes: axes objet for a matplotlib figure\n",
    "        ms: array of integers denoting the logarithmic range of coreset size (m=log_2 coreset size)\n",
    "        mmds_dict: Dictionary of mmd rates; each entry should have size Rep x len(ms)\n",
    "        size_factor: Float, scaling of marker size\n",
    "        error_bar: Boolean, When true error bars across Reps are plotted for each key in mmds_dict\n",
    "        error_shade: Boolean, When true shaded error regions are be plotted for each key in mmds_dict\n",
    "        skip_ns: Integer, starting index for the plot (how many entries to skip in ms, and mmds_dict keys)\n",
    "        legend_size: Size of the legend in the plot\n",
    "        \n",
    "    '''\n",
    "    ns = np.power(2, ms[skip_ns:], dtype=int)\n",
    "    X = sm.add_constant(np.log(ns**2))\n",
    "    ls = []\n",
    "    labs = []\n",
    "    for i, (label, mmd) in enumerate(mmds_dict.items()):\n",
    "        y = np.nanmedian(mmd, axis=1)\n",
    "        yerr = np.nanstd(mmd, axis=1) / np.sqrt((~np.isnan(mmd)).sum(axis=1))\n",
    "        y = y[skip_ns:]\n",
    "        yerr = yerr[skip_ns:]\n",
    "        model = sm.OLS(np.log(y), X).fit()\n",
    "        \n",
    "        if not(error_bar) and not(error_shade):\n",
    "            l1, = axes.plot(ns, y, marker=mss[i], linestyle='None', \n",
    "                           color=colors[i], alpha=1., markersize=size_factor*ms_size[i])\n",
    "        \n",
    "        else:\n",
    "            if error_bar:\n",
    "                l1 = axes.errorbar(ns, y, marker=mss[i], yerr=np.array([np.zeros_like(yerr), yerr]), linestyle='None', \n",
    "                            color=colors[i], alpha=1., markersize=size_factor*ms_size[i], elinewidth=5)\n",
    "            if error_shade:\n",
    "                axes.fill_between(ns, y-yerr, y+yerr, alpha=0.2, color=colors[i])\n",
    "                \n",
    "        l2, = axes.plot(ns, np.exp(model.predict(X)),\n",
    "                                   linestyle=lss[i],\n",
    "                                   linewidth=4, color=colors[i], alpha=.5)\n",
    "        \n",
    "        labs.append(label.replace(\"_\", \" \") + r\": n$^{%.2f}$\"%(model.params[1]))\n",
    "        ls.append((l1, l2))\n",
    "        \n",
    "        axes.legend(ls, labs, loc=\"lower left\", handletextpad=0.0, fontsize=legend_size)\n",
    "        axes.set_xscale('log', basex=2)\n",
    "        axes.set_yscale('log', basey=2)\n",
    "        axes.spines['top'].set_visible(False)\n",
    "        # ax.spines['left'].set_visible(False)\n",
    "        axes.spines['right'].set_visible(False)\n",
    "        # ax.spines['bottom'].set_visible(False)\n",
    "        axes.grid(True, alpha=0.4)\n",
    "    return"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load the mmd results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mcmc_file_names = [ 'Goodwin_RW', 'Goodwin_ADA-RW',\n",
    "                   'Goodwin_MALA', 'Goodwin_PRECOND-MALA',  \n",
    "                   'Lotka_RW', 'Lotka_ADA-RW',  'Lotka_MALA','Lotka_PRECOND-MALA']\n",
    "var = 1.\n",
    "rep_ids = range(100)\n",
    "ms = range(4+1)\n",
    "ds = [2, 3, 4]\n",
    "Ms = [4, 6, 8]\n",
    "mc_mmds = dict()\n",
    "mcmc_mmds = dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for d in ds:\n",
    "    var = 1.\n",
    "    params_k_swap = {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "    params_k_split = {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}\n",
    "    params_p = {\"name\": \"gauss\", \"var\": var, \"d\": int(d), \"saved_samples\": False}\n",
    "    print(f\"loading results for Gauss kernel and Gauss P in d={d}\")\n",
    "    mc_mmds[f\"gauss-{d}\"] = load_mmds(ms, params_p, params_k_split, params_k_swap, rep_ids)\n",
    "    if d == 2:\n",
    "        for M in Ms:\n",
    "            params_p = compute_diag_mog_params(M)\n",
    "            print(f\"loading results for Gauss kernel and {M}-mix diag MOG in d={d}\")\n",
    "            mc_mmds[f\"diagmog-{M}\"] = load_mmds(ms, params_p, params_k_split, params_k_swap, rep_ids)\n",
    "    if d == 4:\n",
    "        for filename in mcmc_file_names:\n",
    "            params_p = compute_mcmc_params_p(filename, nmax=int(2**15), include_last=True)\n",
    "            var = params_p[\"med_dist\"]**2\n",
    "            params_k_swap[\"var\"] = var\n",
    "            params_k_split[\"var\"] = var/2.\n",
    "            print(f\"loading results for Gauss kernel and {filename} MCMC setting d={d}\")\n",
    "            mcmc_mmds[f\"{filename}\"] = load_mmds(ms, params_p, params_k_split, params_k_swap, rep_ids)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for d in ds:\n",
    "#     print(d, np.median(all_mmds[f\"gauss-{d}\"][\"P\"][\"KT\"], 1))\n",
    "    \n",
    "# for M in Ms:\n",
    "#     print(M, np.median(all_mmds[f\"diagmog-{M}\"][\"P\"][\"KT\"], 1))\n",
    "\n",
    "# all_mmds[f\"gauss-{2}\"][\"P\"][\"KT\"][1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 Gauss P and MoG P MMD rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ylab_size = 25\n",
    "xlab_size = 18\n",
    "title_size = 25\n",
    "ylab = 'Median MMD'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "skip_ns = int(1)\n",
    "axes = plt.subplots(1, len(ds), figsize=[5*len(ds), 5], sharex=True, sharey=True)\n",
    "for i, d in enumerate(ds):\n",
    "    ax = axes[1][i]\n",
    "    plot_mmd_dict(ax, ms, all_mmds[f\"gauss-{d}\"][\"P\"], 0.5, skip_ns=skip_ns, error_bar=True)\n",
    "    ax.set_title(r\"d=$%d$\"%(d), fontsize=title_size)\n",
    "    if i==0:\n",
    "        ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "    ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "plt.tight_layout()\n",
    "plt.suptitle(r\"$MMD_{\\bf k}(\\mathbb{P}, \\mathcal{S})$\", y=1.05, fontsize=title_size)\n",
    "plt.show()\n",
    "\n",
    "axes = plt.subplots(1, len(Ms), figsize=[5*len(ds), 5], sharex=True, sharey=True)\n",
    "for i, M in enumerate(Ms):\n",
    "    ax = axes[1][i]\n",
    "    plot_mmd_dict(ax, ms, all_mmds[f\"diagmog-{M}\"][\"P\"], 0.5, skip_ns=skip_ns, error_bar=True)\n",
    "    ax.set_title(r\"M=$%d$\"%(M), fontsize=title_size)\n",
    "    if i==0:\n",
    "        ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "    ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skip_ns = int(2)\n",
    "axes = plt.subplots(1, len(ds), figsize=[5*len(ds), 5], sharex=True, sharey=True)\n",
    "for i, d in enumerate(ds):\n",
    "    ax = axes[1][i]\n",
    "    plot_mmd_dict(ax, ms, all_mmds[f\"gauss-{d}\"][\"Sin\"], 0.5, skip_ns=skip_ns, error_bar=True)\n",
    "    ax.set_title(r\"d=$%d$\"%(d), fontsize=title_size)\n",
    "    if i==0:\n",
    "        ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "    ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "plt.tight_layout()\n",
    "plt.suptitle(r\"$MMD_{\\bf k}(\\mathcal{S}_{in}, \\mathcal{S})$\", y=1.05, fontsize=title_size)\n",
    "plt.show()\n",
    "\n",
    "axes = plt.subplots(1, len(Ms), figsize=[5*len(ds), 5], sharex=True, sharey=True)\n",
    "for i, M in enumerate(Ms):\n",
    "    ax = axes[1][i]\n",
    "    plot_mmd_dict(ax, ms, all_mmds[f\"diagmog-{M}\"][\"Sin\"], 0.5, skip_ns=skip_ns, error_bar=True)\n",
    "    ax.set_title(r\"M=$%d$\"%(M), fontsize=title_size)\n",
    "    if i==0:\n",
    "        ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "    ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 MCMC MMD rates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skip_ns = int(1)\n",
    "axes = plt.subplots(2, 4, figsize=[5*4, 5*2], sharex=True, sharey=True)\n",
    "axes = axes[1].flatten()\n",
    "skip_ns = int(2)\n",
    "for i, filename in enumerate(mcmc_adabw_file_names):\n",
    "    ax = axes[i]\n",
    "    plot_mmd_dict(ax, ms, mcmc_mmds[filename][\"P\"], 0.5, skip_ns=skip_ns, error_bar=True)\n",
    "    title = filename\n",
    "    if \"ADA-\" in filename:\n",
    "        title = filename[:-6] + \"adaRW\"\n",
    "    if \"PRECOND-\" in filename:\n",
    "        title = filename[:-12] + \"pMALA\"\n",
    "    title = title.replace(\"_\", \" \")\n",
    "    title = title.replace(\"Lotka\", \"Lotka-Volterra\")\n",
    "    ax.set_title(title, fontsize=title_size)\n",
    "    if i==0 or i==4:\n",
    "        ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "    if i>=4:\n",
    "        ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "    if i<4:\n",
    "        ax.spines['bottom'].set_visible(False)\n",
    "plt.suptitle(r\"$MMD_{\\bf k}(\\mathbb{P}, \\mathcal{S})$\", y=1.05, fontsize=title_size)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(filename, np.median(mcmc_mmds[filename][\"P\"][\"standard\"], 1), np.median(mcmc_mmds[filename][\"P\"][\"KT\"], 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "skip_ns = int(2)\n",
    "axes = plt.subplots(2, 4, figsize=[5*4, 5*2], sharex=True, sharey=True)\n",
    "axes = axes[1].flatten()\n",
    "skip_ns = int(2)\n",
    "for i, filename in enumerate(mcmc_adabw_file_names):\n",
    "    ax = axes[i]\n",
    "    plot_mmd_dict(ax, ms, mcmc_mmds[filename][\"Sin\"], 0.5, skip_ns=skip_ns, error_bar=True)\n",
    "    title = filename\n",
    "    if \"ADA-\" in filename:\n",
    "        title = filename[:-6] + \"adaRW\"\n",
    "    if \"PRECOND-\" in filename:\n",
    "        title = filename[:-12] + \"pMALA\"\n",
    "    title = title.replace(\"_\", \" \")\n",
    "    title = title.replace(\"Lotka\", \"Lotka-Volterra\")\n",
    "    ax.set_title(title, fontsize=title_size)\n",
    "    if i==0 or i==4:\n",
    "        ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "    if i>=4:\n",
    "        ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "    if i<4:\n",
    "        ax.spines['bottom'].set_visible(False)\n",
    "plt.suptitle(r\"$MMD_{\\bf k}(\\mathcal{S}_{in}, \\mathcal{S})$\", y=1.05, fontsize=title_size)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Scatter plot of coresets for MoG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_seed = 1234567\n",
    "# which rep_id's coresets to visualize\n",
    "rep_id = npr.default_rng(1).choice(range(100))\n",
    "\n",
    "# use contour (True) lines or filled contour plot (False)\n",
    "contour = True\n",
    "\n",
    "d = int(2)\n",
    "var = 1.\n",
    "for M in [4, 6, 8]:\n",
    "    params_k_swap = {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "    params_k_split = {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}\n",
    "    params_p = compute_diag_mog_params(M)\n",
    "    \n",
    "    # range of m to plot\n",
    "    m_plot = np.arange(1, 5, dtype=int )\n",
    "\n",
    "    # axis limits for all plots\n",
    "    snr = 2*max(params_p[\"covs\"])\n",
    "    if M == 4:\n",
    "        snr *= 2.5\n",
    "    else:\n",
    "        snr *= 2\n",
    "    lim0 = np.min(params_p[\"means\"])-snr\n",
    "    lim1 = np.max(params_p[\"means\"])+snr\n",
    "    x, y = np.mgrid[lim0:lim1:.05, lim0:lim1:.05]\n",
    "    pos = np.dstack((x, y))\n",
    "    # compute pdf of the diag MOg density\n",
    "    for i in range(M):\n",
    "        rv_object = multivariate_normal(params_p[\"means\"][i], np.eye(d) * params_p[\"covs\"][i])\n",
    "        if i == 0:\n",
    "            density = params_p[\"weights\"][i]*rv_object.pdf(pos)\n",
    "        else:\n",
    "            density += params_p[\"weights\"][i]*rv_object.pdf(pos)\n",
    "\n",
    "    axes = plt.subplots(2, len(m_plot), figsize=[4*len(m_plot), 4.5*2], sharex=True, sharey=True)    \n",
    "\n",
    "    for i, m in enumerate(m_plot): \n",
    "        n = int(2**m)\n",
    "        # load coresets\n",
    "        Xs = dict()\n",
    "        Xs[\"iid\"], Xs[\"KT\"] = load_input_and_coreset(m, params_p, params_k_split, params_k_swap, rep_id)\n",
    "        for j, (label, X) in enumerate(Xs.items()):\n",
    "            ax = axes[1][j][i]\n",
    "            # plot density of the MoG\n",
    "            if contour:\n",
    "                C = ax.contour(x, y, np.log(density+np.exp(-20)), cmap=\"Greys\", linewidths=.5)\n",
    "                ax.clabel(C, inline=1, fontsize=10)\n",
    "            else:\n",
    "                ax.contourf(x, y, np.log(density+np.exp(-20)), cmap=\"Greys\")\n",
    "            # scatter plot\n",
    "            if label == \"iid\":\n",
    "                # do standard thinning of the iid points\n",
    "                step = int(X.shape[0]/n)\n",
    "                ax.scatter(X[step-1:X.shape[0]:step, 0], X[step-1:X.shape[0]:step, 1], marker=mss[j], s=60,  color=colors[j], label=label)\n",
    "            else:\n",
    "                ax.scatter(X[:n, 0], X[:n, 1], marker=mss[j], s=60,  color=colors[j], label=label)\n",
    "\n",
    "            ax.spines['top'].set_visible(False)\n",
    "            ax.spines['left'].set_visible(False)\n",
    "            ax.spines['right'].set_visible(False)\n",
    "            ax.spines['bottom'].set_visible(False)\n",
    "            if label==\"iid\":\n",
    "                ax.set_title(f\"{n} iid points\", fontsize=20)\n",
    "            if label==\"KT\":\n",
    "                ax.set_title(f\"{n} KT points\", fontsize=20)\n",
    "            if i==0:\n",
    "                ax.spines['left'].set_visible(True)\n",
    "            if j==len(XS)-1:\n",
    "                ax.spines['bottom'].set_visible(True)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
