{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A notebook for running generalized KT and standard thinning experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "The line_profiler extension is already loaded. To reload it, use:\n",
      "  %reload_ext line_profiler\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import numpy.linalg as npl\n",
    "from scipy.spatial.distance import pdist\n",
    "\n",
    "from argparse import ArgumentParser\n",
    "import pickle as pkl\n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "\n",
    "# import kernel thinning\n",
    "from kernelthinning import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions\n",
    "from kernelthinning.util import isnotebook # Check whether this file is being executed as a script or as a notebook\n",
    "from kernelthinning.util import fprint  # for printing while flushing buffer\n",
    "from kernelthinning.tictoc import tic, toc # for timing blocks of code\n",
    "\n",
    "\n",
    "# utils for generating samples, evaluating kernels, and mmds\n",
    "from util_sample import sample, compute_params_p, sample_string\n",
    "from util_k_mmd import kernel_eval, compute_params_k, compute_power_kernel_params_k\n",
    "from util_k_mmd import p_kernel, ppn_kernel, pp_kernel, pnpn_kernel, squared_mmd, get_combined_results_filename\n",
    "from util_parse import init_parser, convert_arg_flags\n",
    "# for partial functions, to use kernel_eval for kernel\n",
    "from functools import partial\n",
    "\n",
    "# experiment functions\n",
    "from util_experiments import run_kernel_thinning_experiment, kt_split_best, kt_split_rand\n",
    "from util_experiments import run_standard_thinning_experiment, run_iid_thinning_experiment\n",
    "\n",
    "# set things a bit when running the notebook\n",
    "if isnotebook():\n",
    "    # Autoreload packages that are modified\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2\n",
    "    %matplotlib inline\n",
    "    %load_ext line_profiler\n",
    "    # https://jakevdp.github.io/PythonDataScienceHandbook/01.07-timing-and-profiling.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if isnotebook():\n",
    "\n",
    "parser = init_parser()\n",
    "args, opt = parser.parse_known_args()\n",
    "args = convert_arg_flags(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(M=None, P='gauss', computemmd=True, computepower=True, d=2, filename='/accounts/projects/binyu/raaz.rsk/.jupyter/runtime/kernel-4cccd30b-d604-47d4-bb10-cd5db3a89913.json', kernel='gauss', ktplus=True, m=6, nu=0.5, power=0.5, powerkt=True, rep0=0, repn=1, rerun=False, save_combined_results=False, stdthin=False, store_K=False, targetkt=True) [] ArgumentParser(prog='ipykernel_launcher.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)\n"
     ]
    }
   ],
   "source": [
    "print(args, opt, parser)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deploy thinning experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "p {'name': 'diag_mog', 'weights': array([0.25, 0.25, 0.25, 0.25]), 'means': array([[ 3.,  3.],\n",
      "       [-3.,  3.],\n",
      "       [-3., -3.],\n",
      "       [ 3., -3.]]), 'covs': array([1., 1., 1., 1.]), 'd': 2, 'mean_sqdist': 40.0, 'saved_samples': False, 'flip_Pnmax': False}\n",
      "k {'name': 'gauss', 'var': 4.0, 'd': 2}\n",
      "kpower {'name': 'gauss_rt', 'd': 2, 'var': 2.0}\n",
      "combo {'name': 'combo_gauss_0.5', 'k': {'name': 'gauss', 'var': 4.0, 'd': 2}, 'kpower': {'name': 'gauss_rt', 'd': 2, 'var': 2.0}, 'var': 4.0, 'd': 2}\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Choose sample and kernel parameters\n",
    "#\n",
    "\n",
    "if isnotebook():\n",
    "    args.d = 38\n",
    "    args.M = 4\n",
    "    args.P = \"mog\"\n",
    "    args.kernel = \"gauss\"\n",
    "    filename = None \n",
    "    args.rep0 = 0\n",
    "    args.repn = 2\n",
    "    args.m = 5\n",
    "    args.computepower = True\n",
    "    args.power = 0.5\n",
    "    args.filename = 'Hinch_P_seed_1_temp_1_scaled'\n",
    "    \n",
    "    # ['Goodwin_RW', 'Goodwin_ADA-RW', 'Goodwin_MALA', 'Goodwin_PRECOND-MALA', \n",
    "    # 'Lotka_RW', 'Lotka_ADA-RW', 'Lotka_MALA', 'Lotka_PRECOND-MALA', \n",
    "    # 'Hinch_P_seed_1_temp_1', 'Hinch_P_seed_2_temp_1', \n",
    "    # 'Hinch_TP_seed_1_temp_8', 'Hinch_TP_seed_2_temp_8',\n",
    "    # 'Hinch_P_seed_1_temp_1_scaled', 'Hinch_P_seed_2_temp_1_scaled', \n",
    "    # 'Hinch_TP_seed_1_temp_8_scaled', 'Hinch_TP_seed_2_temp_8_scaled']\n",
    "\n",
    "d, params_p, var_k = compute_params_p(args)\n",
    "# d can change from args when mcmc filename is specified\n",
    "    \n",
    "args.d = d\n",
    "params_k, params_k_power = compute_params_k(args, var_k, power_kernel=args.computepower,power=args.power)\n",
    "\n",
    "if args.ktplus:\n",
    "    assert(args.power is not None)\n",
    "    params_k_combo = dict()\n",
    "    params_k_combo[\"name\"] = \"combo_\"  + params_k[\"name\"] + f\"_{args.power}\"\n",
    "    params_k_combo[\"k\"] = params_k.copy()\n",
    "    params_k_combo[\"kpower\"] = params_k_power.copy()\n",
    "    params_k_combo[\"var\"] = params_k[\"var\"]\n",
    "    params_k_combo[\"d\"] = args.d\n",
    "\n",
    "# if isnotebook():\n",
    "print(\"p\", params_p)\n",
    "print(\"k\", params_k)\n",
    "print(\"kpower\", params_k_power)\n",
    "if args.ktplus:\n",
    "    print(\"combo\", params_k_combo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# Choose experiment parameters\n",
    "#\n",
    "\n",
    "# List of replicate ID numbers\n",
    "rep_ids = np.arange(args.rep0, args.rep0+args.repn)\n",
    "\n",
    "# List of halving round numbers m to evaluate\n",
    "ms = range(args.m)\n",
    "\n",
    "# Failure probability\n",
    "delta = .5\n",
    "\n",
    "if isnotebook():\n",
    "    args.rerun = False\n",
    "    rep_ids = range(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "if args.stdthin: \n",
    "    mmds_st = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "    mmds_st_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "    fun_diff_st  = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from P\n",
    "    fun_diff_st_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin\n",
    "    \n",
    "if args.targetkt: \n",
    "    mmds_kt = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "    mmds_kt_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "    fun_diff_kt = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P\n",
    "    fun_diff_kt_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin\n",
    "\n",
    "if args.powerkt: \n",
    "    mmds_kt_krt = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "    mmds_kt_krt_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "    fun_diff_kt_krt = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P\n",
    "    fun_diff_kt_krt_sin = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from Sin\n",
    "\n",
    "\n",
    "if args.ktplus: \n",
    "    mmds_ktplus = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "    mmds_ktplus_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "    fun_diff_ktplus = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P\n",
    "    fun_diff_ktplus_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exp setting: k = {'name': 'imq', 'var': 65.08616976, 'd': 38, 'nu': 0.5},  P = {'saved_samples': True, 'name': 'Hinch_P_seed_1_temp_1_scaled', 'data_dir': 'data', 'nmax': 32768, 'burn_in': 0, 'include_last': True, 'med_dist': 8.0676, 'd': 38, 'flip_Pnmax': False, 'Pnmax': array([[-1.52716763,  0.07991351,  2.57791193, ...,  3.88998877,\n",
      "         3.96155122, -1.37342955],\n",
      "       [-1.55686675,  0.06266224,  2.57785841, ...,  3.91243187,\n",
      "         3.96414028, -1.37663901],\n",
      "       [-1.58389802,  0.04750193,  2.57547335, ...,  3.92266016,\n",
      "         3.95932316, -1.36276222],\n",
      "       ...,\n",
      "       [ 1.34991445, -1.19608383,  1.5011615 , ..., -0.88333755,\n",
      "        -1.45351143,  0.96341884],\n",
      "       [ 1.34376293, -1.19351681,  1.50051669, ..., -0.92128697,\n",
      "        -1.45529421,  0.95647493],\n",
      "       [ 1.3622799 , -1.18962683,  1.49751876, ..., -0.85989358,\n",
      "        -1.4585998 ,  0.95922069]])}, m = range(0, 5)\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_var65.086_seed9876543-swapimq_var65.086-d38-m0-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.0633 (s)\n",
      "-elapsed time: 0.0545 (s)\n",
      "-elapsed time: 0.0289 (s)\n",
      "-elapsed time: 0.0135 (s)\n",
      "-elapsed time: 0.0479 (s)\n",
      "-elapsed time: 0.0333 (s)\n",
      "-elapsed time: 0.00882 (s)\n",
      "-elapsed time: 0.0207 (s)\n",
      "-elapsed time: 0.019 (s)\n",
      "-elapsed time: 0.0123 (s)\n",
      "-elapsed time: 0.31 (s)\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_rt_var65.086_seed9876543-swapimq_var65.086-d38-m0-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.0127 (s)\n",
      "-elapsed time: 0.025 (s)\n",
      "-elapsed time: 0.0324 (s)\n",
      "-elapsed time: 0.0273 (s)\n",
      "-elapsed time: 0.0467 (s)\n",
      "-elapsed time: 0.0326 (s)\n",
      "-elapsed time: 0.0212 (s)\n",
      "-elapsed time: 0.0591 (s)\n",
      "-elapsed time: 0.0119 (s)\n",
      "-elapsed time: 0.0246 (s)\n",
      "-elapsed time: 0.301 (s)\n",
      "Running kernel thinning -plus experiment with template results_new/kt-plus-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitcombo_imq_0.5_var65.086_seed9876543-swapimq_var65.086-d38-m0-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00398 (s)\n",
      "-elapsed time: 0.00386 (s)\n",
      "-elapsed time: 0.00477 (s)\n",
      "-elapsed time: 0.00462 (s)\n",
      "-elapsed time: 0.005 (s)\n",
      "-elapsed time: 0.00401 (s)\n",
      "-elapsed time: 0.00379 (s)\n",
      "-elapsed time: 0.00383 (s)\n",
      "-elapsed time: 0.00461 (s)\n",
      "-elapsed time: 0.00406 (s)\n",
      "-elapsed time: 0.0489 (s)\n",
      "mmd target_kt [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
      " [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_var65.086_seed9876543-swapimq_var65.086-d38-m1-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00773 (s)\n",
      "-elapsed time: 0.0325 (s)\n",
      "-elapsed time: 0.019 (s)\n",
      "-elapsed time: 0.0451 (s)\n",
      "-elapsed time: 0.0333 (s)\n",
      "-elapsed time: 0.0305 (s)\n",
      "-elapsed time: 0.0142 (s)\n",
      "-elapsed time: 0.0188 (s)\n",
      "-elapsed time: 0.00842 (s)\n",
      "-elapsed time: 0.00785 (s)\n",
      "-elapsed time: 0.225 (s)\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_rt_var65.086_seed9876543-swapimq_var65.086-d38-m1-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.0185 (s)\n",
      "-elapsed time: 0.00798 (s)\n",
      "-elapsed time: 0.0157 (s)\n",
      "-elapsed time: 0.0087 (s)\n",
      "-elapsed time: 0.0935 (s)\n",
      "-elapsed time: 0.0237 (s)\n",
      "-elapsed time: 0.0107 (s)\n",
      "-elapsed time: 0.0169 (s)\n",
      "-elapsed time: 0.00891 (s)\n",
      "-elapsed time: 0.0132 (s)\n",
      "-elapsed time: 0.225 (s)\n",
      "Running kernel thinning -plus experiment with template results_new/kt-plus-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitcombo_imq_0.5_var65.086_seed9876543-swapimq_var65.086-d38-m1-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00403 (s)\n",
      "-elapsed time: 0.00492 (s)\n",
      "-elapsed time: 0.00481 (s)\n",
      "-elapsed time: 0.00487 (s)\n",
      "-elapsed time: 0.00488 (s)\n",
      "-elapsed time: 0.00408 (s)\n",
      "-elapsed time: 0.00501 (s)\n",
      "-elapsed time: 0.00495 (s)\n",
      "-elapsed time: 0.00508 (s)\n",
      "-elapsed time: 0.00508 (s)\n",
      "-elapsed time: 0.0543 (s)\n",
      "mmd target_kt [[0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.29473562 0.29473562 0.29473562 0.29473562 0.29473562 0.29473562\n",
      "  0.29473562 0.29473562 0.29473562 0.29473562]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_var65.086_seed9876543-swapimq_var65.086-d38-m2-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.0242 (s)\n",
      "-elapsed time: 0.00929 (s)\n",
      "-elapsed time: 0.0093 (s)\n",
      "-elapsed time: 0.00787 (s)\n",
      "-elapsed time: 0.0406 (s)\n",
      "-elapsed time: 0.0243 (s)\n",
      "-elapsed time: 0.0328 (s)\n",
      "-elapsed time: 0.00894 (s)\n",
      "-elapsed time: 0.00755 (s)\n",
      "-elapsed time: 0.00592 (s)\n",
      "-elapsed time: 0.178 (s)\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_rt_var65.086_seed9876543-swapimq_var65.086-d38-m2-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00681 (s)\n",
      "-elapsed time: 0.00727 (s)\n",
      "-elapsed time: 0.008 (s)\n",
      "-elapsed time: 0.00751 (s)\n",
      "-elapsed time: 0.0285 (s)\n",
      "-elapsed time: 0.0224 (s)\n",
      "-elapsed time: 0.0269 (s)\n",
      "-elapsed time: 0.0218 (s)\n",
      "-elapsed time: 0.0254 (s)\n",
      "-elapsed time: 0.0276 (s)\n",
      "-elapsed time: 0.19 (s)\n",
      "Running kernel thinning -plus experiment with template results_new/kt-plus-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitcombo_imq_0.5_var65.086_seed9876543-swapimq_var65.086-d38-m2-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00402 (s)\n",
      "-elapsed time: 0.00506 (s)\n",
      "-elapsed time: 0.00388 (s)\n",
      "-elapsed time: 0.00384 (s)\n",
      "-elapsed time: 0.00472 (s)\n",
      "-elapsed time: 0.00492 (s)\n",
      "-elapsed time: 0.00395 (s)\n",
      "-elapsed time: 0.00408 (s)\n",
      "-elapsed time: 0.00489 (s)\n",
      "-elapsed time: 0.00415 (s)\n",
      "-elapsed time: 0.05 (s)\n",
      "mmd target_kt [[0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.29473562 0.29473562 0.29473562 0.29473562 0.29473562 0.29473562\n",
      "  0.29473562 0.29473562 0.29473562 0.29473562]\n",
      " [0.15846643 0.15091379 0.14882284 0.15846643 0.14882284 0.14882284\n",
      "  0.14882284 0.14882284 0.15053708 0.15846643]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_var65.086_seed9876543-swapimq_var65.086-d38-m3-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00815 (s)\n",
      "-elapsed time: 0.00811 (s)\n",
      "-elapsed time: 0.0199 (s)\n",
      "-elapsed time: 0.0084 (s)\n",
      "-elapsed time: 0.00906 (s)\n",
      "-elapsed time: 0.0263 (s)\n",
      "-elapsed time: 0.0666 (s)\n",
      "-elapsed time: 0.0121 (s)\n",
      "-elapsed time: 0.0143 (s)\n",
      "-elapsed time: 0.00713 (s)\n",
      "-elapsed time: 0.187 (s)\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_rt_var65.086_seed9876543-swapimq_var65.086-d38-m3-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00703 (s)\n",
      "-elapsed time: 0.00734 (s)\n",
      "-elapsed time: 0.00788 (s)\n",
      "-elapsed time: 0.00755 (s)\n",
      "-elapsed time: 0.00823 (s)\n",
      "-elapsed time: 0.00859 (s)\n",
      "-elapsed time: 0.0118 (s)\n",
      "-elapsed time: 0.0247 (s)\n",
      "-elapsed time: 0.0104 (s)\n",
      "-elapsed time: 0.00722 (s)\n",
      "-elapsed time: 0.108 (s)\n",
      "Running kernel thinning -plus experiment with template results_new/kt-plus-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitcombo_imq_0.5_var65.086_seed9876543-swapimq_var65.086-d38-m3-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00462 (s)\n",
      "-elapsed time: 0.0188 (s)\n",
      "-elapsed time: 0.00532 (s)\n",
      "-elapsed time: 0.00445 (s)\n",
      "-elapsed time: 0.00492 (s)\n",
      "-elapsed time: 0.00383 (s)\n",
      "-elapsed time: 0.0038 (s)\n",
      "-elapsed time: 0.00403 (s)\n",
      "-elapsed time: 0.00505 (s)\n",
      "-elapsed time: 0.00498 (s)\n",
      "-elapsed time: 0.0663 (s)\n",
      "mmd target_kt [[0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.29473562 0.29473562 0.29473562 0.29473562 0.29473562 0.29473562\n",
      "  0.29473562 0.29473562 0.29473562 0.29473562]\n",
      " [0.15846643 0.15091379 0.14882284 0.15846643 0.14882284 0.14882284\n",
      "  0.14882284 0.14882284 0.15053708 0.15846643]\n",
      " [0.07490314 0.08102318 0.07734988 0.07647028 0.07236608 0.07942531\n",
      "  0.07457984 0.07357238 0.07057173 0.07857234]\n",
      " [0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]]\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_var65.086_seed9876543-swapimq_var65.086-d38-m4-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00847 (s)\n",
      "-elapsed time: 0.0217 (s)\n",
      "-elapsed time: 0.0096 (s)\n",
      "-elapsed time: 0.0104 (s)\n",
      "-elapsed time: 0.0639 (s)\n",
      "-elapsed time: 0.0239 (s)\n",
      "-elapsed time: 0.00738 (s)\n",
      "-elapsed time: 0.00896 (s)\n",
      "-elapsed time: 0.03 (s)\n",
      "-elapsed time: 0.00724 (s)\n",
      "-elapsed time: 0.199 (s)\n",
      "Running kernel thinning  experiment with template results_new/kt-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitimq_rt_var65.086_seed9876543-swapimq_var65.086-d38-m4-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00722 (s)\n",
      "-elapsed time: 0.00836 (s)\n",
      "-elapsed time: 0.00861 (s)\n",
      "-elapsed time: 0.00848 (s)\n",
      "-elapsed time: 0.0596 (s)\n",
      "-elapsed time: 0.103 (s)\n",
      "-elapsed time: 0.0131 (s)\n",
      "-elapsed time: 0.0304 (s)\n",
      "-elapsed time: 0.017 (s)\n",
      "-elapsed time: 0.00795 (s)\n",
      "-elapsed time: 0.271 (s)\n",
      "Running kernel thinning -plus experiment with template results_new/kt-plus-coresets-Hinch_P_seed_1_temp_1_scaled_endpt_nmax_15-splitcombo_imq_0.5_var65.086_seed9876543-swapimq_var65.086-d38-m4-delta0.5-rep{}.pkl.....\n",
      "-elapsed time: 0.00496 (s)\n",
      "-elapsed time: 0.00485 (s)\n",
      "-elapsed time: 0.00517 (s)\n",
      "-elapsed time: 0.00523 (s)\n",
      "-elapsed time: 0.00397 (s)\n",
      "-elapsed time: 0.00357 (s)\n",
      "-elapsed time: 0.00361 (s)\n",
      "-elapsed time: 0.00371 (s)\n",
      "-elapsed time: 0.0047 (s)\n",
      "-elapsed time: 0.00375 (s)\n",
      "-elapsed time: 0.05 (s)\n",
      "mmd target_kt [[0.         0.         0.         0.         0.         0.\n",
      "  0.         0.         0.         0.        ]\n",
      " [0.29473562 0.29473562 0.29473562 0.29473562 0.29473562 0.29473562\n",
      "  0.29473562 0.29473562 0.29473562 0.29473562]\n",
      " [0.15846643 0.15091379 0.14882284 0.15846643 0.14882284 0.14882284\n",
      "  0.14882284 0.14882284 0.15053708 0.15846643]\n",
      " [0.07490314 0.08102318 0.07734988 0.07647028 0.07236608 0.07942531\n",
      "  0.07457984 0.07357238 0.07057173 0.07857234]\n",
      " [0.03506242 0.03730957 0.03536346 0.03808509 0.03589797 0.0367158\n",
      "  0.03387728 0.03637768 0.03337522 0.03372612]]\n",
      "-elapsed time: 2.51 (s)\n"
     ]
    }
   ],
   "source": [
    "print(f\"Exp setting: k = {params_k},  P = {params_p}, m = {ms}\")       \n",
    "tic()\n",
    "\n",
    "# print(args.rerun, args)\n",
    "for m in ms:\n",
    "    #\n",
    "    # Run experiments and store quality of the 2^m thinned coreset\n",
    "    #\n",
    "    if args.stdthin:\n",
    "        mmd_st, mmd_st_sin, fd_st, fd_st_sin = run_standard_thinning_experiment(m, params_p=params_p, rerun=args.rerun,\n",
    "                                                                                params_k_mmd=params_k, rep_ids=rep_ids,\n",
    "                                                                               compute_mmds=args.computemmd)\n",
    "        mmds_st[m, :] = mmd_st[m, :]\n",
    "        mmds_st_sin[m, :] = mmd_st_sin[m, :]\n",
    "        fun_diff_st[m, :] = fd_st[m, :]\n",
    "        fun_diff_st_sin[m, :] = fd_st_sin[m, :]\n",
    "        \n",
    "    \n",
    "    if args.targetkt: \n",
    "        mmd_kt, mmd_kt_sin, fd_kt, fd_kt_sin = run_kernel_thinning_experiment(m, thin_fun=kt.thin, thin_str=\"\", params_p=params_p, rerun=args.rerun,\n",
    "                                                            params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, \n",
    "                                                            delta=delta, store_K=args.store_K,\n",
    "                                                                              compute_mmds=args.computemmd\n",
    "                                                                             )\n",
    "        mmds_kt[m, :] = mmd_kt[m, :]\n",
    "        mmds_kt_sin[m, :] = mmd_kt_sin[m, :]\n",
    "        fun_diff_kt[m, :] = fd_kt[m, :]\n",
    "        fun_diff_kt_sin[m, :] = fd_kt_sin[m, :]\n",
    "        \n",
    "\n",
    "        \n",
    "    if args.powerkt: \n",
    "        mmd_kt_krt, mmd_kt_krt_sin, fd_kt_krt, fd_kt_krt_sin = run_kernel_thinning_experiment(m, thin_fun=kt.thin, thin_str=\"\", params_p=params_p, rerun=args.rerun,\n",
    "                                                            params_k_split=params_k_power, params_k_swap=params_k, rep_ids=rep_ids, \n",
    "                                                            delta=delta, store_K=args.store_K,\n",
    "                                                                compute_mmds=args.computemmd)\n",
    "        mmds_kt_krt[m, :] = mmd_kt_krt[m, :]\n",
    "        mmds_kt_krt_sin[m, :] = mmd_kt_krt_sin[m, :]\n",
    "        fun_diff_kt_krt[m, :] = fd_kt_krt[m, :]\n",
    "        fun_diff_kt_krt_sin[m, :] = fd_kt_krt_sin[m, :]\n",
    "    \n",
    "    if args.ktplus:\n",
    "        mmd_ktplus, mmd_ktplus_sin, fd_ktplus, fd_ktplus_sin = run_kernel_thinning_experiment(m, thin_fun=kt.thin, thin_str=\"-plus\", params_p=params_p, rerun=args.rerun,\n",
    "                                                            params_k_split=params_k_combo, params_k_swap=params_k, rep_ids=rep_ids, \n",
    "                                                            delta=delta, store_K=args.store_K,\n",
    "                                                                              compute_mmds=args.computemmd\n",
    "                                                                             )\n",
    "        mmds_ktplus[m, :] = mmd_ktplus[m, :]\n",
    "        mmds_ktplus_sin[m, :] = mmd_ktplus_sin[m, :]\n",
    "        fun_diff_ktplus[m, :] = fd_ktplus[m, :]\n",
    "        fun_diff_ktplus_sin[m, :] = fd_ktplus_sin[m, :]\n",
    "    \n",
    "    if args.targetkt:\n",
    "        print('mmd target_kt', mmds_kt)\n",
    "toc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.         0.29473562 0.15209644 0.07588341 0.03557906] [0.         0.29473562 0.153228   0.07525277 0.03628651] [0.         0.29473562 0.153228   0.07544027 0.03627685]\n"
     ]
    }
   ],
   "source": [
    "if isnotebook():\n",
    "    print(mmds_kt.mean(1), mmds_kt_krt.mean(1), mmds_ktplus.mean(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# profile kernel_eval function\n",
    "#\n",
    "if False:\n",
    "    x = params_p[\"Pnmax\"][0]\n",
    "    xn = npr.randn(int(4**(8)), d)\n",
    "    %lprun -f kernel_eval kernel_eval(x.reshape(1,-1), xn,  params_k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# profile squared_mmd function\n",
    "#\n",
    "\n",
    "\n",
    "if False:\n",
    "    X = sample(int(4**8), params_p)\n",
    "    xn = npr.randn(int(4**(4)), d)\n",
    "    tic()\n",
    "    params_p_sin = dict()\n",
    "    params_p_sin[\"d\"] = d\n",
    "    params_p_sin[\"name\"] =  params_p[\"name\"]+\"_sin\"\n",
    "    params_p_sin[\"Pnmax\"] = X\n",
    "    params_p_sin[\"saved_samples\"] = False\n",
    "#     params_q = params_p\n",
    "#     params_q[\"Pnmax\"] = params_p[\"Pnmax\"][:1]\n",
    "    %lprun -f squared_mmd squared_mmd(params_k,  params_p_sin, xn)\n",
    "    \n",
    "    toc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "#\n",
    "# profile kt.thin timings\n",
    "# \n",
    "\n",
    "if False:\n",
    "    d = int(100)\n",
    "    var = 2*float(d)\n",
    "    var_p = 1.\n",
    "    params_p = {\"name\": \"gauss\", \"var\": var_p, \"d\": int(d), \"saved_samples\": False}\n",
    "    params_k =  {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "    params_krt =  {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}\n",
    "    m = 7\n",
    "    X = sample(int(4**m), params_p)\n",
    "    split_kernel = partial(kernel_eval, params_k=params_k)\n",
    "    swap_kernel = partial(kernel_eval, params_k=params_k)\n",
    "    %time kt.thin(X, m, split_kernel, swap_kernel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save MMD and fun diff results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving kt-combinedmmd- to results_new/combined/kt-combinedmmd--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-sin-combinedmmd- to results_new/combined/kt-sin-combinedmmd--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-combinedfundiff- to results_new/combined/kt-combinedfundiff--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-sin-combinedfundiff- to results_new/combined/kt-sin-combinedfundiff--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt_krt-combinedmmd- to results_new/combined/kt_krt-combinedmmd--gauss_var1.0_seed1234567--split_imq_rt_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt_krt-sin-combinedmmd- to results_new/combined/kt_krt-sin-combinedmmd--gauss_var1.0_seed1234567--split_imq_rt_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt_krt-combinedfundiff- to results_new/combined/kt_krt-combinedfundiff--gauss_var1.0_seed1234567--split_imq_rt_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt_krt-sin-combinedfundiff- to results_new/combined/kt_krt-sin-combinedfundiff--gauss_var1.0_seed1234567--split_imq_rt_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-plus0.5-combinedmmd- to results_new/combined/kt-plus0.5-combinedmmd--gauss_var1.0_seed1234567--split_combo_imq_0.5_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-plus0.5-sin-combinedmmd- to results_new/combined/kt-plus0.5-sin-combinedmmd--gauss_var1.0_seed1234567--split_combo_imq_0.5_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-plus0.5-combinedfundiff- to results_new/combined/kt-plus0.5-combinedfundiff--gauss_var1.0_seed1234567--split_combo_imq_0.5_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n",
      "Saving kt-plus0.5-sin-combinedfundiff- to results_new/combined/kt-plus0.5-sin-combinedfundiff--gauss_var1.0_seed1234567--split_combo_imq_0.5_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl\n"
     ]
    }
   ],
   "source": [
    "#\n",
    "# Save all combined results\n",
    "#\n",
    "if isnotebook():\n",
    "    # change this code to save results manually when running notebook\n",
    "    save_combined_results = True #True if args is None else args.save_combined_results\n",
    "else:\n",
    "    save_combined_results = False if args is None else args.save_combined_results\n",
    "\n",
    "generic_prefixes = [\"-combinedmmd-\", \"-sin-combinedmmd-\", \"-combinedfundiff-\", \"-sin-combinedfundiff-\"]\n",
    "\n",
    "if save_combined_results:\n",
    "    \n",
    "    if args.stdthin:\n",
    "        prefixes = [\"mc\" + prefix for prefix in generic_prefixes]\n",
    "        data_arrays = [mmds_st, mmds_st_sin, fun_diff_st, fun_diff_st_sin]\n",
    "        for prefix, data_array in zip(prefixes, data_arrays):\n",
    "            filename = get_combined_results_filename(prefix, ms, params_p, params_k, params_k, rep_ids, delta)\n",
    "            with open(filename, 'wb') as file:\n",
    "                print(f\"Saving {prefix} to {filename}\")\n",
    "                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "    \n",
    "\n",
    "    if args.targetkt:\n",
    "        prefixes = [\"kt\" + prefix for prefix in generic_prefixes]\n",
    "        data_arrays = [mmds_kt, mmds_kt_sin, fun_diff_kt, fun_diff_kt_sin]\n",
    "        for prefix, data_array in zip(prefixes, data_arrays):\n",
    "            filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)\n",
    "            with open(filename, 'wb') as file:\n",
    "                print(f\"Saving {prefix} to {filename}\")\n",
    "                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "\n",
    "\n",
    "    if args.powerkt:\n",
    "        temp = \"kt_krt\" if args.power == 0.5 else f\"kt_power{args.power}\"\n",
    "        prefixes = [temp + prefix for prefix in generic_prefixes]\n",
    "        data_arrays = [mmds_kt_krt, mmds_kt_krt_sin, fun_diff_kt_krt, fun_diff_kt_krt_sin]\n",
    "        for prefix, data_array in zip(prefixes, data_arrays):\n",
    "            filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k_power, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)\n",
    "            with open(filename, 'wb') as file:\n",
    "                print(f\"Saving {prefix} to {filename}\")\n",
    "                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "                \n",
    "    if args.ktplus:\n",
    "        prefixes = [f\"kt-plus{args.power}\" + prefix for prefix in generic_prefixes]\n",
    "        data_arrays = [mmds_ktplus, mmds_ktplus_sin, fun_diff_ktplus, fun_diff_ktplus_sin]\n",
    "        for prefix, data_array in zip(prefixes, data_arrays):\n",
    "            filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k_combo, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)\n",
    "            with open(filename, 'wb') as file:\n",
    "                print(f\"Saving {prefix} to {filename}\")\n",
    "                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "                \n",
    "         "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### OLD CODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #\n",
    "# # Choose sample and kernel parameters\n",
    "# #\n",
    "\n",
    "# d = int(2) if args is None else args.d\n",
    "# # var = 2*float(d) # changing variance to scale with d; NOW ALWAYS USING THIS SETTING\n",
    "# # var_p = 1. # variance of P\n",
    "\n",
    "# # params_p = {\"name\": \"gauss\", \"var\": var_p, \"d\": int(d), \"saved_samples\": False, \"flip_Pnmax\":False}\n",
    "# # filename is for MCMC files\n",
    "# filename = None if args is None else args.filename\n",
    "# # k denotes the number of componets for MOG settings\n",
    "# M = None if args is None else args.M\n",
    "\n",
    "# if isnotebook():\n",
    "#     # M = 8\n",
    "#     args.M = 4\n",
    "#     args.P = \"gauss\"\n",
    "#     args.kernel = \"sinc\"\n",
    "#     # TWEAK HERE IF WANT TO TEST THINGS FOR MCMC\n",
    "#     filename = None #'Hinch_TP_seed_2_temp_8_scaled' # 'Goodwin_RW' if args is None else args.filename\n",
    "#     # ['Goodwin_RW', 'Goodwin_ADA-RW', 'Goodwin_MALA', 'Goodwin_PRECOND-MALA', \n",
    "#     # 'Lotka_RW', 'Lotka_ADA-RW', 'Lotka_MALA', 'Lotka_PRECOND-MALA', \n",
    "#     # 'Hinch_P_seed_1_temp_1', 'Hinch_P_seed_2_temp_1', \n",
    "#     # 'Hinch_TP_seed_1_temp_8', 'Hinch_TP_seed_2_temp_8',\n",
    "#     # 'Hinch_P_seed_1_temp_1_scaled', 'Hinch_P_seed_2_temp_1_scaled', \n",
    "#     # 'Hinch_TP_seed_1_temp_8_scaled', 'Hinch_TP_seed_2_temp_8_scaled']\n",
    "\n",
    "# # if filename is not None:\n",
    "# #     # if a filename is specified then compute params_p\n",
    "    \n",
    "# #     params_p = compute_mcmc_params_p(filename, nmax=int(2**15), include_last=True, \n",
    "# #                                      flip_Pnmax=True)\n",
    "# #     # whether to use median_distance for kernel bandwidth for MCMC settings\n",
    "# #     use_median_distance = True  # NOW SET TO TRUE ALWAYS\n",
    "# #     d = params_p[\"d\"]\n",
    "\n",
    "# #     if use_median_distance:\n",
    "# #         var = (params_p[\"med_dist\"])**2\n",
    "\n",
    "# # if M is not None:\n",
    "# #     # if number of mixture is specified then compute params_p\n",
    "# #     params_p = compute_diag_mog_params(M)\n",
    "# #     d = params_p[\"d\"]\n",
    "# #     var = float(2*d)\n",
    "# #     # var = params_p[\"mean_sqdist\"]\n",
    "\n",
    "# # params_k =  {\"name\": \"gauss\", \"var\": var, \"d\": int(d)}\n",
    "# # params_krt =  {\"name\": \"gauss_rt\", \"var\": var/2., \"d\": int(d)}\n",
    "\n",
    "# # params_k =  {\"name\": \"sinc\", \"var\": d, \"d\": int(d)}\n",
    "\n",
    "# d, params_p, var_k = compute_params_p(args)\n",
    "\n",
    "# # compute params_k; and k_rt for Gauss setting\n",
    "\n",
    "# params_k = {\"name\": args.kernel, \"var\": var_k, \"d\": int(d)}\n",
    "# if params_k[\"name\"] == \"gauss\":\n",
    "#     params_krt = {\"name\": \"gauss_rt\", \"var\": var_k/2., \"d\": int(d)}\n",
    "    \n",
    "# if isnotebook():\n",
    "#     print(\"p\", params_p)\n",
    "#     print(\"k\", params_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Misc code not being used at the moment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# run_iid_thinning = False # iid subsampling\n",
    "# run_kt_split_rand_thining = False # kt.split rand thinning\n",
    "# run_kt_split_best_thining  = False # kt.split best thinning\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "    \n",
    "# if run_kt_split_rand_thining: \n",
    "#     mmds_kt_split = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "#     mmds_kt_split_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "#     fun_diff_kt_split = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P\n",
    "#     fun_diff_kt_split_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin\n",
    "    \n",
    "# if run_kt_split_best_thining: \n",
    "#     mmds_kt_split_best = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "#     mmds_kt_split_best_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "#     fun_diff_kt_split_best = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P\n",
    "#     fun_diff_kt_split_best_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin\n",
    "\n",
    "    \n",
    "# if run_iid_thinning:\n",
    "#     mmds_iid = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P\n",
    "#     mmds_iid_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin\n",
    "#     fun_diff_iid  = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from P\n",
    "#     fun_diff_iid_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for m in ms:\n",
    "#     if run_iid_thinning:\n",
    "#         mmd_iid, mmd_iid_sin, fd_iid, fd_iid_sin = run_iid_thinning_experiment(m, params_p=params_p, rerun=rerun,\n",
    "#                                                                                 params_k_mmd=params_k, rep_ids=rep_ids,\n",
    "#                                                                               compute_mmds=args.computemmd)\n",
    "#         mmds_iid[m, :] = mmd_iid[m, :]\n",
    "#         mmds_iid_sin[m, :] = mmd_iid_sin[m, :]\n",
    "#         fun_diff_iid[m, :] = fd_iid[m, :]\n",
    "#         fun_diff_iid_sin[m, :] = fd_iid_sin[m, :]\n",
    "        \n",
    "\n",
    "#     if run_kt_split_rand_thining:\n",
    "#         mmd_kt_split, mmd_kt_split_sin, fd_kt_split, fd_kt_split_sin = run_kernel_thinning_experiment(m,\n",
    "#                                                             thin_fun=kt_split_rand, thin_str=\"_split_rand\", params_p=params_p, rerun=rerun,\n",
    "#                                                             params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, \n",
    "#                                                             delta=delta, store_K=args.store_K ,\n",
    "#                                                                 compute_mmds=args.computemmd)\n",
    "#         mmds_kt_split[m, :] = mmd_kt_split[m, :]\n",
    "#         mmds_kt_split_sin[m, :] = mmd_kt_split_sin[m, :]\n",
    "#         fun_diff_kt_split[m, :] = fd_kt_split[m, :]\n",
    "#         fun_diff_kt_split_sin[m, :] = fd_kt_split_sin[m, :]\n",
    "        \n",
    "#     if run_kt_split_best_thining:\n",
    "#         mmd_kt_split_best, mmd_kt_split_best_sin, fd_kt_split_best, fd_kt_split_best_sin = run_kernel_thinning_experiment(m,rerun=rerun,\n",
    "#                                                             thin_fun=kt_split_best, thin_str=\"_split_best\", params_p=params_p, \n",
    "#                                                             params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, \n",
    "#                                                             delta=delta, store_K=args.store_K ,\n",
    "#                                                                 compute_mmds=args.computemmd)\n",
    "#         mmds_kt_split_best[m, :] = mmd_kt_split_best[m, :]\n",
    "#         mmds_kt_split_best_sin[m, :] = mmd_kt_split_best_sin[m, :]\n",
    "#         fun_diff_kt_split_best[m, :] = fd_kt_split_best[m, :]\n",
    "#         fun_diff_kt_split_best_sin[m, :] = fd_kt_split_best_sin[m, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#     if run_iid_thinning:\n",
    "#         prefixes = [\"mc-iid\" + prefix for prefix in generic_prefixes]\n",
    "#         data_arrays = [mmds_iid, mmds_iid_sin, fun_diff_iid, fun_diff_iid_sin]\n",
    "#         for prefix, data_array in zip(prefixes, data_arrays):\n",
    "#             filename = get_combined_results_filename(prefix, ms, params_p, params_k, params_k, rep_ids, delta)\n",
    "#             with open(filename, 'wb') as file:\n",
    "#                 print(f\"Saving {prefix} to {filename}\")\n",
    "#                 pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "                \n",
    "                \n",
    "#     if run_kt_split_rand_thining:\n",
    "#         prefixes = [\"kt_split_rand\" + prefix for prefix in generic_prefixes]\n",
    "#         data_arrays = [mmds_kt_split, mmds_kt_split_sin, fun_diff_kt_split, fun_diff_kt_split_sin]\n",
    "#         for prefix, data_array in zip(prefixes, data_arrays):\n",
    "#             filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)\n",
    "#             with open(filename, 'wb') as file:\n",
    "#                 print(f\"Saving {prefix} to {filename}\")\n",
    "#                 pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "\n",
    "#     if run_kt_split_best_thining:\n",
    "#         prefixes = [\"kt_split_best\" + prefix for prefix in generic_prefixes]\n",
    "#         data_arrays = [mmds_kt_split_best, mmds_kt_split_best_sin, fun_diff_kt_split_best, fun_diff_kt_split_best_sin]\n",
    "#         for prefix, data_array in zip(prefixes, data_arrays):\n",
    "#             filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)\n",
    "#             with open(filename, 'wb') as file:\n",
    "#                 print(f\"Saving {prefix} to {filename}\")\n",
    "#                 pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
