{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4a04ef8-b27e-42cd-b184-60ac4113f46c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess # to use subprocess \n",
    "import pathlib\n",
    "import os\n",
    "import os.path\n",
    "import pickle as pkl\n",
    "from datetime import datetime\n",
    "from slurmpy import Slurm\n",
    "import numpy as np\n",
    "from kernelthinning.util import isnotebook # Check whether this file is being executed as a script or as a notebook\n",
    "if isnotebook():\n",
    "    %load_ext autoreload\n",
    "    %autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028e682c-2331-4c73-9388-4916ffb7fe11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_python_command(d, m, sz, rep0, repn, rerun, compute_mmd, recompute_mmd, log_folder,\n",
    "                       symm1=1, rh2 = 0,\n",
    "                   alg = \"kt\", compress_alg=None, n_compress=None, alpha=None,\n",
    "                       setting = \"gauss\", M=None, filename=None\n",
    "                         ):\n",
    "    '''\n",
    "        get python command prefix for experiment settings\n",
    "        alpha: the blow up factor in log_2 base for compress_and_thin experiments\n",
    "        n_compress: number of compress coresets to be generated\n",
    "        compress_alg: the algorithm to be used for thinning in compress\n",
    "        m: thinning factor in log_2 base (useful typically only for kt.thin)\n",
    "        sz: the size of input data in log_4 base \n",
    "        (by default we usually set m = sz)\n",
    "        M: number of Mog components\n",
    "    '''\n",
    "    exp_name =  ['python3', f'construct_{alg}_coresets.py']\n",
    "                                   \n",
    "    if alg == \"compress_thin\" or alg == \"compress\":\n",
    "        assert(compress_alg is not None)\n",
    "        exp_name.extend(['--compressalg', f'{compress_alg}']) \n",
    "    if alg == \"compress_thin\":\n",
    "        assert(alpha is not None)\n",
    "        exp_name.extend(['--alpha', f'{alpha}']) \n",
    "    \n",
    "    if alg == \"compress\":\n",
    "        assert(n_compress is not None)\n",
    "        exp_name.extend(['--ncompress', f'{n_compress}']) \n",
    "     \n",
    "    if setting == \"mog\":\n",
    "        exp_name.extend(['--M', f'{M}'])\n",
    "    \n",
    "    if setting != \"mcmc\":\n",
    "        exp_name.extend(['--d', f'{d}'])# add dimension for non mcmc settings\n",
    "    else:\n",
    "        exp_name.extend(['--filename', f'{filename}'])\n",
    "       \n",
    "    exp_name.extend(['--m', f'{m}', f'--size', f'{sz}', '--setting', f'{setting}'])\n",
    "    \n",
    "    exp_name.extend(['--rep0', f'{rep0}', '--repn', f'{repn}','--rerun' ,str(rerun) ,'--computemmd',\n",
    "                str(compute_mmd), \n",
    "                     '--recomputemmd', str(recompute_mmd),\n",
    "                     '--symm1', str(symm1), '--rh2', str(rh2),\n",
    "                     '--recomputemmd', str(recompute_mmd),\n",
    "                    \n",
    "                    ])\n",
    "    \n",
    "    log_file = ''.join(exp_name[1:])\n",
    "    # removing redundant characters to save space\n",
    "    log_file = log_file.replace(\"--\", \"-\")\n",
    "    log_file = log_file.replace(\"construct_\", \"\")\n",
    "    log_file = log_file.replace(\"_coresets.py\", \"\")\n",
    "    # adding time stamp\n",
    "    suffix = datetime.now().strftime('%H_%M')\n",
    "    out_file = os.path.join(log_folder, log_file+suffix+\".out\")\n",
    "    err_file = os.path.join(log_folder, log_file+suffix+\".err\")\n",
    "    \n",
    "    return(exp_name, out_file, err_file)\n",
    "\n",
    "def deploy_terminal_run(exp_name, out, err):\n",
    "    '''\n",
    "        start a separate python subprocess to run the command in exp_name, \n",
    "        and save std_out to out, and std_err to err\n",
    "    '''\n",
    "    with open(out, \"wb\") as f:\n",
    "        with open(err, \"wb\") as f2:\n",
    "            subprocess.Popen(exp_name, stdout=f, stderr=f2)\n",
    "        return\n",
    "    \n",
    "def deploy_slurm_run(exp_name, partition, prefix):\n",
    "    if partition != \"jsteinhardt\":\n",
    "        s = Slurm(prefix, {\"partition\": partition, \n",
    "#                      \"mem\":\"5G\", \n",
    "                     \"c\": 1\n",
    "                       \n",
    "                    })\n",
    "    else:\n",
    "        \n",
    "        s = Slurm(prefix, {\"partition\": partition, \n",
    "#                            \"requeue\": 1\n",
    "                        })\n",
    "    s.run('module load python; ' + \" \".join(exp_name))\n",
    "    return(s)\n",
    "\n",
    "def deploy_experiment(deploy_slurm, ids, exp_name, out, err, partition, prefix, debug=False):\n",
    "    '''\n",
    "    deploy_slurm:\n",
    "    '''\n",
    "    if debug:\n",
    "        ids.append(\" \".join(exp_name))\n",
    "        return\n",
    "    \n",
    "    if deploy_slurm:\n",
    "        ids.append(deploy_slurm_run(partition=partition, prefix=prefix, exp_name=exp_name))\n",
    "    else:\n",
    "        deploy_terminal_run(exp_name, out, err)\n",
    "        ids.append(exp_name)\n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0769a09-6420-4a4f-9cc3-905612d146fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "deploy_slurm = True\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73285091-44cb-4ffd-bed0-a921f7c2b57d",
   "metadata": {},
   "source": [
    "# Experiment parameters (that need to be adjusted)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f8513ad-0ced-4d1e-9c08-03dc18491835",
   "metadata": {},
   "outputs": [],
   "source": [
    "### ASSUMING SIZE = 4^m; and final output = sqrt(SIZE) = 2^m #####\n",
    "m_min = 9 # min size will be 4**m_min\n",
    "m_max = 10 # max size will be 4**m_max\n",
    "alpha_min = 0 # the min value of alpha in compress; will always be  be truncated to \n",
    "            # max(m, alpha_min) while calling\n",
    "alpha_max = 4 # the max value of alpha in compress; will be truncated to \n",
    "            # min(m, alpha_max) while calling\n",
    "\n",
    "# define repetition\n",
    "total_reps = 10 # set this to the max number of repetitions\n",
    "reps_per_job = 1 # number of reps per python call\n",
    "\n",
    "# which type of coresets\n",
    "st_coresets = False\n",
    "kt_coresets = False\n",
    "\n",
    "compress_alg = \"thin\" # refers to kt.thin\n",
    "compress_thin_coresets = False\n",
    "compress_coresets = False\n",
    "n_compress = 1 # CHECK!!!!!! how many compress coresets for each ssetting\n",
    "# sgd_coresets = False\n",
    "herding_coresets = False\n",
    "cp_herding_coresets = True\n",
    "\n",
    "\n",
    "## whether want use symetric compress in stage 1; and recursive halve in stage 2\n",
    "symm1 = 1\n",
    "rh2 = 0\n",
    "\n",
    "## whether to compute mmd\n",
    "compute_mmd = 1\n",
    "\n",
    "### whether to regenerate coreset / recompute MMD / IMPORTANT NOTE DIFFERENT FLAGS ####\n",
    "### if we want to recompute mmds (for some reason) then we have to set that flag to True ##\n",
    "rerun = 0 # re generate coresets\n",
    "recompute_mmd = 0 # do not recompute mmd; works ONLY if compute_mmd is true in the first place\n",
    "\n",
    "### All experiments are run with Gauss(sigma) as k and Gauss(sigma/sqrt(2)) as krt ###\n",
    "run_gauss_experiments = True # run experiments with Gauss P\n",
    "ds = [2, 4, 10, 100]\n",
    "\n",
    "\n",
    "run_mog_experiments = True # run experiments with MoG P\n",
    "Ms = [4, 6, 8] # supports, 3, 4, 6, 8\n",
    "\n",
    "run_mcmc_experiments = False # run experiments with MCMC P\n",
    "all_mcmc_filenames = ['Goodwin_RW', 'Goodwin_ADA-RW', \n",
    "'Goodwin_MALA', 'Goodwin_PRECOND-MALA', \n",
    "'Lotka_RW', 'Lotka_ADA-RW', \n",
    "'Lotka_MALA', 'Lotka_PRECOND-MALA',  \n",
    "'Hinch_P_seed_1_temp_1_scaled', 'Hinch_P_seed_2_temp_1_scaled', \n",
    "'Hinch_TP_seed_1_temp_8_scaled', \n",
    "'Hinch_TP_seed_2_temp_8_scaled']\n",
    "\n",
    "\n",
    "# supports the following 12 cases\n",
    "# ['Goodwin_RW', 'Goodwin_ADA-RW', \n",
    "# 'Goodwin_MALA', 'Goodwin_PRECOND-MALA', \n",
    "# 'Lotka_RW', 'Lotka_ADA-RW', \n",
    "# 'Lotka_MALA', 'Lotka_PRECOND-MALA'  \n",
    "# 'Hinch_P_seed_1_temp_1_scaled', 'Hinch_P_seed_2_temp_1_scaled', \n",
    "# 'Hinch_TP_seed_1_temp_8_scaled', \n",
    "# 'Hinch_TP_seed_2_temp_8_scaled']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "147d41a0-caef-451e-b06e-137d0bdcc061",
   "metadata": {},
   "source": [
    "# Gauss Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "954db08f-fb0c-4446-95d5-01736eca36b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "partitions = [\"high\", \"yugroup\", \"jsteinhardt\", \"low\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08bb5462-e36e-4116-bd1f-0094671c0866",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress, alg, target, Ms=None, filenames=None):\n",
    "    if filenames is not None:\n",
    "        print(f'Running {alg} experiments (ncompress={n_compress}, alpha_range={range(alpha_min, alpha_max)}, for {target} target for settings = {filenames} in d = {ds}, '\n",
    "              + f'm = {range(m_min, m_max+1)}, reps = {range(0, total_reps)},' +\n",
    "                 f' with {reps_per_job} reps per python call')\n",
    "    elif Ms is not None:\n",
    "        print(f'Running {alg} experiments (ncompress={n_compress}) alpha_range={range(alpha_min, alpha_max)}, for {target} target for M = {Ms} in d = {ds}, ' +\n",
    "                  f'm = {range(m_min, m_max+1)}, reps = {range(0, total_reps)},' +\n",
    "                 f' with {reps_per_job} reps per python call')\n",
    "    else:\n",
    "        print(f'Running {alg} experiments (ncompress={n_compress}) alpha_range={range(alpha_min, alpha_max)}, for {target} target for d = {ds}, ' +\n",
    "                  f'm = {range(m_min, m_max+1)}, reps = {range(0, total_reps)},' +\n",
    "                 f' with {reps_per_job} reps per python call')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2201b356-9e9e-4c52-a4e4-201d02dfb5b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "debug = False\n",
    "partition = partitions[1] # \"high\", \"yugroup\", \"jsteinhardt\", \"low\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "879401ef-b83b-468d-8882-2b93d67d69a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "gauss_ids = []\n",
    "if run_gauss_experiments:\n",
    "    log_folder = f\"logs/{datetime.now().strftime('%b_%d_%Y')}\"\n",
    "    pathlib.Path(log_folder).mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    target=\"gauss\"\n",
    "    \n",
    "    if st_coresets:\n",
    "        alg = \"st\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target)\n",
    "        for d in ds:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                              setting=target,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[:1] + f\"d{d}n{m}r{i}\"\n",
    "                    deploy_experiment(deploy_slurm, gauss_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    \n",
    "    if kt_coresets:\n",
    "        alg = \"kt\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target)\n",
    "        for d in ds:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                              setting=target,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[:1] + f\"d{d}n{m}r{i}\"\n",
    "                    deploy_experiment(deploy_slurm, gauss_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    if compress_thin_coresets:\n",
    "        alg = \"compress_thin\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target)\n",
    "        for d in ds:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for alpha in range(min(alpha_min, m), min(alpha_max, m)+1):\n",
    "                    for i in range(0, total_reps, reps_per_job):\n",
    "                        # experiment command, and filenames\n",
    "                        exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg,  n_compress=None, alpha=alpha,\n",
    "                                            setting=target,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                        prefix = alg[:1] + str(alpha) + target[:1] + f\"d{d}n{m}r{i}\"\n",
    "                        deploy_experiment(deploy_slurm, gauss_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "                            \n",
    "                                \n",
    "    if compress_coresets:\n",
    "        alg  = \"compress\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=n_compress, alg=alg, \n",
    "                          target=target)\n",
    "        for d in ds:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg, n_compress=n_compress, alpha=None,\n",
    "                                              setting=target,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[:1] + f\"d{d}n{m}r{i}\"\n",
    "                    deploy_experiment(deploy_slurm, gauss_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "                    \n",
    "        \n",
    "    print(f'{partition} partition; Number of processes/jobs:{len(gauss_ids)} slurm:{deploy_slurm}, terminal: {not deploy_slurm}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01b0d687-83d2-4254-9f06-5b0b022d4a98",
   "metadata": {},
   "source": [
    "## herding Gauss experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2376afa3-3b5f-4b8f-b5fd-f693f06278ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "rerun = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc463a03-b40e-4395-ae3a-d3b15f4847d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "debug = False\n",
    "partition = partitions[2] # \"high\", \"yugroup\", \"jsteinhardt\", \"low\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc3b00ab-21f5-4751-9eda-1a658d428cf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "gauss_ids = []\n",
    "if run_gauss_experiments:\n",
    "\n",
    "    log_folder = f\"logs/{datetime.now().strftime('%b_%d_%Y')}\"\n",
    "    pathlib.Path(log_folder).mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    target=\"gauss\"\n",
    "    \n",
    "    if herding_coresets:\n",
    "        alg = \"herding\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target)\n",
    "        for d in ds:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                              setting=target,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[:1] + f\"d{d}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, gauss_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    if cp_herding_coresets:\n",
    "        alg = \"compress_thin\"\n",
    "        compress_alg = \"herding\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target)\n",
    "        for d in ds:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for alpha in range(min(alpha_min, m), min(alpha_max, m)+1):\n",
    "                    for i in range(0, total_reps, reps_per_job):\n",
    "                        exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg,  n_compress=None, alpha=alpha,\n",
    "                                            setting=target,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                        prefix = alg[:1] + str(alpha) + \"h\"  + target[:1] + f\"d{d}n{m}r{i}\"\n",
    "                        deploy_experiment(deploy_slurm, gauss_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "                    \n",
    "                        \n",
    "    print(f'Herding: {partition} partition; Number of processes/jobs:{len(gauss_ids)} slurm:{deploy_slurm}, terminal: {not deploy_slurm}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25eb10ca-4f36-48b3-b95a-24b3897be55b",
   "metadata": {},
   "source": [
    "## Herding MOG experimentsm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d03ee595-5826-42fe-bbfa-d6d10e650fd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "debug = False\n",
    "partition = partitions[0] # \"high\", \"yugroup\", \"jsteinhardt\", \"low\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f433b6b9-1424-4d21-bab3-ebed62078ca1",
   "metadata": {},
   "outputs": [],
   "source": [
    "mog_ids = []\n",
    "if run_mog_experiments:\n",
    "    log_folder = f\"logs/{datetime.now().strftime('%b_%d_%Y')}\"\n",
    "    pathlib.Path(log_folder).mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    d = 2\n",
    "    target=\"mog\"\n",
    "    \n",
    "    if herding_coresets:\n",
    "        alg = \"herding\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=Ms)\n",
    "        for M in Ms:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    exp_name, out, err =  get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                            setting=target, M=M,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[1] + f\"d{d}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, mog_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    if cp_herding_coresets:\n",
    "        alg = \"compress_thin\"\n",
    "        compress_alg = \"herding\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=Ms)\n",
    "        for M in Ms:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for alpha in range(min(alpha_min, m), min(alpha_max, m)+1):\n",
    "                    for i in range(0, total_reps, reps_per_job):\n",
    "                        exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg,  n_compress=None, \n",
    "                                              alpha=alpha, setting=target, M=M,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                        prefix = alg[:1] + str(alpha) + \"h\" + target[1] + f\"d{d}n{m}r{i}\"\n",
    "                        deploy_experiment(deploy_slurm, mog_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "    print(f'Herding: {partition} partition; Number of processes/jobs:{len(mog_ids)} slurm:{deploy_slurm}, terminal: {not deploy_slurm}')    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b8bb2d6-eba2-45b3-8e13-eb4897ff7a97",
   "metadata": {},
   "source": [
    "# Run MOG experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace5375d-ebcb-499c-bab5-d5e7b612aed7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# partition = 'jsteinhardt' # 'jsteinhardt' \n",
    "\n",
    "debug = False\n",
    "partition = partitions[1]  # \"high\", \"yugroup\", \"jsteinhardt\", \"low\"\n",
    "mog_ids = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5b77031-9c5e-4015-96c1-7a443f3c8f62",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "if run_mog_experiments:\n",
    "    log_folder = f\"logs/{datetime.now().strftime('%b_%d_%Y')}\"\n",
    "    pathlib.Path(log_folder).mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    d = 2\n",
    "    target=\"mog\"\n",
    "    \n",
    "    if st_coresets:\n",
    "        alg = \"st\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=Ms)\n",
    "        for M in Ms:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                            setting=target, M=M,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[1] + f\"M{M}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, mog_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "                    \n",
    "    if kt_coresets:\n",
    "        alg = \"kt\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=Ms)\n",
    "        for M in Ms:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                            setting=target, M=M,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[1] + f\"M{M}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, mog_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    if compress_thin_coresets:\n",
    "        alg = \"compress_thin\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=Ms)\n",
    "        for M in Ms:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for alpha in range(min(alpha_min, m), min(alpha_max, m)+1):\n",
    "                    for i in range(0, total_reps, reps_per_job):\n",
    "                        # experiment command, and filenames\n",
    "                        exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg,  n_compress=None, \n",
    "                                              alpha=alpha, setting=target, M=M,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                        prefix = alg[:1] + str(alpha) + target[1] + f\"M{M}n{m}r{i}\" \n",
    "                        deploy_experiment(deploy_slurm, mog_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "                                \n",
    "    if compress_coresets:\n",
    "        alg  = \"compress\"\n",
    "        print_exp_setting(ds, m_min, m_max, total_reps, reps_per_job, n_compress=n_compress, alg=alg, \n",
    "                          target=target, Ms=Ms)\n",
    "        for M in Ms:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg, n_compress=n_compress, \n",
    "                                              alpha=None, setting=target, M=M,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + target[1] + f\"d{d}n{m}r{i}\"\n",
    "                    deploy_experiment(deploy_slurm, mog_ids, exp_name, out, err, partition, prefix=alg, debug=debug)\n",
    "                        \n",
    "    print(f'{partition} partition; Number of processes/jobs:{len(mog_ids)} slurm:{deploy_slurm}, terminal: {not deploy_slurm}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "093ab5c2-5490-4915-9ba4-086f2c01108b",
   "metadata": {},
   "source": [
    "# Run MCMC experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70308d84-c5dc-4938-9355-96ed70915064",
   "metadata": {},
   "outputs": [],
   "source": [
    "mcmc_file_range = slice(11, 12)\n",
    "mcmc_filenames = all_mcmc_filenames[mcmc_file_range]\n",
    "print(mcmc_filenames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9938d466-4051-4735-9044-4dd17219cafa",
   "metadata": {},
   "outputs": [],
   "source": [
    "debug = False\n",
    "partition = partitions[1] # \"high\", \"yugroup\", \"jsteinhardt\", \"low\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1870631-f02f-4536-93e2-732b0de70406",
   "metadata": {},
   "outputs": [],
   "source": [
    "st_coresets = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0de2f867-02a9-4134-b7ee-6972b7104893",
   "metadata": {},
   "outputs": [],
   "source": [
    "mcmc_ids = []\n",
    "if run_mcmc_experiments:\n",
    "    log_folder = f\"logs/{datetime.now().strftime('%b_%d_%Y')}\"\n",
    "    pathlib.Path(log_folder).mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    d = 0 # None # no fixed d\n",
    "    target=\"mcmc\"\n",
    "    \n",
    "    if st_coresets:\n",
    "        alg = \"st\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=None,\n",
    "                         filenames=mcmc_filenames)\n",
    "        for filename in mcmc_filenames:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                            setting=target, filename=filename,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + filename[:2] + f\"d{d}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, mcmc_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "                    \n",
    "    if kt_coresets:\n",
    "        alg = \"kt\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=None,\n",
    "                         filenames=mcmc_filenames)\n",
    "        for filename in mcmc_filenames:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=None, n_compress=None, alpha=None,\n",
    "                                            setting=target, filename=filename,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + filename[:2] + f\"d{d}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, mcmc_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    if compress_thin_coresets:\n",
    "        alg = \"compress_thin\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=None,\n",
    "                         filenames=mcmc_filenames)\n",
    "        for filename in mcmc_filenames:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for alpha in range(min(alpha_min, m), min(alpha_max, m)+1):\n",
    "                    for i in range(0, total_reps, reps_per_job):\n",
    "                        # experiment command, and filenames\n",
    "                        exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg,  n_compress=None, \n",
    "                                              alpha=alpha, setting=target, filename=filename,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                        prefix = alg[:1] + str(alpha) + filename[:2] + f\"d{d}n{m}r{i}\" \n",
    "                        deploy_experiment(deploy_slurm, mcmc_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "                                \n",
    "    if compress_coresets:\n",
    "        alg  = \"compress\"\n",
    "        print_exp_setting(d, m_min, m_max, total_reps, reps_per_job, n_compress=None, alg=alg, target=target, Ms=None,\n",
    "                         filenames=mcmc_filenames)\n",
    "        for filename in mcmc_filenames:\n",
    "            for m in range(m_min, m_max+1):\n",
    "                for i in range(0, total_reps, reps_per_job):\n",
    "                    # experiment command, and filenames\n",
    "                    exp_name, out, err = get_python_command(d=d, m=m, sz=m, rep0=i, repn=reps_per_job, \n",
    "                                              rerun=rerun, compute_mmd=compute_mmd, \n",
    "                                              recompute_mmd=recompute_mmd, log_folder=log_folder,\n",
    "                                              alg=alg, compress_alg=compress_alg, n_compress=n_compress, \n",
    "                                              alpha=None, setting=target, filename=filename,\n",
    "                                            symm1=symm1, rh2=rh2)\n",
    "                    prefix = alg[:1] + filename[:2] + f\"d{d}n{m}r{i}\" \n",
    "                    deploy_experiment(deploy_slurm, mcmc_ids, exp_name, out, err, partition, prefix=prefix, debug=debug)\n",
    "\n",
    "    print(f'{partition} partition; Number of processes/jobs:{len(mcmc_ids)} slurm:{deploy_slurm}, terminal: {not deploy_slurm}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2a1c099-a93d-4659-a8af-731f3447e19b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mcmc_ids[-26:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19f764a3-e592-4c7c-940b-0d17d37ca536",
   "metadata": {},
   "source": [
    "# Runtime experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2666060-c7a6-481b-8066-e37b3deb1f65",
   "metadata": {},
   "outputs": [],
   "source": [
    "from slurmpy import Slurm\n",
    "\n",
    "def slurm_command(d, size, rep0, talg, alpha, rerun=0, prefix='', compresalg=\"thin\"):\n",
    "    fix_param_str = 'module load python; python3 run_time.py ' \n",
    "    param_str = fix_param_str +  f' -setting gauss -d {str(d)} -m {size} -sz {size} -r0 {rep0} -rn {str(1)}'\n",
    "    param_str += f' -alp {alpha} -talg {talg}  -rr {str(rerun)}'\n",
    "    param_str += f' -ca  {compresalg}'\n",
    "    if prefix != \"\":\n",
    "        param_str += f' -prefix {prefix}'\n",
    "    return(param_str)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4c1c83f-bb50-444d-944f-663382fdf90c",
   "metadata": {},
   "outputs": [],
   "source": [
    "debug = False\n",
    "rt_list = []\n",
    "partition = 'yugroup'\n",
    "c = 1\n",
    "# mem = \"32G\"\n",
    "\n",
    "num_trials = 3\n",
    "rerun = 1\n",
    "alphas = [0, 4]\n",
    "alpha_max = 4\n",
    "sizes = range(10, 11)\n",
    "ds = [2, 4, 10, 100]\n",
    "compress_alg = \"herding\" # \"herding\"\n",
    "for r0 in range(num_trials):\n",
    "    for d in ds:\n",
    "        for size in sizes:        \n",
    "            for talg in [\"cpthin\"]:\n",
    "                s = Slurm(f\"{talg[0]}d{d}n{size}r{r0}\", {\"partition\": partition,  \"c\": cores, \n",
    "#                                                          \"mem\": mem\n",
    "                                                        })\n",
    "                prefix = f'part_{partition}_c{c}_mem_{mem}' \n",
    "                if talg == \"cpthin\":\n",
    "                    for alpha in alphas:\n",
    "                        if debug:\n",
    "                            rt_list.append(slurm_command(d, size, r0, talg, alpha, rerun, prefix, compress_alg))\n",
    "                        else:\n",
    "                            s.run(slurm_command(d, size, r0, talg, alpha, rerun, prefix, compress_alg))\n",
    "                else:\n",
    "                    if debug:\n",
    "                        rt_list.append(slurm_command(d, size, r0, talg, 0, rerun))\n",
    "                    else:\n",
    "                        s.run(slurm_command(d, size, r0, talg, 0, rerun))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d947d0e-67fd-4b34-bf3c-42f4bdfd83cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(rt_list); rt_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b150404f-646b-4e73-9c28-6730bcdd7ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "partition = 'high'\n",
    "cores = 1\n",
    "mem = \"128G\"\n",
    "\n",
    "num_trials = 3\n",
    "rerun = 1\n",
    "alpha_max = 4\n",
    "sizes = range(0, 10)\n",
    "ds = [2, 4, 10, 100]\n",
    "for r0 in range(num_trials):\n",
    "    for d in ds:\n",
    "        for size in sizes:        \n",
    "            for talg in [\"kt\", \"cpthin\"]:\n",
    "                s = Slurm(f\"{talg}rt\", {\"partition\": partition,  \"c\": cores, \"mem\": mem})\n",
    "                prefix = f'part_{partition}_c{c}_mem_{mem}' \n",
    "                if talg == \"cpthin\":\n",
    "                    for alpha in range(0, min(alpha_max, size)+1):\n",
    "                        s.run(slurm_command(d, size, r0, talg, alpha, rerun, prefix))\n",
    "                else:\n",
    "                    s.run(slurm_command(d, size, r0, talg, 0, rerun))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2d2994e-0d93-4033-89a6-dc688181c091",
   "metadata": {},
   "source": [
    "### KT new vs KT old"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bc6e1fb-b20e-4ec3-9143-cfe185b606ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "partition = 'high'\n",
    "mem = \"8G\" # \"32G\"\n",
    "cores = 4 # 1\n",
    "num_trials = 2\n",
    "rerun = 1\n",
    "alpha_range = [0, 4]\n",
    "sizes = range(4, 9)\n",
    "ds = [2, 10]\n",
    "for r0 in range(num_trials):\n",
    "    for d in ds:\n",
    "        for size in sizes:        \n",
    "            for talg in [\"kt\", \"ktold\"]:\n",
    "                s = Slurm(f\"{talg}-d{d}-n{size}-c{cores}-m{mem}-t{r0}\", {\"partition\": partition,  \"c\": cores, \"mem\": mem })\n",
    "                prefix = f'p_{partition}_c{cores}_mem_{mem}' \n",
    "#                 if talg == \"cpthin\":\n",
    "#                     for alpha in alpha_range:\n",
    "#                         s.run(slurm_command(d, size, r0, talg, alpha, rerun, prefix))\n",
    "#                 else:\n",
    "                s.run(slurm_command(d, size, r0, talg, 0, rerun, prefix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bfdb2a6-6f97-46a2-a418-4b14a8f600eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "partition = 'high'\n",
    "mem = \"8G\" # \"32G\"\n",
    "cores = 4 # 1\n",
    "num_trials = 2\n",
    "rerun = 1\n",
    "alpha_range = [0, 4]\n",
    "sizes = range(4, 9)\n",
    "ds = [2, 10]\n",
    "for r0 in range(num_trials):\n",
    "    for d in ds:\n",
    "        for size in sizes:        \n",
    "            for talg in [\"kt\", \"ktold\"]:\n",
    "                s = Slurm(f\"{talg}-d{d}-n{size}-c{cores}-m{mem}-t{r0}\", {\"partition\": partition,  \"c\": cores, \"mem\": mem })\n",
    "                prefix = f'p_{partition}_c{cores}_mem_{mem}' \n",
    "#                 if talg == \"cpthin\":\n",
    "#                     for alpha in alpha_range:\n",
    "#                         s.run(slurm_command(d, size, r0, talg, alpha, rerun, prefix))\n",
    "#                 else:\n",
    "                s.run(slurm_command(d, size, r0, talg, 0, rerun, prefix))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
