{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "46594c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch, os, pickle\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.lines import Line2D\n",
    "from matplotlib.patches import Patch\n",
    "from IPython.display import clear_output\n",
    "from tqdm.notebook import tqdm\n",
    "import gc\n",
    "\n",
    "# load in our wall-clock timing logs\n",
    "sequential_logs = pd.read_csv(\"logs_sequential.csv\")\n",
    "parallel_logs = pd.read_csv(\"logs_parallel.csv\")\n",
    "\n",
    "# simulation settings\n",
    "Bs = [1, 2, 4, 8, 16, 32]\n",
    "Ls = [1_000, 2_000, 4_000, 8_000, 16_000, 32_000, 64_000]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06d9a6e0",
   "metadata": {},
   "source": [
    "# 1. Computing Wall-Clock Time Summaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "98fead2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create our summary data frames\n",
    "sequential_summaries = pd.DataFrame(data=None, columns=[\"B\", \"L\", \"lb\", \"med\", \"ub\"])\n",
    "parallel_summaries = pd.DataFrame(data=None, columns=[\"B\", \"L\", \"lb\", \"med\", \"ub\"])\n",
    "\n",
    "# generating our summary logs\n",
    "for B in Bs:\n",
    "    for L in Ls:\n",
    "        \n",
    "        # a. get the SEQUENTIAL 0.05, 0.5, and 0.95 quantiles\n",
    "        lb, med, ub = sequential_logs.query(\n",
    "            f\"B == {B} and L == {L} and method == 'sequential'\").time.quantile([0.05, 0.5, 0.95]).values\n",
    "        sequential_summaries.loc[len(sequential_summaries.index)] = [B, L, lb, med, ub]\n",
    "        \n",
    "        # b. get the PARALLEL 0.05, 0.5, and 0.95 quantiles\n",
    "        lb, med, ub = parallel_logs.query(\n",
    "            f\"B == {B} and L == {L} and method == 'qdeer'\").time.quantile([0.05, 0.5, 0.95]).values\n",
    "        parallel_summaries.loc[len(parallel_summaries.index)] = [B, L, lb, med, ub]\n",
    "        \n",
    "# save our log summaries\n",
    "sequential_summaries.to_csv(\"sequential_summaries.csv\", index=False)\n",
    "parallel_summaries.to_csv(\"parallel_summaries.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1956f97",
   "metadata": {},
   "source": [
    "# 2. Computing Full-Convergence MMDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8576db6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's set a maximum sample size of 50K to not burn a GPU\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "MAX_SAMPLE_SIZE, SIGMA = 50_000, 0.3545334045797586\n",
    "\n",
    "# helper function for computing MMD\n",
    "def MMD(X1, X2, sigma):\n",
    "    \n",
    "    '''\n",
    "    Gaussian kernel function: k(x, y) = exp(-|| x - y ||^2 / (2\\sigma^2))\n",
    "    - instead of computing full for-loops to get the mean, we will randomly sample.\n",
    "    '''\n",
    "    \n",
    "    # let's work with a smaller cdist matrix\n",
    "    if X1.shape[0] > MAX_SAMPLE_SIZE:\n",
    "        X1 = X1[np.random.choice(a=X1.shape[0], size=MAX_SAMPLE_SIZE, replace=False)]\n",
    "    \n",
    "    # 1/[m(m-1)] * \\sum_i \\sum_{j \\neq i} k(x_i, x_j)\n",
    "    t1 = torch.exp(-(torch.cdist(X1, X1) ** 2) / (2.0 * (sigma ** 2))).mean()\n",
    "    \n",
    "    # 1/[n(n-1)] * \\sum_i \\sum_{j \\neq i} k(y_i, y_j)\n",
    "    t2 = torch.exp(-(torch.cdist(X2, X2) ** 2) / (2.0 * (sigma ** 2))).mean()\n",
    "    \n",
    "    # -2/[mn] * \\sum_i \\sum_j k(x_i, y_j)\n",
    "    t3 = -2.0 * torch.exp(-(torch.cdist(X1, X2) ** 2) / (2.0 * (sigma ** 2))).mean()\n",
    "    \n",
    "    # compute the MMD\n",
    "    return (t1 + t2 + t3).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "55459ddc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on setting 1680 of 1680: B=1, L=1000, seed=19, method=SEQUENTIAL.\n"
     ]
    }
   ],
   "source": [
    "# load in our true samples\n",
    "X_truth = torch.tensor(np.load(\"truth/NUTS_blr-gcp_d=24_whiten=True.npy\"), device=device)\n",
    "n_trials, D, counter = 10, 24, 0\n",
    "\n",
    "# set a seed for reproducibility\n",
    "torch.random.manual_seed(858)\n",
    "\n",
    "# create a dataframe to store our results\n",
    "mmd_logs = pd.DataFrame(data=None, columns=[\"B\", \"L\", \"seed\", \"method\", \"mmd\"])\n",
    "\n",
    "for B in Bs:\n",
    "    for L in Ls:\n",
    "        for seed in range(20):\n",
    "            for method in [\"QDEER\", \"SEQUENTIAL\"]:\n",
    "            \n",
    "                # 0. status update\n",
    "                clear_output(wait=True)\n",
    "                print(\n",
    "                    f\"Working on setting {counter + 1} of {2 * 20 * len(Bs) * len(Ls)}:\" + \n",
    "                    f\" B={B}, L={L}, seed={seed}, method={method}.\")\n",
    "\n",
    "                # 1. load in our samples\n",
    "                fname = f\"{method}_B={B}_L={L}_seed={seed}.npz\"\n",
    "                if fname in os.listdir(\"samples\"):\n",
    "                    X_samples = torch.tensor(\n",
    "                        np.load(f\"samples/{fname}\")[\"samples\"], \n",
    "                        device=device).reshape(-1, D)\n",
    "                else:\n",
    "                    X_samples = None\n",
    "                    counter += 1\n",
    "                    continue\n",
    "\n",
    "                # 2. compute our MMD over multiple trials if necessary\n",
    "                if X_samples.shape[0] > MAX_SAMPLE_SIZE:\n",
    "                    mmd = np.mean([MMD(X1=X_samples, X2=X_truth, sigma=SIGMA) for _ in tqdm(range(n_trials))])\n",
    "                else:\n",
    "                    mmd = MMD(X1=X_samples, X2=X_truth, sigma=SIGMA)\n",
    "\n",
    "                # 3. put in the MMD\n",
    "                mmd_logs.loc[len(mmd_logs.index)] = [B, L, seed, method, mmd]\n",
    "\n",
    "                # 4. update our counter + clear our cache\n",
    "                counter += 1\n",
    "                del X_samples\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()\n",
    "            \n",
    "# save as a .csv\n",
    "mmd_logs.to_csv(\"full-convergence_mmd_logs.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cd88035",
   "metadata": {},
   "source": [
    "# 3. Computing Time Multipliers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30019105",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load in our wall-clock timing logs\n",
    "sequential_logs = pd.read_csv(\"logs_sequential.csv\")\n",
    "parallel_logs = pd.read_csv(\"logs_parallel.csv\")\n",
    "\n",
    "# on each seed, how many times faster was Q-DEER faster than SEQUENTIAL?\n",
    "multipliers = pd.concat(\n",
    "    [\n",
    "        sequential_logs[[\"B\", \"L\", \"seed\", \"time\"]]\\\n",
    "        .rename(columns={\"time\" : \"sequential\"}).set_index([\"B\", \"L\", \"seed\"]),\n",
    "        parallel_logs[[\"B\", \"L\", \"seed\", \"time\"]]\\\n",
    "        .rename(columns={\"time\" : \"qdeer\"}).set_index([\"B\", \"L\", \"seed\"])\n",
    "    ], axis=1).reset_index()\n",
    "multipliers[\"multiplier\"] = multipliers.sequential / multipliers.qdeer\n",
    "multipliers = multipliers[[\"B\", \"L\", \"seed\", \"multiplier\"]] # seq. time / parallel time\n",
    "multipliers.to_csv(\"seq-time-divided-by-parallel-time.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0d4d696",
   "metadata": {},
   "source": [
    "# 4. Numbers of Iteration Until Convergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff64e361",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load in the convergence iteration metrics\n",
    "with open(\"convg_iters.pickle\", \"rb\") as file:\n",
    "    convg_iters = pickle.load(file)\n",
    "    \n",
    "# dataframe to store our logs on distribution of iterations until convergence\n",
    "convg_logs = pd.DataFrame(data=None, columns=[\"B\", \"L\", \"seed\", \"q25\", \"q50\", \"q75\", \"q90\"])\n",
    "\n",
    "# go thru all possible completed settings: want 90%, 75%, 50%, and 25% quantiles within each B.\n",
    "for (B, L, seed) in convg_iters.keys():\n",
    "    \n",
    "    # get the number of iterations that each chain took to converge\n",
    "    iters = convg_iters[(B, L, seed)]\n",
    "    convg_logs.loc[len(convg_logs.index)] = [B, L, seed] + list(np.quantile(iters, q=[0.25, 0.5, 0.75, 0.9]))\n",
    "    \n",
    "# save as a .csv\n",
    "convg_logs.to_csv(\"iters-to-converge_logs.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4630a7cc",
   "metadata": {},
   "source": [
    "# 5. Summaries of Wall-Clock Time vs. MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4767f748",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load in our relevant log files (note that these are medians)\n",
    "parallel_times = pd.read_csv(\"parallel_summaries.csv\"); parallel_times[\"method\"] = \"qdeer\"\n",
    "sequential_times = pd.read_csv(\"sequential_summaries.csv\"); sequential_times[\"method\"] = \"sequential\"\n",
    "\n",
    "# take the medians of the mmds for each combination of B, L, method\n",
    "mmds = mmds.groupby([\"B\", \"L\", \"method\"]).median().reset_index()\n",
    "\n",
    "# concatenate into one big dataframe\n",
    "times = pd.concat([parallel_times, sequential_times])[[\"B\", \"L\", \"method\", \"med\"]]\n",
    "times.rename(columns={\"med\" : \"time\"}, inplace=True)\n",
    "logs = pd.concat([times.set_index([\"B\", \"L\", \"method\"]), \n",
    "                  mmds.set_index([\"B\", \"L\", \"method\"])], axis=1)\\\n",
    ".reset_index()[[\"B\", \"L\", \"method\", \"time\", \"mmd\"]].dropna()\n",
    "\n",
    "# save our results\n",
    "logs.to_csv(\"mmd-and-times-summaries.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5718c0bb",
   "metadata": {},
   "source": [
    "# 6. Computing Full-Trace MMDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea372904",
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's set a maximum sample size of 50K to not burn a GPU\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "MAX_SAMPLE_SIZE, SIGMA = 50_000, 0.3545334045797586\n",
    "\n",
    "# helper function for computing MMD\n",
    "def MMD(X1, X2, sigma):\n",
    "    \n",
    "    '''\n",
    "    Gaussian kernel function: k(x, y) = exp(-|| x - y ||^2 / (2\\sigma^2))\n",
    "    - instead of computing full for-loops to get the mean, we will randomly sample.\n",
    "    '''\n",
    "    \n",
    "    # let's work with a smaller cdist matrix for both X1 and X2\n",
    "    if X1.shape[0] > MAX_SAMPLE_SIZE:\n",
    "        X1 = X1[np.random.choice(a=X1.shape[0], size=MAX_SAMPLE_SIZE, replace=False)]\n",
    "    if X2.shape[0] > MAX_SAMPLE_SIZE:\n",
    "        X2 = X2[np.random.choice(a=X2.shape[0], size=MAX_SAMPLE_SIZE, replace=False)]\n",
    "    \n",
    "    # 1/[m(m-1)] * \\sum_i \\sum_{j \\neq i} k(x_i, x_j)\n",
    "    t1 = torch.exp(-(torch.cdist(X1, X1) ** 2) / (2.0 * (sigma ** 2))).mean()\n",
    "    \n",
    "    # 1/[n(n-1)] * \\sum_i \\sum_{j \\neq i} k(y_i, y_j)\n",
    "    t2 = torch.exp(-(torch.cdist(X2, X2) ** 2) / (2.0 * (sigma ** 2))).mean()\n",
    "    \n",
    "    # -2/[mn] * \\sum_i \\sum_j k(x_i, y_j)\n",
    "    t3 = -2.0 * torch.exp(-(torch.cdist(X1, X2) ** 2) / (2.0 * (sigma ** 2))).mean()\n",
    "    \n",
    "    # compute the MMD\n",
    "    return (t1 + t2 + t3).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be894837",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load in our true samples\n",
    "X_truth = torch.tensor(np.load(\"truth/NUTS_blr-gcp_d=24_whiten=True.npy\"), device=device)\n",
    "n_trials, D, counter = 10, 24, 0\n",
    "\n",
    "# set a seed for reproducibility\n",
    "torch.random.manual_seed(858)\n",
    "np.random.seed(858)\n",
    "\n",
    "# create a dataframe to store our results\n",
    "iter_mmd_logs = pd.DataFrame(data=None, columns=[\"B\", \"L\", \"seed\", \"method\", \"iteration\", \"mmd\"])\n",
    "\n",
    "# go thru all possible combinations of batchsize B and sequence length L\n",
    "for B in Bs[::-1]:\n",
    "    for L in Ls[::-1]:\n",
    "        \n",
    "        # what is the maximum number of iterations?\n",
    "        max_iters = int(50 + 5e-4 * L)\n",
    "        \n",
    "        # go thru all the seeds\n",
    "        for seed in range(20):\n",
    "        \n",
    "            # let's do a quick skip\n",
    "            if (B == 32) and (L > 16_000):\n",
    "                print(f\"Skipping {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "                counter += 1\n",
    "                continue\n",
    "            if (B == 16) and (L > 32_000):\n",
    "                print(f\"Skipping {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "                counter += 1\n",
    "                continue\n",
    "\n",
    "            # status update\n",
    "            clear_output(wait=True)\n",
    "\n",
    "            # check whether this run actually gave samples as output\n",
    "            fname = f\"FULL-TRACE_QDEER_B={B}_L={L}_seed={seed}.npz\"\n",
    "            if fname in os.listdir(\"samples\"):\n",
    "\n",
    "                # load via memory-map, so not storing the full thing in memory\n",
    "                with np.load(f\"samples/{fname}\", mmap_mode=\"r\") as data:\n",
    "\n",
    "                    # get the X_samples (still mapped)\n",
    "                    X_samples = torch.tensor(data[\"samples\"], device=device)\n",
    "\n",
    "                    # compute our MMD over multiple trials if necessary\n",
    "                    for iteration in tqdm(range(0, max_iters+1), desc=f\"B={B}, L={L}\"):\n",
    "\n",
    "                        # a. get our iteration's slice\n",
    "                        X_samples_iteration = X_samples[:, iteration, :, :].reshape(-1, D)\n",
    "\n",
    "                        # b. compute the MMD\n",
    "                        if X_samples_iteration.shape[0] > MAX_SAMPLE_SIZE:\n",
    "                            mmd = np.mean(\n",
    "                                [MMD(\n",
    "                                    X1=X_samples_iteration, \n",
    "                                    X2=X_truth, sigma=SIGMA) \n",
    "                                 for _ in range(n_trials)])\n",
    "                        else:\n",
    "                            mmd = MMD(\n",
    "                                X1=X_samples_iteration, \n",
    "                                X2=X_truth, sigma=SIGMA)\n",
    "\n",
    "                        # c. record the MMD\n",
    "                        iter_mmd_logs.loc[len(iter_mmd_logs.index)] = [B, L, seed, \"qdeer\", iteration, mmd]\n",
    "\n",
    "                        # d. clean house\n",
    "                        del X_samples_iteration\n",
    "                        gc.collect()\n",
    "                        torch.cuda.empty_cache()\n",
    "\n",
    "                # clean house again\n",
    "                counter += 1\n",
    "                del X_samples, data\n",
    "                gc.collect()\n",
    "                torch.cuda.empty_cache()\n",
    "\n",
    "            # if filename doesn't exist, just skip it\n",
    "            else:\n",
    "\n",
    "                # just move on\n",
    "                X_samples = None\n",
    "                print(f\"Skipping {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "                counter += 1\n",
    "                continue\n",
    "\n",
    "            # do intermediate checkpointing\n",
    "            iter_mmd_logs.to_csv(\"iter_mmd_logs.csv\", index=False)\n",
    "        \n",
    "# final saving at the end\n",
    "iter_mmd_logs.to_csv(\"iter_mmd_logs.csv\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e66a100e",
   "metadata": {},
   "source": [
    "# 7. Computing L1 Errors Across Batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "6922c341",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Working on setting 840 of 840: B=1, L=1000, seed=19.\n"
     ]
    }
   ],
   "source": [
    "# simulation settings\n",
    "Bs = [1, 2, 4, 8, 16, 32]\n",
    "Ls = [1_000, 2_000, 4_000, 8_000, 16_000, 32_000, 64_000]\n",
    "\n",
    "# start a counter\n",
    "counter = 0\n",
    "\n",
    "# make sure everything is on GPU\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "# create a dictionary to store all of our maximum absolute errors (max over sequence L and dimension D)\n",
    "max_abs_errs_dict = {}\n",
    "\n",
    "# go thru all of our simulation settings\n",
    "for B in Bs[::-1]:\n",
    "    for L in Ls[::-1]:\n",
    "        for seed in range(20):\n",
    "            \n",
    "            # let's do a quick skip over settings that ran out of memory.\n",
    "            if (B == 32) and (L > 16_000):\n",
    "                print(f\"Skipping {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "                counter += 1\n",
    "                continue\n",
    "            if (B == 16) and (L > 32_000):\n",
    "                print(f\"Skipping {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "                counter += 1\n",
    "                continue\n",
    "\n",
    "            # status update\n",
    "            clear_output(wait=True)\n",
    "            print(f\"Working on setting {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "\n",
    "            # check whether this run actually gave samples as output\n",
    "            fname = f\"FULL-TRACE_QDEER_B={B}_L={L}_seed={seed}.npz\"\n",
    "            seq_fname = f\"SEQUENTIAL_B={B}_L={L}_seed={seed}.npz\"\n",
    "            if (fname in os.listdir(\"samples\")) and (seq_fname in os.listdir(\"samples\")):\n",
    "\n",
    "                # load the full-trace Q-DEER samples via memory-map, so not storing the full thing in memory\n",
    "                with np.load(f\"samples/{fname}\", mmap_mode=\"r\") as data:\n",
    "\n",
    "                    # get the X_samples (still mapped)\n",
    "                    X_samples = torch.tensor(data[\"samples\"], device=device)\n",
    "                    \n",
    "                # load the sequential samples\n",
    "                X_sequential = torch.tensor(\n",
    "                    np.load(f\"samples/SEQUENTIAL_B={B}_L={L}_seed={seed}.npz\")[\"samples\"], device=device)\n",
    "                \n",
    "                # compute our errors and save\n",
    "                max_abs_errs = torch.abs(\n",
    "                    X_samples - X_sequential[:, None, :, :])\\\n",
    "                .max(dim=3).values.max(dim=2).values.detach().cpu().numpy()\n",
    "                max_abs_errs_dict[(B, L, seed)] = max_abs_errs\n",
    "                \n",
    "                # increment our counter\n",
    "                counter += 1\n",
    "                \n",
    "            else:\n",
    "                print(f\"Skipping {counter+1} of {len(Bs) * len(Ls) * 20}: B={B}, L={L}, seed={seed}.\")\n",
    "                counter += 1\n",
    "                continue\n",
    "                \n",
    "# save our results to a .pickle\n",
    "with open(\"max_abs_errs_dict.pickle\", \"wb\") as file:\n",
    "    pickle.dump(max_abs_errs_dict, file)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10 (Afterburner)",
   "language": "python",
   "name": "afterburner"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
