{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A notebook for running kernel thinning and standard thinning experiments\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import numpy.linalg as npl\n",
    "# from scipy.spatial.distance import pdist\n",
    "\n",
    "from argparse import ArgumentParser\n",
    "import pickle as pkl\n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "\n",
    "# import kernel thinning\n",
    "from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions\n",
    "from kernelthinning.util import isnotebook # Check whether this file is being executed as a script or as a notebook\n",
    "from kernelthinning.util import fprint  # for printing while flushing buffer\n",
    "from kernelthinning.tictoc import tic, toc # for timing blocks of code\n",
    "\n",
    "\n",
    "# utils for generating samples, evaluating kernels, and mmds\n",
    "from util_sample import sample, compute_mcmc_params_p, compute_diag_mog_params, sample_string\n",
    "from util_k_mmd import kernel_eval, squared_mmd, get_combined_mmd_filename\n",
    "\n",
    "# for partial functions, to use kernel_eval for kernel\n",
    "from functools import partial\n",
    "\n",
    "# set things a bit when running the notebook\n",
    "if isnotebook():\n",
    "    # Autoreload packages that are modified\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2\n",
    "    %matplotlib inline\n",
    "    %load_ext line_profiler\n",
    "    # https://jakevdp.github.io/PythonDataScienceHandbook/01.07-timing-and-profiling.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If notebook run as a script, parse command-line arguments\n",
    "if not isnotebook():\n",
    "    parser = ArgumentParser()\n",
    "    parser.add_argument('--rep0', '-r0', type=int, default=0,\n",
    "                        help=\"starting experiment id\")\n",
    "    parser.add_argument('--repn', '-rn', type=int, default=1,\n",
    "                        help=\"number of experiment replication\")\n",
    "    parser.add_argument('--store_K', '-sk', type=bool, default=False,\n",
    "                        help=\"whether to save K matrix, 2-3x faster runtime, but larger memory O(n^2)\")\n",
    "    parser.add_argument('--m', '-m', type=int, default=6,\n",
    "                        help=\"number of thinning rounds\")\n",
    "    parser.add_argument('--d', '-d', type=int, default=1,\n",
    "                        help=\"dimensions\")\n",
    "    parser.add_argument('--M', '-M', type=int, default=None,\n",
    "                        help=\"number of mixture for diag mog in d=2\")\n",
    "    parser.add_argument('--filename', '-f', type=str, default=None,\n",
    "                       help=\"name for saved (MCMC) samples\")\n",
    "    parser.add_argument('--combine_mmd', '-cm', type=bool, default=False,\n",
    "                        help=\"whether to save combined_mmd results; should be set to True once all experiments are done running\")\n",
    "    args, opt = parser.parse_known_args()\n",
    "else:\n",
    "    args = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define kernel thinning experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_kernel_thinning_experiment(m, params_p, params_k_split, params_k_swap, rep_ids,\n",
    "                     delta=None, store_K=False,\n",
    "                      sample_seed=1234567, thin_seed=9876543,\n",
    "                      compute_mmds = True,\n",
    "                      rerun=False, results_dir=\"results_new\"):\n",
    "    \"\"\"Runs kernel thinning experiment using samples from params_p for repetitions over rep_ids,\n",
    "    saves coresets to disk, saves and returns mmd evaluations to disk mmd evaluation\n",
    "    \n",
    "    Args:\n",
    "      m: Number of halving rounds (number of sample points n = 2^{2m})\n",
    "      params_p: Dictionary of distribution parameters recognized by sample()\n",
    "      params_k_split: Dictionary of kernel parameters recognized by kernel_eval()\n",
    "      params_k_swap: Dictionary of kernel parameters recognized by kernel_eval()\n",
    "      rep_ids: Which replication numbers of experiment to run; the replication\n",
    "        number determines the seeds set for reproducibility\n",
    "      delta: delta/(4^m) is the failure probability for\n",
    "        adaptive threshold sequence;\n",
    "      store_K: If False, runs O(nd) space version which does not store kernel\n",
    "        matrix; if True, stores n x n kernel matrix\n",
    "      sample_seed: (Optional) random seed is set to sample_seed + rep\n",
    "        prior to generating input sample for replication rep\n",
    "      thin_seed: (Optional) random seed is set to thin_seed + rep\n",
    "        prior to running thinning for replication rep\n",
    "      rerun: (Optional) If False and results have been previously saved to\n",
    "        disk, load results from disk instead of rerunning experiment\n",
    "      results_dir: (Optional) Directory in which results should be saved\n",
    "      compute_mmds: (Optional) Whether to compute mmds of coresets (using params_k_swap)\n",
    "        \n",
    "        returns MMD evaluation of final thinned coresets from each rep using the\n",
    "        params_k_swap kernel and the params_p target distribution\n",
    "    \"\"\"\n",
    "    # Create results directory if necessary\n",
    "    pathlib.Path(results_dir).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    split_kernel = partial(kernel_eval, params_k=params_k_split)\n",
    "    swap_kernel = partial(kernel_eval, params_k=params_k_swap)\n",
    "    \n",
    "    # Construct results filename template with placeholder for rep value\n",
    "    d = params_p[\"d\"]\n",
    "    assert(d==params_k_split[\"d\"])\n",
    "    assert(d==params_k_swap[\"d\"])\n",
    "    \n",
    "    sample_str = sample_string(params_p, sample_seed)\n",
    "    split_kernel_str = \"{}_var{:.3f}_seed{}\".format(params_k_split[\"name\"], params_k_split[\"var\"], thin_seed)\n",
    "    swap_kernel_str =  \"{}_var{:.3f}\".format(params_k_swap[\"name\"], params_k_swap[\"var\"])\n",
    "    thresh_str = f\"delta{delta}\"\n",
    "    file_template = os.path.join(results_dir, f\"kt-coresets-{sample_str}-split{split_kernel_str}-swap{swap_kernel_str}-d{d}-m{m}-{thresh_str}-rep{{}}.pkl\")\n",
    "    \n",
    "    # Create array to store MMD evaluations from P, and Sin\n",
    "    if compute_mmds:\n",
    "        mmds_p = np.zeros((m+1, len(rep_ids)))\n",
    "        mmds_sin = np.zeros((m+1, len(rep_ids)))\n",
    "        mmd_p_file_template = os.path.join(results_dir, \n",
    "                                         f\"kt-mmd-{sample_str}-split{split_kernel_str}-swap{swap_kernel_str}-d{d}-m{m}-{thresh_str}-rep{{}}.pkl\")\n",
    "        mmd_sin_file_template = os.path.join(results_dir, \n",
    "                                         f\"kt-mmd-sin-{sample_str}-split{split_kernel_str}-swap{swap_kernel_str}-d{d}-m{m}-{thresh_str}-rep{{}}.pkl\")\n",
    "    split_kernel = partial(kernel_eval, params_k=params_k_split)\n",
    "    swap_kernel = partial(kernel_eval, params_k=params_k_swap)\n",
    "\n",
    "    # Number of sample points\n",
    "    n = int(2**(2*m))\n",
    "    fprint(f\"Running kernel thinning experiment with template {file_template}.....\")\n",
    "    tic()\n",
    "    for r_i, rep in enumerate(rep_ids):\n",
    "        # Include replication number in filename\n",
    "        filename = file_template.format(rep)\n",
    "        mmd_p_filename = mmd_p_file_template.format(rep)\n",
    "        mmd_sin_filename = mmd_sin_file_template.format(rep)\n",
    "        \n",
    "        # Generate matrix of input sample points\n",
    "        #print(f\"Generating data for rep {rep}\", flush=True)\n",
    "        #tic()\n",
    "        X = sample(n, params_p, seed=sample_seed+rep)\n",
    "        #toc()\n",
    "\n",
    "        if not rerun and os.path.exists(filename):\n",
    "            # Return previously saved results\n",
    "            #print(f\"Loading coresets from {filename}\", flush=True)\n",
    "            #tic()\n",
    "            with open(filename, 'rb') as file:\n",
    "                coresets = pkl.load(file)\n",
    "            #toc()\n",
    "        else:\n",
    "            # Obtain sequence of thinned coresets\n",
    "            print(f\"Kernel Thinning rep {rep}...\", flush=True)\n",
    "            # tic()\n",
    "            coresets = kt.thin(X, m, split_kernel, swap_kernel, delta=delta, seed=thin_seed+rep, store_K=store_K)\n",
    "            # toc()\n",
    "\n",
    "            # Save coresets to disk\n",
    "            # print(f\"Saving coresets to {filename}\", flush=True)\n",
    "            # tic()\n",
    "            with open(filename, 'wb') as file:\n",
    "                pkl.dump(coresets, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "            #toc()\n",
    "            \n",
    "        # Evaluate final coreset MMD\n",
    "        if compute_mmds:\n",
    "            if not rerun and os.path.exists(mmd_p_filename):\n",
    "                # Return previously saved results\n",
    "                #print(f\"Loading KT MMD results from {mmd_filename}\", flush=True)\n",
    "                #tic()\n",
    "                with open(mmd_p_filename, 'rb') as file:\n",
    "                    mmds_p[:, r_i] = pkl.load(file)\n",
    "                #toc()                \n",
    "            else:\n",
    "                #print(\"Evaluating KT MMD\", flush=True)\n",
    "                #tic()\n",
    "                for j in range(m+1):\n",
    "                    nj = int(2**j)\n",
    "                    mmds_p[j, r_i] = np.sqrt(\n",
    "                        squared_mmd(params_k_swap, params_p, X[coresets[:nj]]))\n",
    "                #toc()\n",
    "                # Save MMD results to disk\n",
    "                # print(f\"Saving KT MMD results to {mmd_filename}\", flush=True)\n",
    "                #tic()\n",
    "                with open(mmd_p_filename, 'wb') as file:\n",
    "                    pkl.dump(mmds_p[:, r_i], file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "                #toc()\n",
    "                    \n",
    "            if not rerun and os.path.exists(mmd_sin_filename):\n",
    "                # Return previously saved results\n",
    "                #print(f\"Loading KT MMD results from {mmd_filename}\", flush=True)\n",
    "                #tic()\n",
    "                with open(mmd_sin_filename, 'rb') as file:\n",
    "                    mmds_sin[:, r_i] = pkl.load(file)\n",
    "                #toc()                \n",
    "            else:\n",
    "                #print(\"Evaluating KT MMD\", flush=True)\n",
    "                #tic()\n",
    "                # redefining target p as distribution on Sin\n",
    "                params_p_sin = dict()\n",
    "                params_p_sin[\"name\"] =  params_p[\"name\"]+ \"_sin\"\n",
    "                params_p_sin[\"Pnmax\"] = X\n",
    "                params_p_sin[\"d\"] = d\n",
    "                for j in range(m+1):\n",
    "                    nj = int(2**j)\n",
    "                    mmds_sin[j, r_i] = np.sqrt(squared_mmd(params_k_swap, params_p_sin, X[coresets[:nj]]))\n",
    "                #toc()\n",
    "                # Save MMD results to disk\n",
    "                # print(f\"Saving KT MMD results to {mmd_filename}\", flush=True)\n",
    "                #tic()\n",
    "                with open(mmd_sin_filename, 'wb') as file:\n",
    "                    pkl.dump(mmds_sin[:, r_i], file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "                #toc()\n",
    "    toc()\n",
    "    if compute_mmds:\n",
    "        return(mmds_p, mmds_sin)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define standard thinning experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_standard_thinning_experiment(m, params_p, params_k_mmd, rep_ids, sample_seed=1234567, \n",
    "                      rerun=False, results_dir=\"results_new\", compute_mmds=True,\n",
    "                      min_mmd=False):\n",
    "    \"\"\"Evaluates MMD of iid Monte Carlo draws, and saves it to disk \n",
    "    \n",
    "    Args:\n",
    "      m: Number of halving rounds (defines number of sample points via n = 2^{2m})\n",
    "      params_p: Dictionary of distribution parameters recognized by sample()\n",
    "      params_k_mmd: Dictionary of kernel parameters for MMD evaluation\n",
    "      rep_ids: Which replication numbers of experiment to run; the replication\n",
    "        number determines the seeds set for reproducibility\n",
    "      sample_seed: (Optional) random seed is set to sample_seed + rep\n",
    "        prior to generating input sample for replication rep\n",
    "      rerun: (Optional) If False and results have been previously saved to\n",
    "        disk, load results from disk instead of rerunning experiment\n",
    "      results_dir: (Optional) Directory in which results should be saved\n",
    "      min_mmd: (Optional) if True, returns the minimum MMD over all sqrt(n) thinned \n",
    "        sequences of n points with step size sqrt(n); if False, returns the MMD\n",
    "        of the first such thinned sequence\n",
    "      compute_mmds: (Optional) Whether to compute mmds of coresets (using params_k_mmd)\n",
    "    \"\"\"\n",
    "    \n",
    "    # Create results directory if necessary\n",
    "    pathlib.Path(results_dir).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Create array to store MMD evaluations\n",
    "    mmds_p = np.zeros((m+1, len(rep_ids)))\n",
    "    mmds_sin = np.zeros((m+1, len(rep_ids)))\n",
    "\n",
    "    # Construct results filename template with placeholder for rep value\n",
    "    d = params_p[\"d\"]\n",
    "    assert(d == params_k_mmd[\"d\"])\n",
    "    sample_str = sample_string(params_p, sample_seed)\n",
    "    kernel_str = \"{}_var{:.3f}\".format(params_k_mmd[\"name\"], params_k_mmd[\"var\"])\n",
    "    min_str = \"min_\" if min_mmd else \"\"\n",
    "    mmd_p_file_template = os.path.join(results_dir, f\"{min_str}mc-mmd-{sample_str}-{kernel_str}-d{d}-m{m}-rep{{}}.pkl\")\n",
    "    mmd_sin_file_template = os.path.join(results_dir, f\"{min_str}mc-sin-mmd-{sample_str}-{kernel_str}-d{d}-m{m}-rep{{}}.pkl\")\n",
    "    \n",
    "    # Number of sample points\n",
    "    n = int(2**(2*m))\n",
    "    \n",
    "    fprint(f\"Running standard thinning experiment for m={m}\")\n",
    "    tic()\n",
    "    for r_i, rep in enumerate(rep_ids):\n",
    "        # Include replication number in filename\n",
    "        mmd_p_filename = mmd_p_file_template.format(rep)\n",
    "        mmd_sin_filename = mmd_sin_file_template.format(rep)\n",
    "        \n",
    "        if not rerun and os.path.exists(mmd_p_filename):\n",
    "            # Return previously saved results\n",
    "            #print(f\"Loading {min_str} Monte Carlo MMD results from {filename}\", flush=True)\n",
    "            #tic()\n",
    "            with open(mmd_p_filename, 'rb') as file:\n",
    "                mmds_p[:, r_i] = pkl.load(file)\n",
    "            #toc()\n",
    "        else:\n",
    "            #tic()\n",
    "            #print(f\"Generating data for rep {rep}\", flush=True)\n",
    "            #tic()\n",
    "            X = sample(n, params_p, seed=sample_seed+rep)\n",
    "            #toc()\n",
    "            #print(f\"Evaluating {min_str} Monte Carlo MMD\", flush=True)\n",
    "            for j in range(m+1):\n",
    "                # Target coreset size\n",
    "                coreset_size = int(2**j)\n",
    "                input_size = int(coreset_size**2)\n",
    "                if min_mmd:\n",
    "                    # Consider each coreset obtained by choosing every nj-th point\n",
    "                    # of the first nj^2 points of X and select the one with smallest MMD\n",
    "                    # There are nj^2/nj = nj such coresets indexed by their starting point\n",
    "                    num_starts = coreset_size\n",
    "                else:\n",
    "                    num_starts = 1\n",
    "                step_size = coreset_size\n",
    "                end = input_size\n",
    "                mmds_p[j, r_i] = np.inf \n",
    "                for start in range(num_starts):\n",
    "                    mmds_p[j, r_i] = min(mmds_p[j, r_i], np.sqrt(squared_mmd(params_k_mmd, params_p, X[(step_size-1-start):end:step_size])))\n",
    "                    #print(f\"j={j},rep={rep},mmd={mmds[j,rep]}\")\n",
    "            #toc()\n",
    "            # Save MMD results to disk\n",
    "            #print(f\"Saving {min_str} Monte Carlo MMD results to {filename}\", flush=True)\n",
    "            #tic()\n",
    "            with open(mmd_p_filename, 'wb') as file:\n",
    "                pkl.dump(mmds_p[:, r_i], file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "            #toc() \n",
    "            \n",
    "        if not rerun and os.path.exists(mmd_sin_filename):\n",
    "            # Return previously saved results\n",
    "            #print(f\"Loading {min_str} Monte Carlo MMD results from {filename}\", flush=True)\n",
    "            #tic()\n",
    "            with open(mmd_sin_filename, 'rb') as file:\n",
    "                mmds_sin[:, r_i] = pkl.load(file)\n",
    "            #toc()\n",
    "        else:\n",
    "            #tic()\n",
    "            #print(f\"Generating data for rep {rep}\", flush=True)\n",
    "            #tic()\n",
    "            X = sample(n, params_p, seed=sample_seed+rep)\n",
    "            #toc()\n",
    "            #print(f\"Evaluating {min_str} Monte Carlo MMD\", flush=True)\n",
    "            # redefining target p as distribution on Sin\n",
    "            params_p_sin = dict()\n",
    "            params_p_sin[\"name\"] =  params_p[\"name\"]+\"_sin\"\n",
    "            params_p_sin[\"Pnmax\"] = X\n",
    "            params_p_sin[\"d\"] = d\n",
    "            for j in range(m+1):\n",
    "                # Target coreset size\n",
    "                coreset_size = int(2**j)\n",
    "                input_size = int(coreset_size**2)\n",
    "                if min_mmd:\n",
    "                    # Consider each coreset obtained by choosing every nj-th point\n",
    "                    # of the first nj^2 points of X and select the one with smallest MMD\n",
    "                    # There are nj^2/nj = nj such coresets indexed by their starting point\n",
    "                    num_starts = coreset_size\n",
    "                else:\n",
    "                    num_starts = 1\n",
    "                step_size = coreset_size\n",
    "                end = input_size\n",
    "                mmds_sin[j, r_i] = np.inf \n",
    "                \n",
    "                for start in range(num_starts):\n",
    "                    mmds_sin[j, r_i] = min(mmds_sin[j, r_i], np.sqrt(\n",
    "                        squared_mmd(params_k_mmd, params_p_sin, X[(step_size-1-start):end:step_size])))\n",
    "                    #print(f\"j={j},rep={rep},mmd={mmds[j,rep]}\")\n",
    "            #toc()\n",
    "            # Save MMD results to disk\n",
    "            #print(f\"Saving {min_str} Monte Carlo MMD results to {filename}\", flush=True)\n",
    "            #tic()\n",
    "            with open(mmd_sin_filename, 'wb') as file:\n",
    "                pkl.dump(mmds_sin[:, r_i], file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "            #toc() \n",
    "    toc()\n",
    "    return(mmds_p, mmds_sin)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy thinning experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Choose sample and kernel parameters\n",
    "#\n",
    "var = 1. # Variance\n",
    "d = int(2) if args is None else args.d\n",
    "params_p = {\"name\": \"gauss\", \"var\": var, \"d\": int(d), \"saved_samples\": False}\n",
    "\n",
    "# filename is for MCMC files\n",
    "filename = None if args is None else args.filename\n",
    "\n",
    "# k denotes the number of componets for MOG settings\n",
    "M = None if args is None else args.M\n",
    "\n",
    "if isnotebook():\n",
    "    filename = None if args is None else args.filename\n",
    "    # ['Goodwin_RW', 'Goodwin_ADA-RW', 'Goodwin_MALA', 'Goodwin_PRECOND-MALA', 'Lotka_RW', 'Lotka_ADA-RW', 'Lotka_MALA', 'Lotka_PRECOND-MALA']\n",
    "\n",
    "if filename is not None:\n",
    "    # if a filename is specified then compute params_p\n",
    "    d = int(4)\n",
    "    params_p = compute_mcmc_params_p(filename, nmax=int(2**15), include_last=True)\n",
    "    # whether to use median_distance for kernel bandwidth for MCMC settings\n",
    "    use_median_distance = True \n",
    "\n",
    "    if use_median_distance:\n",
    "        var = (params_p[\"med_dist\"])**2\n",
    "    \n",
    "if M is not None:\n",
    "    # if number of mixture is specified then compute params_p\n",
    "    d = int(2)\n",
    "    params_p = compute_diag_mog_params(M)\n",
    "\n",
    "params_k_swap = {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "params_k_split = {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Choose experiment parameters\n",
    "#\n",
    "\n",
    "# List of replicate ID numbers\n",
    "rep_ids = range(10) if args is None else np.arange(args.rep0, args.rep0+args.repn)\n",
    "\n",
    "# List of halving round numbers m to evaluate\n",
    "ms = range(2+1) if args is None else range(args.m)\n",
    "\n",
    "# whether store_k during thinning, saves computation but requires O(n^2) memory\n",
    "# issue with larger n; if False, requires O(nd) memory\n",
    "store_K = False if args is None else args.store_K\n",
    "\n",
    "# Failure probability\n",
    "delta = .5\n",
    "\n",
    "# Which methods to run?\n",
    "run_standard_thinning = True\n",
    "run_kernel_thinning = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if run_standard_thinning: \n",
    "    mmds_st = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "    mmds_st_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "if run_kernel_thinning: \n",
    "    mmds_kt = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "    mmds_kt_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "\n",
    "print(\"Exp setting:\", params_p[\"name\"], ms)       \n",
    "for m in ms:\n",
    "    #\n",
    "    # Run experiments and store quality of the 2^m thinned coreset\n",
    "    #\n",
    "    if run_standard_thinning:\n",
    "        mmd_st, mmd_st_sin = run_standard_thinning_experiment(m, params_p, params_k_swap, rep_ids)\n",
    "        mmds_st[m, :] = mmd_st[m, :]\n",
    "        mmds_st_sin[m, :] = mmd_st_sin[m, :] \n",
    "\n",
    "    if run_kernel_thinning:\n",
    "        mmd_kt, mmd_kt_sin = run_kernel_thinning_experiment(m, params_p, params_k_split, params_k_swap, rep_ids, delta, store_K)\n",
    "        mmds_kt[m, :] = mmd_kt[m, :]\n",
    "        mmds_kt_sin[m, :] = mmd_kt_sin[m, :]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save MMD Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Save all mmd settings\n",
    "#\n",
    "save_combined_mmd = False if args is None else args.combine_mmd\n",
    "\n",
    "if save_combined_mmd:\n",
    "    if run_standard_thinning:\n",
    "        filename = get_combined_mmd_filename(f\"mc\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "        with open(filename, 'wb') as file:\n",
    "            print(f\"Saving combined mc mmd to {filename}\")\n",
    "            pkl.dump(mmds_st, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "        \n",
    "        filename = get_combined_mmd_filename(f\"mc-sin\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "        with open(filename, 'wb') as file:\n",
    "            print(f\"Saving combined mc mmd_sin to {filename}\")\n",
    "            pkl.dump(mmds_st_sin, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "\n",
    "    if run_kernel_thinning:\n",
    "        filename = get_combined_mmd_filename(f\"kt\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "        with open(filename, 'wb') as file:\n",
    "            print(f\"Saving combined kt mmd to {filename}\")\n",
    "            pkl.dump(mmds_kt, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "        \n",
    "        filename = get_combined_mmd_filename(f\"kt-sin\", ms, params_p, params_k_split, params_k_swap, rep_ids, delta)\n",
    "        with open(filename, 'wb') as file:\n",
    "            print(f\"Saving combined kt mmd_sin to {filename}\")\n",
    "            pkl.dump(mmds_kt_sin, file, protocol=pkl.HIGHEST_PROTOCOL)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
