{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from argparse import ArgumentParser\n",
    "import copy\n",
    "import joblib\n",
    "\n",
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import numpy.linalg as npl\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "import pickle as pkl\n",
    "\n",
    "# Fitting linear models\n",
    "import statsmodels.api as sm\n",
    "from scipy.stats import multivariate_normal\n",
    "\n",
    "# plottibg libraries\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "import pylab\n",
    "import seaborn as sns\n",
    "plt.style.use('seaborn-white')\n",
    "\n",
    "# utils for generating samples, evaluating kernels, and mmds\n",
    "from util_sample import sample, compute_mcmc_params_p, compute_diag_mog_params, sample_string, compute_params_p\n",
    "from util_k_mmd import get_combined_mmd_filename, compute_params_k\n",
    "from util_filenames import get_file_template, get_combined_file_template\n",
    "from util_parse import init_parser\n",
    "\n",
    "from construct_compress_thin_coresets import construct_compress_thin_coresets\n",
    "from construct_kt_coresets import construct_kt_coresets\n",
    "from construct_st_coresets import construct_st_coresets\n",
    "from construct_herding_coresets import construct_herding_coresets\n",
    "\n",
    "# Autoreload packages that are modified\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "# %matplotlib inline\n",
    "%load_ext line_profiler\n",
    "# https://jakevdp.github.io/PythonDataScienceHandbook/01.07-timing-and-profiling.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ylab_size = 25\n",
    "xlab_size = 25\n",
    "leg_size = 25\n",
    "title_size = 30\n",
    "ylab = 'Mean MMD' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_plot_settings = True\n",
    "if fix_plot_settings:\n",
    "    plt.rc('font', family='serif')\n",
    "    plt.rc('text', usetex=False)\n",
    "    label_size = 20\n",
    "    mpl.rcParams['xtick.labelsize'] = label_size \n",
    "    mpl.rcParams['ytick.labelsize'] = label_size \n",
    "    mpl.rcParams['axes.labelsize'] = label_size\n",
    "    mpl.rcParams['axes.titlesize'] = label_size\n",
    "    mpl.rcParams['figure.titlesize'] = label_size\n",
    "    mpl.rcParams['lines.markersize'] = label_size\n",
    "    mpl.rcParams['grid.linewidth'] = 2.5\n",
    "    mpl.rcParams['legend.fontsize'] = label_size\n",
    "    pylab.rcParams['xtick.major.pad']=5\n",
    "    pylab.rcParams['ytick.major.pad']=5\n",
    "\n",
    "    lss = ['-', '-.',  ':', '--',  '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "    mss = ['>', 'o',  's', 'D', '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "    ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "    colors = ['#e41a1c',  'magenta','#4daf4a', '#0000cd',  'cyan', 'black' ,'orange','yellow','gray']*2\n",
    "else:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def plot_mmd_dict(axes, ms, mmds_dict, size_factor=1., fit_model=True, error_bar = True, error_shade = False, \n",
    "#                   skip_ns=int(0), legend_size=mpl.rcParams['legend.fontsize'], rm_keys = [], summary=np.mean, \n",
    "#                   legend_loc = 'best',\n",
    "#                  change_cpthin=False,\n",
    "#                  alpha_kt = 1.):\n",
    "#     '''\n",
    "#         Plot mmd rates\n",
    "#         axes: axes objet for a matplotlib figure\n",
    "#         ms: array of integers denoting the logarithmic range of coreset size (m=log_2 coreset size)\n",
    "#         mmds_dict: Dictionary of mmd rates; each entry should have size Rep x len(ms)\n",
    "#         size_factor: Float, scaling of marker size\n",
    "#         fit_model: Whether fit a log-log least squares model\n",
    "#         error_bar: Boolean, When true error bars across Reps are plotted for each key in mmds_dict\n",
    "#         error_shade: Boolean, When true shaded error regions are be plotted for each key in mmds_dict\n",
    "#         skip_ns: Integer, starting index for the plot (how many entries to skip in ms, and mmds_dict keys)\n",
    "#         legend_size: Size of the legend in the plot\n",
    "#         rm_keys: keys not to be plotted\n",
    "#         summary: function to be used for combining results across reps\n",
    "#         legend_loc: location of legend\n",
    "#         change_cpthin: whether to only plot ST, KT, KT-Comp, KT-Comp++\n",
    "        \n",
    "#     '''\n",
    "#     ns = np.power(4, ms[skip_ns:], dtype=int)\n",
    "#     X = sm.add_constant(np.log(ns))\n",
    "#     ls = []\n",
    "#     labs = []\n",
    "#     mmds_dict_new = mmds_dict.copy()\n",
    "#     for rm_key in rm_keys:\n",
    "#         mmds_dict_new.pop(rm_key, None)\n",
    "    \n",
    "#     if change_cpthin:\n",
    "#         for label in [\"CPthin+KT-1\", \"CPthin+KT-2\", \"CPthin+KT-3\", \"CP+KT-1\", \"CP+KT-2\", \"CP+KT-3\"]:\n",
    "#             if label in mmds_dict_new.keys():\n",
    "#                 mmds_dict_new.pop(label, None)\n",
    "\n",
    "#     for i, (label, mmd) in enumerate(mmds_dict_new.items()):\n",
    "#         alpha = 1.\n",
    "#         if label == \"KT\":\n",
    "#             alpha = alpha_kt\n",
    "#         if change_cpthin:\n",
    "#             if label == \"CPthin+KT-0\" or label==\"CP+KT-0\":\n",
    "#                 label = \"KT-Comp\"\n",
    "#             if label == \"CPthin+KT-4\" or label==\"CP+KT-4\":\n",
    "#                 label = \"KT-Comp++\"\n",
    "            \n",
    "#         if summary == np.mean or summary == np.nanmedian:\n",
    "#             y = summary(mmd, axis=1)\n",
    "#             yerr = np.nanstd(mmd, axis=1) / np.sqrt((~np.isnan(mmd)).sum(axis=1))\n",
    "#             y = y[skip_ns:]\n",
    "#             yerr = yerr[skip_ns:]\n",
    "#         else:\n",
    "#             y = mmd\n",
    "#             yerr = np.zeros_like(y)\n",
    "            \n",
    "#         if fit_model:\n",
    "#             model = sm.OLS(np.log(y), X).fit()\n",
    "        \n",
    "#         if not(error_bar) and not(error_shade):\n",
    "#             l1, = axes.plot(ns, y, marker=mss[i], linestyle='None', \n",
    "#                            color=colors[i], alpha=alpha, markersize=size_factor*ms_size[i])\n",
    "        \n",
    "#         else:\n",
    "#             if error_bar:\n",
    "#                 l1 = axes.errorbar(ns, y, marker=mss[i], yerr=np.array([np.zeros_like(yerr), yerr]), linestyle='None', \n",
    "#                             color=colors[i], alpha=alpha, markersize=size_factor*ms_size[i], linewidth=5)\n",
    "#             if error_shade:\n",
    "#                 axes.fill_between(ns, y-yerr, y+yerr, alpha=0.2, color=colors[i])\n",
    "          \n",
    "#         if fit_model:\n",
    "#             l2, = axes.plot(ns, np.exp(model.predict(X)),\n",
    "#                                    linestyle=lss[i],\n",
    "#                                    linewidth=4, color=colors[i], alpha=.5)\n",
    "        \n",
    "#         if fit_model:\n",
    "#             labs.append(label.replace(\"_\", \" \") + r\": n$^{%.2f}$\"%(model.params[1]))\n",
    "#             ls.append((l1, l2))\n",
    "#         else:\n",
    "#             labs.append(label.replace(\"_\", \" \"))\n",
    "#             ls.append((l1))\n",
    "\n",
    "#         axes.legend(ls, labs, loc=legend_loc, handletextpad=0.0, fontsize=legend_size)\n",
    "#         axes.set_xscale('log', basex=2)\n",
    "#         axes.set_yscale('log', basey=2)\n",
    "#         axes.spines['top'].set_visible(False)\n",
    "#         # ax.spines['left'].set_visible(False)\n",
    "#         axes.spines['right'].set_visible(False)\n",
    "#         # ax.spines['bottom'].set_visible(False)\n",
    "#         axes.grid(True, alpha=0.4)\n",
    "#     return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_mmd_dict_new(axes, ms, mmds_dict, size_factor=1., fit_model=True, error_bar = True, error_shade = False, \n",
    "                  skip_ns=int(0), legend_size=mpl.rcParams['legend.fontsize'], rm_keys = [], summary=np.mean, \n",
    "                  legend_loc = 'best',\n",
    "                 change_cpthin=False,\n",
    "                 alpha_kt = 1.,\n",
    "                 adjust_alpha = False,\n",
    "                 sizes_kt = None,\n",
    "                 sizes_non_kt = None,\n",
    "                     basex=2,\n",
    "                     basey=2):\n",
    "    '''\n",
    "        Plot mmd rates\n",
    "        axes: axes objet for a matplotlib figure\n",
    "        ms: array of integers denoting the logarithmic range of coreset size (m=log_2 coreset size)\n",
    "        mmds_dict: Dictionary of mmd rates; each entry should have size Rep x len(ms)\n",
    "        size_factor: Float, scaling of marker size\n",
    "        fit_model: Whether fit a log-log least squares model\n",
    "        error_bar: Boolean, When true error bars across Reps are plotted for each key in mmds_dict\n",
    "        error_shade: Boolean, When true shaded error regions are be plotted for each key in mmds_dict\n",
    "        skip_ns: Integer, starting index for the plot (how many entries to skip in ms, and mmds_dict keys)\n",
    "        legend_size: Size of the legend in the plot\n",
    "        rm_keys: keys not to be plotted\n",
    "        summary: function to be used for combining results across reps\n",
    "        legend_loc: location of legend\n",
    "        change_cpthin: whether to only plot ST, KT, KT-Comp, KT-Comp++\n",
    "        \n",
    "    '''\n",
    "    ns = np.power(4, ms[skip_ns:], dtype=int)\n",
    "    X = sm.add_constant(np.log(ns))\n",
    "    ls = []\n",
    "    labs = []\n",
    "    mmds_dict_new = mmds_dict.copy()\n",
    "    for rm_key in rm_keys:\n",
    "        mmds_dict_new.pop(rm_key, None)\n",
    "    \n",
    "    if change_cpthin:\n",
    "        for label in [\"CPthin+KT-1\", \"CPthin+KT-2\", \"CPthin+KT-3\", \"CP+KT-1\", \"CP+KT-2\", \"CP+KT-3\", \"CP+KTnosymm-0\"]:\n",
    "            if label in mmds_dict_new.keys():\n",
    "                mmds_dict_new.pop(label, None)\n",
    "\n",
    "    for i, (label, mmd) in enumerate(mmds_dict_new.items()):\n",
    "        if adjust_alpha:\n",
    "            if label == \"KT\":\n",
    "                assert(sizes_kt is not None)\n",
    "                ms = range(sizes_kt[0], sizes_kt[1]+1)\n",
    "            else:\n",
    "                assert(sizes_non_kt is not None)\n",
    "                ms = range(sizes_non_kt[0], sizes_non_kt[1]+1)\n",
    "        ns = np.power(4, ms[skip_ns:], dtype=int)\n",
    "        X = sm.add_constant(np.log(ns))\n",
    "        \n",
    "        alpha = 1.\n",
    "        if label == \"KT\":\n",
    "            alpha = alpha_kt\n",
    "        if change_cpthin:\n",
    "            if label == \"CPthin+KT-0\" or label==\"CP+KT-0\":\n",
    "                label = \"KT-Comp\"\n",
    "            if label == \"CPthin+KT-4\" or label==\"CP+KT-4\":\n",
    "                label = \"KT-Comp++\"\n",
    "                \n",
    "            if label == \"CPHerd-4\":\n",
    "                label = \"Herd-Comp++\"\n",
    "            if label == \"CPHerd-0\":\n",
    "                label = \"Herd-Comp\"\n",
    "                \n",
    "            if label == \"CP+KTnosymm-4\":\n",
    "                label = \"KT-Comp++(NS)\"\n",
    "            \n",
    "                \n",
    "            \n",
    "        if summary == np.mean or summary == np.nanmedian:\n",
    "            y = summary(mmd, axis=1)\n",
    "            yerr = np.nanstd(mmd, axis=1) / np.sqrt((~np.isnan(mmd)).sum(axis=1))\n",
    "            y = y[skip_ns:]\n",
    "            yerr = yerr[skip_ns:]\n",
    "        else:\n",
    "            y = mmd\n",
    "            yerr = np.zeros_like(y)\n",
    "            \n",
    "        if fit_model:\n",
    "            model = sm.OLS(np.log(y), X).fit()\n",
    "        \n",
    "        if not(error_bar) and not(error_shade):\n",
    "            l1, = axes.plot(ns, y, marker=mss[i], linestyle='None', \n",
    "                           color=colors[i], alpha=alpha, markersize=size_factor*ms_size[i])\n",
    "        \n",
    "        else:\n",
    "            if error_bar:\n",
    "                l1 = axes.errorbar(ns, y, marker=mss[i], yerr=np.array([np.zeros_like(yerr), yerr]), linestyle='None', \n",
    "                            color=colors[i], alpha=alpha, markersize=size_factor*ms_size[i], linewidth=5)\n",
    "            if error_shade:\n",
    "                axes.fill_between(ns, y-yerr, y+yerr, alpha=0.2, color=colors[i])\n",
    "          \n",
    "        if fit_model:\n",
    "            l2, = axes.plot(ns, np.exp(model.predict(X)),\n",
    "                                   linestyle=lss[i],\n",
    "                                   linewidth=4, color=colors[i], alpha=.5)\n",
    "        \n",
    "        if fit_model:\n",
    "            labs.append(label.replace(\"_\", \" \") + r\": n$^{%.2f}$\"%(model.params[1]))\n",
    "            ls.append((l1, l2))\n",
    "        else:\n",
    "            labs.append(label.replace(\"_\", \" \"))\n",
    "            ls.append((l1))\n",
    "\n",
    "        axes.legend(ls, labs, loc=legend_loc, handletextpad=0.0, fontsize=legend_size)\n",
    "        axes.set_xscale('log', basex=basex)\n",
    "        axes.set_yscale('log', basey=basey)\n",
    "        axes.spines['top'].set_visible(False)\n",
    "        # ax.spines['left'].set_visible(False)\n",
    "        axes.spines['right'].set_visible(False)\n",
    "        # ax.spines['bottom'].set_visible(False)\n",
    "        axes.grid(True, alpha=0.4)\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_arguments():\n",
    "    parser = init_parser()\n",
    "    parser.add_argument('--size_min', '-size_min', type=int, default=0,\n",
    "                            help=\"min size\")\n",
    "    parser.add_argument('--size_max', '-size_max', type=int, default=3,\n",
    "                            help=\"max size\")\n",
    "    parser.add_argument('--cset', '-cset', type=str, default='KT',\n",
    "                       help=\"which type of coreset results to combine\")\n",
    "    args, opt = parser.parse_known_args()\n",
    "    return(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_results(args, params_p, params_k_split, params_k_swap, adjust_alpha=False):\n",
    "    temp_args = copy.deepcopy(args)\n",
    "    \n",
    "    folder = \"results/combined\"\n",
    "    pathlib.Path(folder).mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    prefix = args.cset \n",
    "    file_template = get_combined_file_template(folder, prefix, args.d, args.size_min,\n",
    "                    args.size_max, args.size_max, params_p, params_k_split, params_k_swap, \n",
    "                    delta=0.5, experiment_seed=args.seed, compressalg = args.compressalg)\n",
    "    \n",
    "    combo_filename  = file_template.format(\"combined_mmd\", f\"s{args.rep0}-{args.rep0+args.repn}\")\n",
    "    # print(combo_filename)\n",
    "    if not args.recombine and os.path.exists(combo_filename):\n",
    "        mmd_results = joblib.load(combo_filename)\n",
    "    else:\n",
    "        mmd_results = dict()\n",
    "        if args.cset == \"ST\":\n",
    "            mmd_results[args.cset] = dict() # np.zeros((args.size_max-args.size_min+1, args.repn-args.rep0))\n",
    "            for i, sz in enumerate(range(args.size_min, args.size_max+1)):\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = np.zeros(args.repn)\n",
    "                temp_args.size = sz\n",
    "                temp_args.m = sz\n",
    "                mmd_results[args.cset] [\"size_%d\"%sz] = construct_st_coresets(temp_args)\n",
    "                \n",
    "        # by default there is no difference between CP+KT and CPthin+KT\n",
    "        if args.cset == \"CP+KT\":\n",
    "            mmd_results[args.cset] = dict()\n",
    "            \n",
    "            for i, sz in enumerate(range(args.size_min, args.size_max+1)):\n",
    "                if adjust_alpha:\n",
    "                    alphas = [0, 4]\n",
    "                else:\n",
    "                    alphas =  range(min(args.alpha_max, sz)+1)\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = dict()\n",
    "                for alpha in alphas:\n",
    "                    mmd_results[args.cset][\"size_%d\"%sz][\"alpha_%d\"%alpha] = np.zeros(args.repn)\n",
    "                    temp_args.size = sz\n",
    "                    temp_args.alpha = alpha\n",
    "                    temp_args.m = sz\n",
    "                    mmd_results[args.cset][\"size_%d\"%sz][\"alpha_%d\"%alpha] = construct_compress_thin_coresets(temp_args)    \n",
    "#         if args.cset == \"CPthin+KT\":\n",
    "#             mmd_results[args.cset] = dict()\n",
    "#             for i, sz in enumerate(range(args.size_min, args.size_max+1)):\n",
    "#                 mmd_results[args.cset][\"size_%d\"%sz] = dict()\n",
    "#                 for alpha in range(min(args.alpha_max, sz)+1):\n",
    "#                     mmd_results[args.cset][\"size_%d\"%sz][\"alpha_%d\"%alpha] = np.zeros(args.repn)\n",
    "#                     temp_args.size = sz\n",
    "#                     temp_args.alpha = alpha\n",
    "#                     temp_args.m = sz\n",
    "#                     temp_args.compressalg =\"thin\"\n",
    "#                     mmd_results[args.cset][\"size_%d\"%sz][\"alpha_%d\"%alpha] = construct_compress_thin_coresets(temp_args)    \n",
    "        if args.cset == \"KT\":\n",
    "            mmd_results[args.cset] = dict() # np.zeros((args.size_max-args.size_min+1, args.repn-args.rep0))\n",
    "            for i, sz in enumerate(range(args.size_min, args.size_max+1)):\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = np.zeros(args.repn)\n",
    "                temp_args.size = sz\n",
    "                temp_args.m = sz\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = construct_kt_coresets(temp_args)\n",
    "        if args.cset == \"Herd\":\n",
    "            mmd_results[args.cset] = dict() # np.zeros((args.size_max-args.size_min+1, args.repn-args.rep0))\n",
    "            for i, sz in enumerate(range(args.size_min, args.size_max+1)):\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = np.zeros(args.repn)\n",
    "                temp_args.size = sz\n",
    "                temp_args.m = sz\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = construct_herding_coresets(temp_args)\n",
    "        if args.cset == \"CPHerd\":\n",
    "            mmd_results[args.cset] = dict()\n",
    "            for i, sz in enumerate(range(args.size_min, args.size_max+1)):\n",
    "                mmd_results[args.cset][\"size_%d\"%sz] = dict()\n",
    "                alphas = [0, 4]\n",
    "                alphas = [0, 1, 2, 3, 4]\n",
    "                for alpha in  alphas: #range(min(args.alpha_max, sz)+1):\n",
    "                    mmd_results[args.cset][\"size_%d\"%sz][\"alpha_%d\"%alpha] = np.zeros(args.repn)\n",
    "                    temp_args.size = sz\n",
    "                    temp_args.alpha = alpha\n",
    "                    temp_args.m = sz\n",
    "                    temp_args.compressalg =\"herding\"\n",
    "                    mmd_results[args.cset][\"size_%d\"%sz][\"alpha_%d\"%alpha] = construct_compress_thin_coresets(temp_args)    \n",
    "                \n",
    "        joblib.dump(mmd_results, combo_filename)\n",
    "    return( mmd_results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def post_process(mmd_results, args, keys=[\"KT\"], adjust_alpha=False, sizes_non_kt=None, sizes_kt = [4, 8]):\n",
    "    mmd_dict = dict()\n",
    "    if sizes_non_kt is None:\n",
    "        sizes_non_kt = [args.size_min, args.size_max]\n",
    "    if not adjust_alpha:\n",
    "        sizes_kt = sizes_non_kt\n",
    "    else:\n",
    "        assert(sizes_kt is not None)\n",
    "\n",
    "    for key in keys:\n",
    "        if key == \"ST\":\n",
    "            num_sizes = 1+sizes_non_kt[1]-sizes_non_kt[0]\n",
    "            assert(len(mmd_results[key][key].keys())==num_sizes)\n",
    "            mmd_dict[key] = np.zeros((num_sizes, args.repn))\n",
    "            for i, (_, values) in enumerate(mmd_results[key][key].items()):\n",
    "                mmd_dict[key][i, :] = values\n",
    "        if key == \"KT\":\n",
    "            num_sizes = 1+sizes_kt[1]-sizes_kt[0]\n",
    "            assert(len(mmd_results[key][key].keys())==num_sizes)\n",
    "            mmd_dict[key] = np.zeros((num_sizes, args.repn))\n",
    "            for i, (_, values) in enumerate(mmd_results[key][key].items()):\n",
    "                mmd_dict[key][i, :] = values\n",
    "        if key == \"Herd\":\n",
    "            mmd_dict[key] = np.zeros((num_sizes, args.repn))\n",
    "            for i, (_, values) in enumerate(mmd_results[key][key].items()):\n",
    "                mmd_dict[key][i, :] = values\n",
    "        if key == \"CPHerd\":\n",
    "            alphas = [0, 4]\n",
    "            alphas = [0, 1, 2, 3, 4]\n",
    "            for i in alphas:\n",
    "                mmd_dict[key+f\"-{i}\"] = np.zeros((num_sizes, args.repn))\n",
    "                for j, sz in enumerate(range(args.size_min,args.size_max+1)):\n",
    "                    mmd_dict[key+f\"-{i}\"][j, :] = mmd_results[key][key][f'size_{sz}'][f'alpha_{i}']\n",
    "        if \"CP+KT\" in key: # for Hinch\n",
    "            if adjust_alpha:\n",
    "                alphas = [0, 4]\n",
    "            else:\n",
    "                alphas = range(min(args.alpha_max + 1, args.size_max+1))\n",
    "            num_sizes = 1+sizes_non_kt[1]-sizes_non_kt[0]\n",
    "            assert(len(mmd_results[key][key].keys())==num_sizes)\n",
    "            for i in alphas:\n",
    "                mmd_dict[key+f\"-{i}\"] = np.zeros((num_sizes, args.repn))\n",
    "                for j, sz in enumerate(range(sizes_non_kt[0], sizes_non_kt[1]+1)):\n",
    "                    # print(mmd_results[key][key])\n",
    "                    mmd_dict[key+f\"-{i}\"][j, :] = mmd_results[key][key][f'size_{sz}'][f'alpha_{i}']\n",
    "    return(mmd_dict)\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Combine and load results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# set parameters here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "st_coresets = True\n",
    "kt_coresets = False\n",
    "compress_thin_coresets = False\n",
    "# compressthin_thin_coresets = False\n",
    "herding_coresets = False\n",
    "cp_herding_coresets = False\n",
    "\n",
    "## gauss experiments\n",
    "run_gauss_experiments = True\n",
    "ds = [2, 4, 10, 100]\n",
    "\n",
    "## mog experiments\n",
    "run_mog_experiments = True\n",
    "Ms = [4, 6, 8]\n",
    "\n",
    "run_mcmc_experiments = False # run experiments with MCMC P\n",
    "all_mcmc_filenames = ['Goodwin_RW', 'Goodwin_ADA-RW', \n",
    "'Goodwin_MALA', 'Goodwin_PRECOND-MALA', \n",
    "'Lotka_RW', 'Lotka_ADA-RW', \n",
    "'Lotka_MALA', 'Lotka_PRECOND-MALA',  \n",
    "'Hinch_P_seed_1_temp_1_scaled', 'Hinch_P_seed_2_temp_1_scaled', \n",
    "'Hinch_TP_seed_1_temp_8_scaled', \n",
    "'Hinch_TP_seed_2_temp_8_scaled']\n",
    "\n",
    "\n",
    "combined_mmd_results = dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Gauss Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "st_coresets = True\n",
    "kt_coresets = True\n",
    "compress_thin_coresets = True\n",
    "herding_coresets = False\n",
    "cp_herding_coresets = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjust_alpha = True # only comp0 and comp4 will be loaded\n",
    "sizes_kt = [4, 8]\n",
    "sizes_non_kt = [4, 10]\n",
    "if adjust_alpha is False:\n",
    "    sizes_kt = None\n",
    "    sizes_non_kt = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "\n",
    "args = parse_arguments()\n",
    "\n",
    "args.krt = 0 # 0 if no krt, anything else for krt\n",
    "\n",
    "# mmd and size params\n",
    "args.combine_mmd = 1 \n",
    "args.computemmd = 1\n",
    "args.size_min, args.size_max = 4, 8\n",
    "args.rep0 = 0\n",
    "args.repn = 10\n",
    "args.alpha_max = 4\n",
    "args.symm1 = 1\n",
    "\n",
    "# rerun/compute params\n",
    "args.rerun = 0\n",
    "args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# recombine\n",
    "args.recombine = 0 # whether to create a new combined mmd file in case one already exists\n",
    "\n",
    "if run_gauss_experiments:\n",
    "    args.setting = \"gauss\"\n",
    "    for d in ds:\n",
    "        args.d = d\n",
    "        combined_mmd_results[\"d_%d\"%d] = dict()\n",
    "        \n",
    "        # compute d, params_p and var_k for the setting\n",
    "        d, params_p, var_k = compute_params_p(args)\n",
    "\n",
    "        # define the kernels\n",
    "        params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "                                                            use_krt_split=args.krt, name=\"gauss\")\n",
    "        flags = [st_coresets, compress_thin_coresets, kt_coresets, herding_coresets, cp_herding_coresets]\n",
    "        labels = [\"ST\", \"CP+KT\", \"KT\", \"CPHerd\", \"Herd\", ]\n",
    "        for f, lab in zip(flags, labels):\n",
    "            if f:\n",
    "                if adjust_alpha:\n",
    "                    if lab == \"KT\":\n",
    "                        args.size_min, args.size_max = sizes_kt[0], sizes_kt[1]\n",
    "                    else:\n",
    "                        args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "                    \n",
    "                args.cset = lab\n",
    "                combined_mmd_results[\"d_%d\"%d][lab] = dict()\n",
    "                combined_mmd_results[\"d_%d\"%d][lab] = combine_results(args, params_p, params_k_split, params_k_swap, adjust_alpha=adjust_alpha)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Gauss results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig_dir = \"figs/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys = []\n",
    "for f, lab in zip(flags, labels):\n",
    "    if f:\n",
    "        plot_keys.append(lab)\n",
    "\n",
    "# plot_keys = [\"ST\", \"KT\",\"CPthin+KT\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_plot_settings = True\n",
    "if fix_plot_settings:\n",
    "    plt.rc('font', family='serif')\n",
    "    plt.rc('text', usetex=False)\n",
    "    label_size = 20\n",
    "    mpl.rcParams['xtick.labelsize'] = label_size \n",
    "    mpl.rcParams['ytick.labelsize'] = label_size \n",
    "    mpl.rcParams['axes.labelsize'] = label_size\n",
    "    mpl.rcParams['axes.titlesize'] = label_size\n",
    "    mpl.rcParams['figure.titlesize'] = label_size\n",
    "    mpl.rcParams['lines.markersize'] = label_size\n",
    "    mpl.rcParams['grid.linewidth'] = 2.5\n",
    "    mpl.rcParams['legend.fontsize'] = 15\n",
    "    pylab.rcParams['xtick.major.pad']=5\n",
    "    pylab.rcParams['ytick.major.pad']=5\n",
    "\n",
    "    lss = ['-', '-.',  ':', '--',  '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "    mss = ['>', 's', 'o', 'D', '+', '*',  '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "    ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "    colors = ['#e41a1c', 'cyan',   '#0000cd',   '#4daf4a', 'magenta', 'black' ,'orange','yellow','gray']*2\n",
    "    colors = ['#e41a1c', #red\n",
    "              'orange',  \n",
    "              '#0000cd', #blue\n",
    "                '#4daf4a', #green\n",
    "              'magenta', 'black' , 'yellow','gray']*2\n",
    "else:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "basex = 10\n",
    "basey = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_fig = True\n",
    "if run_gauss_experiments:\n",
    "    simplified_plot = True\n",
    "    if not simplified_plot:\n",
    "        skip_ns = int(0)\n",
    "        axes = plt.subplots(1, len(ds), figsize=[6*len(ds), 5], sharex=True, sharey=False)\n",
    "        for i, d in enumerate(ds):\n",
    "            ax = axes[1] if len(ds) == 1 else axes[1][i]\n",
    "            simple_mmd_dict = post_process(combined_mmd_results[f'd_{d}'], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "            plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "            ax.set_title(r\"d=$%d$\"%(d), fontsize=title_size)\n",
    "            if i==0:\n",
    "                ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "            ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "        plt.tight_layout()\n",
    "    #     plt.suptitle(r\"Gauss Target\", y=1.05, fontsize=title_size)\n",
    "        plt.show()\n",
    "#     simplified_plot = True\n",
    "    if simplified_plot:\n",
    "        skip_ns = int(0)\n",
    "        axes = plt.subplots(1, len(ds), figsize=[6*len(ds), 5], sharex=True, sharey=False)\n",
    "        for i, d in enumerate(ds):\n",
    "            ax = axes[1] if len(ds) == 1 else axes[1][i]\n",
    "            simple_mmd_dict = post_process(combined_mmd_results[f'd_{d}'], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "            plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], change_cpthin=True, alpha_kt=.8,\n",
    "                          legend_loc='lower left',\n",
    "                          legend_size=16, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "            #ax.set_ylim([2**(-12), 2**(-2)])\n",
    "            ax.set_title(r\"d=$%d$\"%(d), fontsize=title_size)\n",
    "            if i==0:\n",
    "                ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "            ax.set_xlabel(r\"Input size $n$\", fontsize=xlab_size)\n",
    "        plt.tight_layout()\n",
    "        if save_fig:\n",
    "            suff  = f\"4_{sizes_non_kt[1]}\" if adjust_alpha else \"\"\n",
    "            if herding_coresets:\n",
    "                plt.savefig(fig_dir + f\"gauss_herd_mmd_{suff}.pdf\")\n",
    "            else:\n",
    "                plt.savefig(fig_dir + f\"gauss_kt_mmd_{suff}.pdf\")\n",
    "                \n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MOG "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "st_coresets = True\n",
    "kt_coresets = True\n",
    "compress_thin_coresets = True\n",
    "# compressthin_thin_coresets = False\n",
    "herding_coresets = False\n",
    "cp_herding_coresets = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjust_alpha = True # only comp0 and comp4 will be loaded\n",
    "sizes_kt = [4, 8]\n",
    "sizes_non_kt = [4, 10]\n",
    "if adjust_alpha is False:\n",
    "    sizes_kt = None\n",
    "    sizes_non_kt = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parse_arguments()\n",
    "\n",
    "args.krt = 0 # 0 if no krt, anything else for krt\n",
    "\n",
    "# mmd and size params\n",
    "args.combine_mmd = 1 \n",
    "args.computemmd = 1\n",
    "args.size_min, args.size_max = 4, 8\n",
    "args.rep0 = 0\n",
    "args.repn = 10\n",
    "args.alpha_max = 4\n",
    "args.symm1 = 1\n",
    "\n",
    "# rerun/compute params\n",
    "args.rerun = 0\n",
    "args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# recombine\n",
    "args.recombine = False # whether to create a new combined mmd file in case one already exists\n",
    "\n",
    "if run_mog_experiments:\n",
    "    args.setting = \"mog\"\n",
    "    args.d = 2\n",
    "    d = args.d\n",
    "    for M in Ms:\n",
    "        args.M = M\n",
    "        combined_mmd_results[\"M_%d\"%M] = dict()\n",
    "        \n",
    "        # compute d, params_p and var_k for the setting\n",
    "        _, params_p, var_k = compute_params_p(args)\n",
    "        \n",
    "        \n",
    "        # define the kernels\n",
    "        params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "                                                            use_krt_split=args.krt, name=\"gauss\")\n",
    "        \n",
    "        flags = [st_coresets, compress_thin_coresets, # compressthin_thin_coresets,\n",
    "                 kt_coresets, herding_coresets, cp_herding_coresets]\n",
    "        labels = [\"ST\", \"CP+KT\", # \"CPthin+KT\",\n",
    "                  \"KT\", \"CPHerd\", \"Herd\"]\n",
    "        for f, lab in zip(flags, labels):\n",
    "            if f:\n",
    "                if adjust_alpha:\n",
    "                    if lab == \"KT\":\n",
    "                        args.size_min, args.size_max = sizes_kt[0], sizes_kt[1]\n",
    "                    else:\n",
    "                        args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "                    \n",
    "                args.cset = lab\n",
    "                combined_mmd_results[\"M_%d\"%M][lab] = dict()\n",
    "                combined_mmd_results[\"M_%d\"%M][lab] = combine_results(args, params_p, params_k_split, params_k_swap, adjust_alpha=adjust_alpha)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys = []\n",
    "for f, lab in zip(flags, labels):\n",
    "    if f:\n",
    "        plot_keys.append(lab)\n",
    "\n",
    "print(plot_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_fig = True\n",
    "if run_mog_experiments:\n",
    "    simplified_plot = True\n",
    "    if not simplified_plot:\n",
    "        skip_ns = int(0)\n",
    "        axes = plt.subplots(1, len(Ms), figsize=[6*len(Ms), 5], sharex=True, sharey=False)\n",
    "        for i, M in enumerate(Ms):\n",
    "            ax = axes[1] if len(Ms) == 1 else axes[1][i]\n",
    "            simple_mmd_dict = post_process(combined_mmd_results[f'M_{M}'], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "            plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "            ax.set_title(r\"M=$%d$\"%(M), fontsize=title_size)\n",
    "            if i==0:\n",
    "                ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "            ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "        plt.tight_layout()\n",
    "    #     plt.suptitle(r\"Gauss Target\", y=1.05, fontsize=title_size)\n",
    "        plt.show()\n",
    "#     simplified_plot = True\n",
    "    if simplified_plot:\n",
    "        skip_ns = int(0)\n",
    "        axes = plt.subplots(1, len(Ms), figsize=[6*len(Ms), 5], sharex=True, sharey=False)\n",
    "        for i, M in enumerate(Ms):\n",
    "            ax = axes[1] if len(Ms) == 1 else axes[1][i]\n",
    "            simple_mmd_dict = post_process(combined_mmd_results[f'M_{M}'], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "            plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], change_cpthin=True, alpha_kt=0.6,\n",
    "                          legend_loc='lower left',\n",
    "                          legend_size=16, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "            #ax.set_ylim([2**(-12), 2**(-2)])\n",
    "            ax.set_title(r\"M=$%d$\"%(M), fontsize=title_size)\n",
    "            if i==0:\n",
    "                ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "            ax.set_xlabel(r\"Input size $n$\", fontsize=xlab_size)\n",
    "        plt.tight_layout()\n",
    "        if save_fig:\n",
    "            suff  = f\"4_{sizes_non_kt[1]}\" if adjust_alpha else \"\"\n",
    "            if herding_coresets:\n",
    "                plt.savefig(fig_dir + f\"mog_herd_mmd_{suff}.pdf\")\n",
    "            else:\n",
    "                plt.savefig(fig_dir + f\"mog_kt_mmd_{suff}.pdf\")\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MCMC results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mcmc_titles(a):\n",
    "    # function to reformat titles for figures corresponding to Hinch data with filename starting with a\n",
    "    a = a.replace(\"_P_\", \" \")\n",
    "    a = a.replace(\"_TP_\", \" Tempered \")\n",
    "    a = a.replace(\"seed_1\", \"1\")\n",
    "    a = a.replace(\"seed_2\", \"2\")\n",
    "    a = a.replace(\"_temp_1\", \"\")\n",
    "    a = a.replace(\"_temp_8\", \"\")\n",
    "    a = a.replace(\"_scaled\", \"\")\n",
    "    a = a.replace(\"_float_step\", \"\")\n",
    "    a = a.replace(\"_\", \" \")\n",
    "    a = a.replace(\"PRECOND-\", \"p\")\n",
    "#     a += \")\"\n",
    "    return(a)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "st_coresets = True\n",
    "kt_coresets = True\n",
    "compress_thin_coresets = True\n",
    "# compressthin_thin_coresets = False\n",
    "herding_coresets = False\n",
    "cp_herding_coresets = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjust_alpha = True # only comp0 and comp4 will be loaded\n",
    "sizes_kt = [4, 8]\n",
    "sizes_non_kt = [4, 9]\n",
    "if adjust_alpha is False:\n",
    "    sizes_kt = None\n",
    "    sizes_non_kt = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_mcmc_experiments = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parse_arguments()\n",
    "\n",
    "args.krt = 0 # 0 if no krt, anything else for krt\n",
    "\n",
    "# mmd and size params\n",
    "args.combine_mmd = 1 \n",
    "args.computemmd = 1\n",
    "args.size_min, args.size_max = 4, 8\n",
    "args.rep0 = 0\n",
    "args.repn = 10\n",
    "args.alpha_max = 4\n",
    "args.symm1 = 1\n",
    "\n",
    "# rerun/compute params\n",
    "args.rerun = 0\n",
    "args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# recombine\n",
    "args.recombine = False # whether to create a new combined mmd file in case one already exists\n",
    "\n",
    "\n",
    "if run_mcmc_experiments:\n",
    "    args.setting = \"mcmc\"\n",
    "    for filename in all_mcmc_filenames:\n",
    "        args.filename = filename\n",
    "        combined_mmd_results[filename] = dict()\n",
    "        \n",
    "        # compute d, params_p and var_k for the setting\n",
    "        d, params_p, var_k = compute_params_p(args)\n",
    "        args.d  = d \n",
    "        \n",
    "        # define the kernels\n",
    "        params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "                                                            use_krt_split=args.krt, name=\"gauss\")\n",
    "        \n",
    "        flags = [st_coresets, compress_thin_coresets, # compressthin_thin_coresets,\n",
    "                 kt_coresets, herding_coresets, cp_herding_coresets]\n",
    "        labels = [\"ST\", \"CP+KT\", # \"CPthin+KT\",\n",
    "                  \"KT\", \"Herd\", \"CPHerd\"]\n",
    "        for f, lab in zip(flags, labels):\n",
    "            if f:\n",
    "                if adjust_alpha:\n",
    "                    if lab == \"KT\" or 'Hinch' in args.filename:\n",
    "                        args.size_min, args.size_max = sizes_kt[0], sizes_kt[1]\n",
    "                    else:\n",
    "                        args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "                args.cset = lab\n",
    "                combined_mmd_results[filename][lab] = dict()\n",
    "                combined_mmd_results[filename][lab] = combine_results(args, params_p, params_k_split, params_k_swap, adjust_alpha=adjust_alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lss = ['-', '-.',  ':', '--',  '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "mss = ['>', 's', 'o', 'D', '+', '*',  '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "# colors = ['#e41a1c', 'cyan',   '#0000cd',   '#4daf4a', 'magenta', 'black' ,'orange','yellow','gray']*2\n",
    "colors = ['#e41a1c', #red\n",
    "          'orange',  \n",
    "          '#0000cd', #blue\n",
    "            '#4daf4a', #green\n",
    "          'magenta', 'black' , 'yellow','gray']*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "title_size = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys = []\n",
    "for f, lab in zip(flags, labels):\n",
    "    if f:\n",
    "        plot_keys.append(lab)\n",
    "\n",
    "print(plot_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjust_alpha = True # only comp0 and comp4 will be loaded\n",
    "sizes_kt = [4, 8]\n",
    "sizes_non_kt = [4, 8]\n",
    "if adjust_alpha is False:\n",
    "    sizes_kt = None\n",
    "    sizes_non_kt = None\n",
    "    \n",
    "save_fig = False\n",
    "if run_mcmc_experiments:\n",
    "    for t in [0, 4, 8]:\n",
    "        if t == 8 and adjust_alpha:\n",
    "            sizes_non_kt = sizes_kt.copy()\n",
    "#         else:\n",
    "#             args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "        plot_files = np.array(all_mcmc_filenames)[range(t, t+4)]\n",
    "        simplified_plot = True\n",
    "        if not simplified_plot:\n",
    "            skip_ns = int(0)\n",
    "            axes = plt.subplots(1, len(plot_files), figsize=[6*len(plot_files), 5], sharex=True, sharey=True)\n",
    "            for i, filename in enumerate(plot_files):\n",
    "                ax = axes[1] if len(plot_files) == 1 else axes[1][i]\n",
    "                \n",
    "                simple_mmd_dict = post_process(combined_mmd_results[filename], args, plot_keys,adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "                ax.set_title(filename, fontsize=title_size)\n",
    "                if 'Hinch' not in filename:\n",
    "                    ax.set_xticks([4**i for i in range(4, 10)])\n",
    "                    ax.set_xticklabels([4**i for i in range(4, 10)])\n",
    "                if i==0:\n",
    "                    ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "                ax.set_xlabel(r\"Input size $n$\", fontsize=xlab_size)\n",
    "            plt.tight_layout()\n",
    "        #     plt.suptitle(r\"Gauss Target\", y=1.05, fontsize=title_size)\n",
    "            plt.show()\n",
    "    #     simplified_plot = True\n",
    "        if simplified_plot:\n",
    "            skip_ns = int(0)\n",
    "            axes = plt.subplots(1, len(plot_files), figsize=[6*len(plot_files), 5], sharex=True, sharey=False)\n",
    "            for i, filename in enumerate(plot_files):\n",
    "                ax = axes[1] if len(plot_files) == 1 else axes[1][i]\n",
    "                simple_mmd_dict = post_process(combined_mmd_results[filename], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], change_cpthin=True, alpha_kt=0.6,\n",
    "                          legend_loc='lower left',\n",
    "                          legend_size=16, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "                #ax.set_ylim([2**(-12), 2**(-2)])\n",
    "                ax.set_title(mcmc_titles(filename), fontsize=title_size)\n",
    "                if 'Hinch' not in filename:\n",
    "                    ax.set_xticks([4**i for i in range(4, 10)])\n",
    "                    #ax.set_xticklabels([4**i for i in range(4, 10)])\n",
    "                if i==0:\n",
    "                    ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "                ax.set_xlabel(r\"Input size $n$\", fontsize=xlab_size)\n",
    "            plt.tight_layout()\n",
    "            if save_fig:\n",
    "                suff  = \"4_9\" if adjust_alpha else \"\"\n",
    "                plt.savefig(fig_dir + f\"mcmc_{plot_files[0][:4]}_kt_mmd_{suff}.pdf\")\n",
    "            plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hinch plot\n",
    "\n",
    "# Bit messy\n",
    "- first get all Hinch results for symm 1\n",
    "- then symm 0\n",
    "- use recombine carefully because the filenames do not save symm flag in the name\n",
    "- to save all trouble, just use combined_mmd_results = joblib.load(\"results/combined/Hinch_all.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_mmd_results = joblib.load(\"results/combined/Hinch_all.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# args = parse_arguments()\n",
    "\n",
    "# args.krt = 0 # 0 if no krt, anything else for krt\n",
    "\n",
    "# # mmd and size params\n",
    "# args.combine_mmd = 1 \n",
    "# args.computemmd = 1\n",
    "# args.size_min, args.size_max = 4, 8\n",
    "# args.rep0 = 0\n",
    "# args.repn = 10\n",
    "# args.alpha_max = 4\n",
    "# args.symm1 = 0\n",
    "\n",
    "# # rerun/compute params\n",
    "# args.rerun = 0\n",
    "# args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# # recombine\n",
    "# args.recombine = True # whether to create a new combined mmd file in case one already exists\n",
    "\n",
    "\n",
    "# if run_mcmc_experiments:\n",
    "#     args.setting = \"mcmc\"\n",
    "#     for filename in np.array(all_mcmc_filenames)[range(8, 12)]:\n",
    "#         args.filename = filename\n",
    "#         no_symm_mmd[filename] = dict()\n",
    "        \n",
    "#         # compute d, params_p and var_k for the setting\n",
    "#         d, params_p, var_k = compute_params_p(args)\n",
    "#         args.d  = d \n",
    "        \n",
    "#         # define the kernels\n",
    "#         params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "#                                                             use_krt_split=args.krt, name=\"gauss\")\n",
    "        \n",
    "#         flags = [st_coresets, compress_thin_coresets, # compressthin_thin_coresets,\n",
    "#                  kt_coresets, herding_coresets, cp_herding_coresets]\n",
    "#         labels = [\"ST\", \"CP+KT\", # \"CPthin+KT\",\n",
    "#                   \"KT\", \"Herd\", \"CPHerd\"]\n",
    "#         for f, lab in zip(flags, labels):\n",
    "#             if f:\n",
    "#                 if adjust_alpha:\n",
    "#                     if lab == \"KT\" or 'Hinch' in args.filename:\n",
    "#                         args.size_min, args.size_max = sizes_kt[0], sizes_kt[1]\n",
    "#                     else:\n",
    "#                         args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "#                 args.cset = lab\n",
    "#                 no_symm_mmd[filename][lab] = dict()\n",
    "#                 no_symm_mmd[filename][lab] = combine_results(args, params_p, params_k_split, params_k_swap, adjust_alpha=adjust_alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for filename in np.array(all_mcmc_filenames)[range(t, t+4)]:\n",
    "#     combined_mmd_results[filename][\"CP+KTnosymm\"] = dict()\n",
    "#     combined_mmd_results[filename][\"CP+KTnosymm\"][\"CP+KTnosymm\"] = dict()\n",
    "#     combined_mmd_results[filename][\"CP+KTnosymm\"][\"CP+KTnosymm\"] = no_symm_mmd[filename][\"CP+KT\"][\"CP+KT\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys = ['ST', 'CP+KT', \"CP+KTnosymm\", 'KT', ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lss = ['-', '-.',  ':', '-', '--', '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "mss = ['>', 's', 'o', '+', 'D',  '*',  '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "colors = ['#e41a1c', 'cyan',   '#0000cd',   '#4daf4a', 'magenta', 'black' ,'orange','yellow','gray']*2\n",
    "colors = ['#e41a1c', #red\n",
    "          'orange',  \n",
    "          '#0000cd', #blue\n",
    "          'darkgoldenrod',\n",
    "            '#4daf4a', #green\n",
    "          'magenta', 'black' , 'yellow',]*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "basex = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjust_alpha = True # only comp0 and comp4 will be loaded\n",
    "sizes_kt = [4, 8]\n",
    "sizes_non_kt = [4, 8]\n",
    "if adjust_alpha is False:\n",
    "    sizes_kt = None\n",
    "    sizes_non_kt = None\n",
    "    \n",
    "save_fig = True\n",
    "if run_mcmc_experiments:\n",
    "    for t in [8]:\n",
    "#         if t == 8 and adjust_alpha:\n",
    "#             sizes_non_kt = sizes_kt.copy()\n",
    "#         else:\n",
    "#             args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "        plot_files = np.array(all_mcmc_filenames)[range(t, t+4)]\n",
    "        simplified_plot = True\n",
    "        if not simplified_plot:\n",
    "            skip_ns = int(0)\n",
    "            axes = plt.subplots(1, len(plot_files), figsize=[6*len(plot_files), 5], sharex=True, sharey=True)\n",
    "            for i, filename in enumerate(plot_files):\n",
    "                ax = axes[1] if len(plot_files) == 1 else axes[1][i]\n",
    "                \n",
    "                simple_mmd_dict = post_process(combined_mmd_results[filename], args, plot_keys,adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "                ax.set_title(filename, fontsize=title_size)\n",
    "                if i==0:\n",
    "                    ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "                ax.set_xlabel(r\"Input size $n$\", fontsize=xlab_size)\n",
    "            plt.tight_layout()\n",
    "        #     plt.suptitle(r\"Gauss Target\", y=1.05, fontsize=title_size)\n",
    "            plt.show()\n",
    "    #     simplified_plot = True\n",
    "        if simplified_plot:\n",
    "            skip_ns = int(0)\n",
    "            axes = plt.subplots(1, len(plot_files), figsize=[6*len(plot_files), 5], sharex=True, sharey=False)\n",
    "            for i, filename in enumerate(plot_files):\n",
    "                ax = axes[1] if len(plot_files) == 1 else axes[1][i]\n",
    "                simple_mmd_dict = post_process(combined_mmd_results[filename], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], change_cpthin=True, alpha_kt=0.6,\n",
    "                          legend_loc='lower left',\n",
    "                          legend_size=16, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt, basex=basex, basey=basey)\n",
    "                #ax.set_ylim([2**(-12), 2**(-2)])\n",
    "                ax.set_title(mcmc_titles(filename), fontsize=title_size)\n",
    "                if i==0:\n",
    "                    ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "                ax.set_xlabel(\"Input size $n$\", fontsize=xlab_size)\n",
    "            plt.tight_layout()\n",
    "            if save_fig:\n",
    "                suff  = \"4_9\" if adjust_alpha else \"\"\n",
    "                plt.savefig(fig_dir + f\"hinch_all.pdf\")\n",
    "            plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "joblib.dump(combined_mmd_results, \"results/combined/Hinch_all.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "st_coresets = True\n",
    "kt_coresets = True\n",
    "compress_thin_coresets = True\n",
    "# compressthin_thin_coresets = False\n",
    "herding_coresets = False\n",
    "cp_herding_coresets = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjust_alpha = False # only comp0 and comp4 will be loaded\n",
    "sizes_kt = [4, 8]\n",
    "sizes_non_kt = [4, 9]\n",
    "if adjust_alpha is False:\n",
    "    sizes_kt = None\n",
    "    sizes_non_kt = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_mcmc_experiments = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parse_arguments()\n",
    "\n",
    "args.krt = 0 # 0 if no krt, anything else for krt\n",
    "\n",
    "# mmd and size params\n",
    "args.combine_mmd = 1 \n",
    "args.computemmd = 1\n",
    "args.size_min, args.size_max = 6, 8\n",
    "args.rep0 = 0\n",
    "args.repn = 10\n",
    "args.alpha_max = 6\n",
    "args.symm1 = 1\n",
    "\n",
    "# rerun/compute params\n",
    "args.rerun = 0\n",
    "args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# recombine\n",
    "args.recombine = False # whether to create a new combined mmd file in case one already exists\n",
    "\n",
    "\n",
    "if run_mcmc_experiments:\n",
    "    args.setting = \"mcmc\"\n",
    "    for filename in np.array(all_mcmc_filenames)[range(8, 12)]:\n",
    "        args.filename = filename\n",
    "        combined_mmd_results[filename] = dict()\n",
    "        \n",
    "        # compute d, params_p and var_k for the setting\n",
    "        d, params_p, var_k = compute_params_p(args)\n",
    "        args.d  = d \n",
    "        \n",
    "        # define the kernels\n",
    "        params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "                                                            use_krt_split=args.krt, name=\"gauss\")\n",
    "        \n",
    "        flags = [st_coresets, compress_thin_coresets, # compressthin_thin_coresets,\n",
    "                 kt_coresets, herding_coresets, cp_herding_coresets]\n",
    "        labels = [\"ST\", \"CP+KT\", # \"CPthin+KT\",\n",
    "                  \"KT\", \"Herd\", \"CPHerd\"]\n",
    "        for f, lab in zip(flags, labels):\n",
    "            if f:\n",
    "                if adjust_alpha:\n",
    "                    if lab == \"KT\" or 'Hinch' in args.filename:\n",
    "                        args.size_min, args.size_max = sizes_kt[0], sizes_kt[1]\n",
    "                    else:\n",
    "                        args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "                args.cset = lab\n",
    "                combined_mmd_results[filename][lab] = dict()\n",
    "                combined_mmd_results[filename][lab] = combine_results(args, params_p, params_k_split, params_k_swap, adjust_alpha=adjust_alpha)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for sz in [6, 7, 8]:\n",
    "    print(f\"size = {sz}\")\n",
    "    print(\"KT:\", combined_mmd_results[filename][\"KT\"][\"KT\"][f\"size_{sz}\"].mean(0), combined_mmd_results[filename][\"KT\"][\"KT\"][f\"size_{sz}\"].std(0))\n",
    "    for alpha in range(args.alpha_max+1):\n",
    "        print(f\"CP+KT, alpha ={alpha}\", combined_mmd_results[filename][\"CP+KT\"][\"CP+KT\"][f\"size_{sz}\"][f\"alpha_{alpha}\"].mean(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys = []\n",
    "for f, lab in zip(flags, labels):\n",
    "    if f:\n",
    "        plot_keys.append(lab)\n",
    "\n",
    "print(plot_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lss = ['-', '-.',  ':', '--',  '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "mss = ['>', 's', 'o', 'D', '+', '*',  '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "# colors = ['#e41a1c', 'cyan',   '#0000cd',   '#4daf4a', 'magenta', 'black' ,'orange','yellow','gray']*2\n",
    "colors = ['#e41a1c', #red\n",
    "          'orange',  \n",
    "          '#0000cd', #blue\n",
    "            '#4daf4a', #green\n",
    "          'magenta', 'black' , 'green','gray']*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "    \n",
    "save_fig = False\n",
    "if run_mog_experiments:\n",
    "    for t in [0, 4, 8]:\n",
    "#         if t == 8:\n",
    "#             sizes_non_kt = sizes_kt.copy()\n",
    "#         else:\n",
    "#             args.size_min, args.size_max = sizes_non_kt[0], sizes_non_kt[1]\n",
    "        plot_files = np.array(all_mcmc_filenames)[range(t, t+4)]\n",
    "        simplified_plot = False\n",
    "        if not simplified_plot:\n",
    "            skip_ns = int(0)\n",
    "            axes = plt.subplots(1, len(plot_files), figsize=[10*len(plot_files), 7], sharex=True, sharey=False)\n",
    "            for i, filename in enumerate(plot_files):\n",
    "                ax = axes[1] if len(plot_files) == 1 else axes[1][i]\n",
    "                simple_mmd_dict = post_process(combined_mmd_results[filename], args, plot_keys,adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                ax.set_title(filename, fontsize=title_size)\n",
    "                if i==0:\n",
    "                    ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "                ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "            plt.tight_layout()\n",
    "        #     plt.suptitle(r\"Gauss Target\", y=1.05, fontsize=title_size)\n",
    "            plt.show()\n",
    "    #     simplified_plot = True\n",
    "        if simplified_plot:\n",
    "            skip_ns = int(0)\n",
    "            axes = plt.subplots(1, len(plot_files), figsize=[6*len(plot_files), 5], sharex=True, sharey=False)\n",
    "            for i, filename in enumerate(plot_files):\n",
    "                ax = axes[1] if len(plot_files) == 1 else axes[1][i]\n",
    "                simple_mmd_dict = post_process(combined_mmd_results[filename], args, plot_keys, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                plot_mmd_dict_new(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                          skip_ns=skip_ns, error_bar=True, rm_keys = [], change_cpthin=True, alpha_kt=0.6,\n",
    "                          legend_loc='lower left',\n",
    "                          legend_size=16, adjust_alpha=adjust_alpha, \n",
    "                                           sizes_non_kt = sizes_non_kt, sizes_kt=sizes_kt)\n",
    "                #ax.set_ylim([2**(-12), 2**(-2)])\n",
    "                ax.set_title(mcmc_titles(filename), fontsize=title_size)\n",
    "                if i==0:\n",
    "                    ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "                ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "            plt.tight_layout()\n",
    "            if save_fig:\n",
    "                suff  = \"4_9\" if adjust_alpha else \"\"\n",
    "                plt.savefig(fig_dir + f\"mcmc_{plot_files[0][:4]}_kt_mmd_{suff}.pdf\")\n",
    "            plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(simple_mmd_dict[\"CP+KT-0\"].mean(1))\n",
    "print(simple_mmd_dict[\"CP+KT-1\"].mean(1))\n",
    "print(simple_mmd_dict[\"CP+KT-2\"].mean(1))\n",
    "print(simple_mmd_dict[\"CP+KT-3\"].mean(1))\n",
    "print(simple_mmd_dict[\"CP+KT-4\"].mean(1))\n",
    "print(simple_mmd_dict[\"KT\"].mean(1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Time plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_time_dir = \"results/run_time\"\n",
    "num_trials = 3\n",
    "size_time = range(4, 11)\n",
    "alpha_max = 4\n",
    "alphas = [0, 4]\n",
    "d_time = [2, 4, 10, 100]\n",
    "\n",
    "kt_times = np.zeros((len(d_time), len(size_time), num_trials))\n",
    "cpp_times = np.zeros((len(d_time), alpha_max+1, len(size_time), num_trials))\n",
    "\n",
    "for i, d in enumerate(d_time):\n",
    "    for j, size in enumerate(size_time):\n",
    "        for t in range(num_trials):\n",
    "            if size < 10:\n",
    "                filename = os.path.join(results_time_dir, f\"kt_gauss_k_d_{d}_size_{size}_rep_{t}.pkl\")\n",
    "                kt_times[i, j, t] = joblib.load(filename)\n",
    "            for a in alphas: #range(0, min(alpha_max, size)+1)\n",
    "#                 filename = os.path.join(results_time_dir, f\"cpthin{a}_gauss_k_d_{d}_size_{size}_rep_{t}.pkl\") # same partition as kt but till 4^9\n",
    "\n",
    "                filename = os.path.join(results_time_dir, f\"cpthin{a}_part_yugroup_c1_mem_32G_gauss_k_d_{d}_size_{size}_rep_{t}.pkl\") # different partition till 4^`10\n",
    "                cpp_times[i, a, j, t] = joblib.load(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lss = ['--', ':',  '-.', '--',  '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "mss = ['D', 'o', 's', 'D', '+', '*',  '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "colors = ['#4daf4a', #green\n",
    "          '#0000cd', #blue\n",
    "          'orange',  \n",
    "          '#e41a1c', #red\n",
    "          'magenta', 'black' , 'yellow','gray']*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "savefig = True\n",
    "for ylog in [True]:\n",
    "    axes = plt.subplots(1, len(d_time), figsize=[6*len(d_time), 5], sharex=True, sharey=False)\n",
    "    for i, d, in enumerate(d_time):\n",
    "        ax = axes[1][i]\n",
    "        \n",
    "        y =  kt_times[i, :, :].mean(1)\n",
    "        yerr =  2*kt_times[i, :, :].std(1) / np.sqrt(num_trials)\n",
    "        ax.errorbar(4**np.array(size_time[:-1]), y[:-1], marker=mss[0], \n",
    "                yerr=yerr[:-1], #np.array([-yerr, yerr]),\n",
    "                linestyle=lss[0], markersize=ms_size[0]/2, color=colors[0], linewidth =3,\n",
    "                 label='KT')\n",
    "        for jj, a in enumerate([alpha_max, 0]): #range(0, min(alpha_max, size)+1):\n",
    "            y = cpp_times[i, a, :, :].mean(1)\n",
    "            yerr = cpp_times[i, a, :, :].std(1) / np.sqrt(num_trials)\n",
    "            if a == 0:\n",
    "                lab = \"KT-Comp\"\n",
    "            else:\n",
    "                lab = 'KT-Comp++'\n",
    "            ax.errorbar(4**np.array(size_time), y, yerr=np.array([-yerr, yerr]),\n",
    "                    marker=mss[1+jj], linestyle=lss[jj+1], markersize=ms_size[1+jj]/2,\n",
    "                    color=colors[1+jj], linewidth = 3,\n",
    "    #                 linestyle='--', \n",
    "                    label=lab)\n",
    "        ax.set_xscale('log', basex=10)\n",
    "        if ylog:\n",
    "            ax.set_yscale('log', basey=10)\n",
    "            ax.set_ylim([5*1e-2, 1.4e5])\n",
    "            ax.yaxis.tick_right()\n",
    "            ax.set_yticks([1, 60, 600, 3600, 28800, 86400])\n",
    "            ax.set_yticklabels( ['1s','1m', '10m', '1hr', '8hr', '1day'], fontsize=20, color='black')\n",
    "            ax.set_xticks([10**3, 10**4, 10**5, 10**6])\n",
    "#             ax.set_xticklabels( ['1K','10K', '100K', '1M'], fontsize=20, color='gray')\n",
    "        ax.set_xlabel(r'Input size $n$', fontsize=xlab_size)\n",
    "\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['left'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        # ax.spines['bottom'].set_visible(False)a\n",
    "        \n",
    "        ax.grid(True, alpha=0.4)\n",
    "        ax.set_title(\"d=%d\"%d, fontsize=1.1*title_size)\n",
    "        if i == 0:\n",
    "#             locs, labels = ax.get_yticks()  \n",
    "#             ax.set_yticks([1, 60, 600, 3600, 36000, 86400])\n",
    "            ax.set_ylabel('Single core runtime',fontsize=ylab_size)\n",
    "            ax.legend(fontsize=22)\n",
    "    plt.tight_layout()\n",
    "    if savefig:\n",
    "#         plt.savefig(fig_dir + \"time_kt_1M.pdf\")\n",
    "        plt.savefig(fig_dir + \"time_kt.pdf\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Herding run-time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "results_time_dir = \"results/run_time\"\n",
    "num_trials = 3\n",
    "size_time = range(4, 10)\n",
    "alpha_max = 2\n",
    "d_time = [2, 4, 10, 100]\n",
    "\n",
    "herd_times = np.zeros((len(d_time), len(size_time), num_trials))\n",
    "cpp_herd_times = np.zeros((len(d_time), alpha_max+1, len(size_time), num_trials))\n",
    "# cpthin2herding_part_yugroup_c1_mem_32G_gauss_k_d_100_size_4_rep_1.pkl\n",
    "for i, d in enumerate(d_time):\n",
    "    for j, size in enumerate(size_time):\n",
    "        for t in range(num_trials):\n",
    "            if size < 10:\n",
    "                filename = os.path.join(results_time_dir, \"herding__\" + f\"gauss_k_d_{d}_size_{size}_rep_{t}.pkl\")\n",
    "                herd_times[i, j, t] = joblib.load(filename)\n",
    "            for a in range(0, min(alpha_max, size)+1):\n",
    "                filename = os.path.join(results_time_dir, f\"cpthin{a}herding_part_yugroup_c1_mem_32G_gauss_k_d_{d}_size_{size}_rep_{t}.pkl\")\n",
    "                cpp_herd_times[i, a, j, t] = joblib.load(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lss = ['--', ':',  '-.', '--',  '--', '-.', ':', '-', '--', '-.', ':', '-']*2\n",
    "mss = ['D', 'o', 's', 'D', '+', '*',  '>', 's', 'o', 'D', '>', 's', 'o', 'D']*2\n",
    "ms_size = [25, 20, 20, 20, 20, 20, 20, 20, 20, 20]\n",
    "colors = ['#4daf4a', #green\n",
    "          '#0000cd', #blue\n",
    "          'orange',  \n",
    "          '#e41a1c', #red\n",
    "          'magenta', 'black' , 'yellow','gray']*2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "savefig = False\n",
    "for ylog in [True]:\n",
    "    axes = plt.subplots(1, len(d_time), figsize=[6*len(d_time), 5], sharex=True, sharey=False)\n",
    "    for i, d, in enumerate(d_time):\n",
    "        ax = axes[1][i]\n",
    "        \n",
    "        y =  herd_times[i, :, :].mean(1)\n",
    "        yerr =  herd_times[i, :, :].std(1) / np.sqrt(num_trials)\n",
    "        ax.errorbar(4**np.array(size_time[:]), y, marker=mss[0], \n",
    "                yerr=yerr, #np.array([-yerr, yerr]),\n",
    "                linestyle=lss[0], markersize=ms_size[0]/2, color=colors[0], linewidth =3,\n",
    "                 label='Herd')\n",
    "        for jj, a in enumerate([alpha_max, 0]): #range(0, min(alpha_max, size)+1):\n",
    "            y = cpp_herd_times[i, a, :, :].mean(1)\n",
    "            yerr = cpp_herd_times[i, a, :, :].std(1) / np.sqrt(num_trials)\n",
    "            if a == 0:\n",
    "                lab = \"Herd-Comp\"\n",
    "            else:\n",
    "                lab = 'Herd-Comp++'\n",
    "            ax.errorbar(4**np.array(size_time), y, yerr=np.array([-yerr, yerr]),\n",
    "                    marker=mss[1+jj], linestyle=lss[jj+1], markersize=ms_size[1+jj]/2,\n",
    "                    color=colors[1+jj], linewidth = 3,\n",
    "    #                 linestyle='--', \n",
    "                    label=lab)\n",
    "        ax.set_xscale('log', basex=10)\n",
    "        if ylog:\n",
    "            ax.set_yscale('log', basey=10)\n",
    "            ax.set_ylim([5*1e-3, 1e5])\n",
    "            ax.yaxis.tick_right()\n",
    "            ax.set_yticks([1, 10, 60, 600, 3600, 36000])\n",
    "            ax.set_yticklabels( ['1s', '10s', '1m', '10m', '1hr', '10hr'], fontsize=20, color='black')\n",
    "            ax.set_xticks([10**3, 10**4, 10**5, 10**6])\n",
    "#             ax.set_xticklabels( ['1K','10K', '100K', '1M'], fontsize=20, color='gray')\n",
    "        ax.set_xlabel('Input size', fontsize=xlab_size)\n",
    "\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['left'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        # ax.spines['bottom'].set_visible(False)a\n",
    "        \n",
    "        ax.grid(True, alpha=0.4)\n",
    "        ax.set_title(\"d=%d\"%d, fontsize=1.1*title_size)\n",
    "        if i == 0:\n",
    "#             locs, labels = ax.get_yticks()  \n",
    "#             ax.set_yticks([1, 60, 600, 3600, 36000, 86400])\n",
    "            ax.set_ylabel('Single core runtime',fontsize=ylab_size)\n",
    "            ax.legend(fontsize=22)\n",
    "    plt.tight_layout()\n",
    "    if savefig:\n",
    "        plt.savefig(fig_dir + \"time_herd.pdf\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## OLD KT TImes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# have results for d = 2, 4, 10, 100; size = range(4, 10), 3 trials\n",
    "old_kt_times, old_cpp_times = joblib.load(\"results/run_time/old_runtimes.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.array(yerr)/y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for ylog in [True, False]:\n",
    "    axes = plt.subplots(1, len(d_time), figsize=[6*len(d_time), 5], sharex=True,)\n",
    "    for i, d, in enumerate(d_time[:]):\n",
    "        ax = axes[1][i]\n",
    "        \n",
    "        y =  kt_times[i, :, :].mean(1)\n",
    "        yerr =  2*kt_times[i, :, :].std(1) / np.sqrt(num_trials)\n",
    "        ax.errorbar(4**np.array(size_time), y, marker=mss[0], \n",
    "                yerr=np.array(yerr),\n",
    "                linestyle=lss[0], markersize=ms_size[0]/2, color='red',\n",
    "                 label='KT')\n",
    "        \n",
    "        y =  old_kt_times[i, :len(size_time), :num_trials].mean(1)\n",
    "        yerr =  old_kt_times[i, :len(size_time), :num_trials].std(1) / np.sqrt(num_trials)\n",
    "        \n",
    "        ax.errorbar(4**np.array(size_time), y, marker=mss[0], \n",
    "                yerr=np.array([-yerr, yerr]),\n",
    "                linestyle=lss[0], markersize=ms_size[0]/2, color='orange', # alpha = 0.5,\n",
    "                 label='KT old')\n",
    "\n",
    "#         plot_old = False\n",
    "#         a = alpha_max\n",
    "#         lab = \"Comp\" if a==0 else f\"C++{a}\"\n",
    "#         color = 'olive' if a==0 else 'black'\n",
    "#         for a in [0, alpha_max]:\n",
    "#             y =  cpp_times[i, a, :, :].mean(1)\n",
    "#             yerr =  2*cpp_times[i, a, :, :].std(1) / np.sqrt(num_trials)\n",
    "#             ax.errorbar(4**np.array(size_time[:]), y, marker=mss[a], \n",
    "#                     yerr=np.array(yerr),\n",
    "#                     linestyle=lss[a], markersize=ms_size[a]/2, color=color,\n",
    "#                      label=lab)\n",
    "#             if plot_old:\n",
    "#                 color = 'green' if a==0 else 'gray'\n",
    "#                 lab += \" old\"\n",
    "#                 y =  old_cpp_times[i, a, :len(size_time), :num_trials].mean(1)\n",
    "#                 yerr =  old_cpp_times[i, a, :len(size_time), :num_trials ].std(1) / np.sqrt(num_trials)\n",
    "#                 ax.errorbar(4**np.array(size_time), y, marker=mss[a+1], \n",
    "#                         yerr=np.array([-yerr, yerr]),\n",
    "#                         linestyle=lss[a], markersize=ms_size[a+1]/2, color=color, # alpha = 0.5,\n",
    "#                          label=lab)\n",
    "\n",
    "        \n",
    "        ax.set_xscale('log', basex=4)\n",
    "        if ylog:\n",
    "            ax.set_yscale('log', basey=4)\n",
    "        ax.set_xlabel('Input size')\n",
    "\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        # ax.spines['left'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        # ax.spines['bottom'].set_visible(False)\n",
    "        ax.grid(True, alpha=0.4)\n",
    "        ax.set_title(\"d=%d\"%d)\n",
    "        if i == 0:\n",
    "            ax.set_ylabel('Run time (seconds)',fontsize=ylab_size*0.9)\n",
    "            ax.legend(fontsize=15)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for ylog in [False]:\n",
    "    axes = plt.subplots(1, len(d_time), figsize=[6*len(d_time), 5], sharex=True, sharey=True)\n",
    "    for i, d, in enumerate(d_time[:]):\n",
    "        ax = axes[1][i]\n",
    "        \n",
    "        a = alpha_max\n",
    "        lab = \"Comp\" if a==0 else f\"C++{a}\"\n",
    "        color = 'olive' if a==0 else 'black'\n",
    "        y = cpp_times[i, a, 1:, :].mean(1)/cpp_times[i, a, :-1, :].mean(1)\n",
    "#         yerr =  cpp_times[i, a, :, :].std(1) / np.sqrt(num_trials)\n",
    "        ax.plot(4**np.array(size_time[:-1]), y, marker=mss[a+1], \n",
    "#                 yerr=np.array([-yerr, yerr]),\n",
    "                linestyle=lss[a+1], markersize=ms_size[a+1]/2, color=color,\n",
    "                 label=lab)\n",
    "        \n",
    "        ax.set_xscale('log', basex=4)\n",
    "        if ylog:\n",
    "            ax.set_yscale('log', basey=4)\n",
    "        ax.set_xlabel('n')\n",
    "\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        # ax.spines['left'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        # ax.spines['bottom'].set_visible(False)\n",
    "        ax.grid(True, alpha=0.4)\n",
    "        ax.set_title(\"d=%d\"%d)\n",
    "        if i == 0:\n",
    "            ax.set_ylabel(r'Time ratio $\\frac{T(4n)}{T(n)}$',fontsize=ylab_size*0.9)\n",
    "            ax.legend(fontsize=15)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## New KT vs OLD KT (another attempt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = \"results/run_time\"\n",
    "\n",
    "\n",
    "partition = 'high'\n",
    "num_trials = 2\n",
    "rerun = 1\n",
    "size_time = range(4, 9)\n",
    "d_time = [2, 10]\n",
    "rt = dict()\n",
    "axes = plt.subplots(1, len(d_time)*2, figsize=[6*2*len(d_time), 5], sharex=True, sharey=True)\n",
    "\n",
    "for kk, (mem, cores) in enumerate(zip([\"32G\", \"8G\"],[1, 4])):\n",
    "    \n",
    "    prefix = f'p_{partition}_c{cores}_mem_{mem}' \n",
    "    \n",
    "    rt[prefix] = dict()\n",
    "    rt[prefix][\"kt\"] = np.zeros((len(d_time), len(size_time), num_trials))\n",
    "    rt[prefix][\"ktold\"] = np.zeros((len(d_time), len(size_time), num_trials))\n",
    "\n",
    "    for r0 in range(num_trials):\n",
    "        for i, d in enumerate(d_time):\n",
    "            ax = axes[1][i+kk*2]\n",
    "            for j, size in enumerate(size_time):        \n",
    "                for k, talg in enumerate([\"kt\", \"ktold\"]):\n",
    "                    filename = os.path.join(results_dir, f\"{talg}_{prefix}_gauss_k_d_{d}_size_{size}_rep_{r0}.pkl\")\n",
    "                    rt[prefix][talg][i, j, r0] = joblib.load(filename)\n",
    "                    \n",
    "            ax.plot(4**np.array(size_time), rt[prefix][\"kt\"][i, :, r0], color='red', label=f\"kt{r0}\", linestyle=\"None\", marker=mss[r0+num_trials*k], markersize=8)\n",
    "            ax.plot(4**np.array(size_time), rt[prefix][\"ktold\"][i, :, r0], color='blue', label=f\"kt_old{r0}\",  linestyle=\"None\",marker=mss[r0+(num_trials+1)*k],  markersize=8)\n",
    "            ax.set_title(f'd={d}, set={prefix}')\n",
    "            if i == 0 and kk == 0:\n",
    "                ax.legend(fontsize=15, loc='upper left')\n",
    "            ax.grid(alpha=0.4)\n",
    "            ax.set_xscale('log', basex=4)\n",
    "            ax.set_xlabel('input size n')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = \"results/run_time\"\n",
    "\n",
    "\n",
    "partition = 'high'\n",
    "num_trials = 2\n",
    "rerun = 1\n",
    "size_time = range(4, 9)\n",
    "d_time = [2, 10]\n",
    "rt = dict()\n",
    "axes = plt.subplots(1, len(d_time)*2, figsize=[6*2*len(d_time), 5], sharex=True, sharey=True)\n",
    "\n",
    "for kk, (mem, cores) in enumerate(zip([\"32G\", \"8G\"],[1, 4])):\n",
    "    \n",
    "    prefix = f'p_{partition}_c{cores}_mem_{mem}' \n",
    "    \n",
    "    rt[prefix] = dict()\n",
    "    rt[prefix][\"kt\"] = np.zeros((len(d_time), len(size_time), num_trials))\n",
    "    rt[prefix][\"ktold\"] = np.zeros((len(d_time), len(size_time), num_trials))\n",
    "\n",
    "\n",
    "    for i, d in enumerate(d_time):\n",
    "        ax = axes[1][i+kk*2]\n",
    "        for j, size in enumerate(size_time):        \n",
    "            for k, talg in enumerate([\"kt\", \"ktold\"]):\n",
    "                for r0 in range(num_trials):\n",
    "                    filename = os.path.join(results_dir, f\"{talg}_{prefix}_gauss_k_d_{d}_size_{size}_rep_{r0}.pkl\")\n",
    "                    rt[prefix][talg][i, j, r0] = joblib.load(filename)\n",
    "\n",
    "        ax.errorbar(4**np.array(size_time), \n",
    "                    y= rt[prefix][\"kt\"][i, :, :].mean(1), \n",
    "                    yerr= np.sqrt(2)*rt[prefix][\"kt\"][i, :, :].std(1), \n",
    "                    color='red', label=f\"kt\", linestyle=\"None\", marker=mss[r0+num_trials*k], markersize=8)\n",
    "        ax.errorbar(4**np.array(size_time), \n",
    "                y= rt[prefix][\"ktold\"][i, :, :].mean(1), \n",
    "                yerr= np.sqrt(2)*rt[prefix][\"ktold\"][i, :, :].std(1), \n",
    "                color='blue', label=f\"kt_old\",  linestyle=\"None\",marker=mss[r0+(num_trials+1)*k],  markersize=8)\n",
    "        ax.set_title(f'd={d}, set={prefix}')\n",
    "        if i == 0 and kk == 0:\n",
    "            ax.legend(fontsize=15, loc='upper left')\n",
    "        ax.grid(alpha=0.4)\n",
    "        ax.set_xscale('log', basex=4)\n",
    "        ax.set_xlabel('input size n')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load MOG results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parse_arguments()\n",
    "\n",
    "args.krt = 0\n",
    "\n",
    "# mmd and size params\n",
    "args.combine_mmd = 1 \n",
    "args.computemmd = 1\n",
    "args.size_max = 10\n",
    "args.size_min = 0\n",
    "args.rep0 = 0\n",
    "args.repn = 10\n",
    "args.alpha_max = 0\n",
    "\n",
    "# rerun/compute params\n",
    "args.rerun = 0\n",
    "args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# recombine\n",
    "args.recombine = False # whether to recombine\n",
    "\n",
    "if run_mog_experiments:\n",
    "    args.setting = \"mog\"\n",
    "    args.d = 2\n",
    "    d = args.d\n",
    "    for M in Ms:\n",
    "        args.M = M\n",
    "        combined_mmd_results[\"M_%d\"%M] = dict()\n",
    "        \n",
    "        # compute d, params_p and var_k for the setting\n",
    "        _, params_p, var_k = compute_params_p(args)\n",
    "        \n",
    "        \n",
    "        # define the kernels\n",
    "        params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "                                                            use_krt_split=args.krt, name=\"gauss\")\n",
    "        \n",
    "        flags = [st_coresets, compress_thin_coresets, compressthin_thin_coresets, kt_coresets, herding_coresets, cp_herding_coresets]\n",
    "        labels = [\"ST\", \"CP+KT\", \"CPthin+KT\",\"KT\", \"Herd\", \"CPHerd\"]\n",
    "        for f, lab in zip(flags, labels):\n",
    "            if f:\n",
    "                args.cset = lab\n",
    "                combined_mmd_results[\"M_%d\"%M][lab] = dict()\n",
    "                combined_mmd_results[\"M_%d\"%M][lab] = combine_results(args, params_p, params_k_split, params_k_swap)\n",
    "        \n",
    "#         if kt_coresets:\n",
    "#             args.cset = \"KT\"\n",
    "#             combined_mmd_results[\"M_%d\"%M][\"KT\"] = dict()\n",
    "#             combined_mmd_results[\"M_%d\"%M][\"KT\"] = combine_results(args, params_p, params_k_split, params_k_swap)\n",
    "\n",
    "#         if compress_thin_coresets:\n",
    "#             args.cset = \"CP+KT\"\n",
    "#             combined_mmd_results[\"M_%d\"%M][\"CP+KT\"] = dict()\n",
    "#             combined_mmd_results[\"M_%d\"%M][\"CP+KT\"] = combine_results(args, params_p, params_k_split, params_k_swap)\n",
    "            \n",
    "#         if compressthin_thin_coresets:\n",
    "#             args.cset = \"CPthin+KT\"\n",
    "#             combined_mmd_results[\"M_%d\"%M][\"CPthin+KT\"] = dict()\n",
    "#             combined_mmd_results[\"M_%d\"%M][\"CPthin+KT\"] = combine_results(args, params_p, params_k_split, params_k_swap)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot MoG resutls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_keys = []\n",
    "for f, lab in zip(flags, labels):\n",
    "    if f:\n",
    "        plot_keys.append(lab)\n",
    "\n",
    "# plot_keys = [\"ST\", \"KT\",\"CPthin+KT\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_mog_experiments:\n",
    "    skip_ns = int(5)\n",
    "    axes = plt.subplots(1, len(Ms), figsize=[10*len(Ms), 10], sharex=True, sharey=True)\n",
    "    for i, M in enumerate(Ms):\n",
    "        ax = axes[1] if len(Ms) == 1 else axes[1][i]\n",
    "        simple_mmd_dict = post_process(combined_mmd_results[f'M_{M}'], args, plot_keys)\n",
    "        plot_mmd_dict(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                      skip_ns=skip_ns, error_bar=True, rm_keys = [])\n",
    "        ax.set_title(r\"M=$%d$\"%(M), fontsize=title_size)\n",
    "        if i==0:\n",
    "            ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "        ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "    plt.tight_layout()\n",
    "    plt.suptitle(r\"Mixture of Gaussian Target\", y=1.05, fontsize=title_size)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load MCMC Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = parse_arguments()\n",
    "\n",
    "args.krt = 0 # is krt used for split\n",
    "\n",
    "# mmd and size params\n",
    "args.combine_mmd = 1 \n",
    "args.computemmd = 1\n",
    "args.size_max = 8\n",
    "args.size_min = 4\n",
    "args.rep0 = 0\n",
    "args.repn = 10\n",
    "args.alpha_max = 4\n",
    "\n",
    "# rerun/compute params\n",
    "args.rerun = 0\n",
    "args.recomputemmd = 0 # CHECK THIS\n",
    "\n",
    "# recombine\n",
    "args.recombine = False # whether to recombine\n",
    "\n",
    "if run_mcmc_experiments:\n",
    "    args.setting = \"mcmc\"\n",
    "    for filename in all_mcmc_filenames:\n",
    "        args.filename = filename\n",
    "        combined_mmd_results[filename] = dict()\n",
    "        \n",
    "        # compute d, params_p and var_k for the setting\n",
    "        d, params_p, var_k = compute_params_p(args)\n",
    "        args.d = d\n",
    "        \n",
    "        # define the kernels\n",
    "        params_k_split, params_k_swap, _, _ = compute_params_k(d=d, var_k=var_k, \n",
    "                                                            use_krt_split=args.krt, name=\"gauss\")\n",
    "        \n",
    "        if kt_coresets:\n",
    "            args.cset = \"KT\"\n",
    "            combined_mmd_results[filename][\"KT\"] = dict()\n",
    "            combined_mmd_results[filename][\"KT\"] = combine_results(args, params_p, params_k_split, params_k_swap)\n",
    "\n",
    "        if compress_thin_coresets:\n",
    "            args.cset = \"CP+KT\"\n",
    "            combined_mmd_results[filename][\"CP+KT\"] = dict()\n",
    "            combined_mmd_results[filename][\"CP+KT\"] = combine_results(args, params_p, params_k_split, params_k_swap)\n",
    "            \n",
    "        if compressthin_thin_coresets:\n",
    "            args.cset = \"CPthin+KT\"\n",
    "            combined_mmd_results[filename][\"CPthin+KT\"] = dict()\n",
    "            combined_mmd_results[filename][\"CPthin+KT\"] = combine_results(args, params_p, params_k_split, params_k_swap)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot MCMC results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for t in range(0,3):\n",
    "    mcmc_file_range = slice(t*4, t*4+4)\n",
    "    mcmc_filenames = all_mcmc_filenames[mcmc_file_range]\n",
    "\n",
    "    skip_ns = int(0)\n",
    "    axes = plt.subplots(1, len(mcmc_filenames), figsize=[10*len(mcmc_filenames), 10], sharex=True, sharey=True)\n",
    "    for i, filename in enumerate(mcmc_filenames):\n",
    "        ax = axes[1] if len(mcmc_filenames) == 1 else axes[1][i]\n",
    "        simple_mmd_dict = post_process(combined_mmd_results[filename], args, [\"KT\",\"CPthin+KT\"])\n",
    "        plot_mmd_dict(ax, range(args.size_min, args.size_max+1), simple_mmd_dict, 0.5, \n",
    "                      skip_ns=skip_ns, error_bar=True, rm_keys = [])\n",
    "        ax.set_title(filename, fontsize=title_size)\n",
    "        if i==0:\n",
    "            ax.set_ylabel(ylab, fontsize=ylab_size)\n",
    "        ax.set_xlabel(\"Coreset size $\\sqrt{n}$\", fontsize=xlab_size)\n",
    "    plt.tight_layout()\n",
    "    # plt.suptitle(r\"MCMC Target\", y=1.05, fontsize=title_size)\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
