{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test MMD bandwidth sensitivity in all scaling experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the configs and set up the plotting "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from torch.distributions import MultivariateNormal\n",
    "\n",
    "from labproject.data import DATASETS, DISTRIBUTIONS, get_dataset\n",
    "from labproject.experiments import *\n",
    "from labproject.plotting import cm2inch, generate_palette\n",
    "from labproject.utils import get_cfg, get_cfg_from_file, get_log_path, set_seed\n",
    "\n",
    "# inline plotting\n",
    "%matplotlib inline\n",
    "\n",
    "print(\"Running experiments...\")\n",
    "# load the config file\n",
    "cfg = get_cfg_from_file(\"conf_mmd_main_scaling_experiment\")\n",
    "cfg.running_user = 'mmd_scaling_experiment'\n",
    "seed = cfg.seed\n",
    "\n",
    "set_seed(seed)\n",
    "print(f\"Seed: {seed}\")\n",
    "print(f\"Experiments: {cfg.experiments}\") \n",
    "print(f\"Data: {cfg.data}\")\n",
    "\n",
    "# assert cfg.data is list\n",
    "assert len(cfg.data) == len(cfg.n) == len(cfg.d), \"Data, n and d must be lists of the same length\"\n",
    "    \n",
    "# setup colors and labels for plotting\n",
    "color_dict = {\"wasserstein\": \"#cc241d\",\n",
    "              \"mmd\": \"#eebd35\",\n",
    "              \"c2st\": \"#458588\",\n",
    "              \"fid\": \"#8ec07c\", \n",
    "              \"kl\": \"#8ec07c\"}\n",
    "\n",
    "col_map = {'ScaleSampleSizeKL':'kl', 'ScaleSampleSizeSW':'wasserstein',\n",
    "           'ScaleSampleSizeMMD':'mmd', 'ScaleSampleSizeC2ST':'c2st',\n",
    "           'ScaleSampleSizeFID':'fid', 'ScaleDimKL':'kl', 'ScaleDimSW':'wasserstein',\n",
    "           'ScaleDimMMD':'mmd', 'ScaleDimC2ST':'c2st', 'ScaleGammaMMD':'mmd',\n",
    "           'ScaleDimFID':'fid',}\n",
    "\n",
    "mapping = {'ScaleSampleSizeKL':'KL', 'ScaleSampleSizeSW':'SW',\n",
    "           'ScaleSampleSizeMMD':'MMD', 'ScaleSampleSizeC2ST':'C2ST',\n",
    "           'ScaleSampleSizeFID':'FD', 'ScaleDimKL':'KL', 'ScaleDimSW':'SW',\n",
    "           'ScaleDimMMD':'MMD', 'ScaleDimC2ST':'C2ST',\n",
    "           'ScaleDimFID':'FD', 'ScaleGammaMMD':'MMD'}\n",
    "\n",
    "# dark and light colors for inter vs. intra comparisons \n",
    "col_dark = {}\n",
    "col_light = {}\n",
    "for e, exp_name in enumerate(cfg.experiments):\n",
    "    col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n",
    "    col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n",
    "for e, exp_name in enumerate(cfg.experiments_dim):\n",
    "    col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n",
    "    col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n",
    "    \n",
    "color_list = [col_light, col_dark] # make this a list to account for true and shifted\n",
    "\n",
    "label_true = {}\n",
    "label_shift = {}\n",
    "for e, data_name in enumerate(cfg.data):\n",
    "    label_true[data_name] = \"true\"\n",
    "    label_shift[data_name] = \"generated\"\n",
    "    \n",
    "label_list = [label_true, label_shift]\n",
    "label_list[1]['toy_2d'] = 'approx.'\n",
    "label_list[1]['random'] = 'shifted'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop over the three datasets for respective MMD bandwidth ranges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make comparison plots\n",
    "fig, axes = plt.subplots(3, 3, figsize=cm2inch((20, 10)), sharex='col')\n",
    "for ax in axes.flatten():\n",
    "    # move spines outward\n",
    "    ax.spines['bottom'].set_position(('outward', 4))\n",
    "    ax.spines['left'].set_position(('outward', 4))\n",
    "    ax.locator_params(nbins=3)\n",
    "\n",
    "# Loop over all three datasets\n",
    "for dd, ds in enumerate(cfg.data):\n",
    "\n",
    "    # specify the bandwidth parameters to vary for each dataset\n",
    "    # custom_values = np.linspace(cfg.val_min[dd], cfg.val_max[dd], cfg.val_step[dd]) # switch for linspacing\n",
    "\n",
    "    dataset_fn = get_dataset(cfg.data[dd])\n",
    "    # sample double the number of samples to ensure variability at the highest samples set size\n",
    "    n_samples = cfg.n[dd] * cfg.runs\n",
    "\n",
    "    # generate the ground truth and the two approximations inter and intra\n",
    "    dataset_gt = dataset_fn(n_samples, cfg.d[dd])\n",
    "    dataset_intra = dataset_fn(n_samples, cfg.d[dd])\n",
    "\n",
    "    print(cfg.data[dd], n_samples, cfg.d[dd])\n",
    "\n",
    "    # generate the inter dataset\n",
    "    if cfg.data[dd] == \"toy_2d\":\n",
    "        dataset_inter = MultivariateNormal(\n",
    "            torch.mean(dataset_gt, axis=0).T, torch.cov(dataset_gt.T)\n",
    "        ).sample((n_samples,))\n",
    "    elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"mean_shift\":\n",
    "        # shift the mean by 1 for all dimensions\n",
    "        dataset_inter = dataset_fn(n_samples, cfg.d[dd]) + 1\n",
    "    elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"one_dim_shift\":\n",
    "        # just shift the first dimension by 1\n",
    "        dataset_inter = dataset_fn(n_samples, cfg.d[dd])\n",
    "        dataset_inter[:, 0] += 1  # just shift the mean of first dim by 1\n",
    "\n",
    "    if dd < 2:  # for the first two datasets, we compare sample sizes\n",
    "        for e, exp_name in enumerate(cfg.experiments):\n",
    "            experiment = globals()[exp_name]()\n",
    "            ax = axes[e, dd]\n",
    "            for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n",
    "                assert (\n",
    "                    dataset_gt.shape == data_comp.shape\n",
    "                ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n",
    "                time_start = time.time()\n",
    "                if mapping[exp_name] == \"MMD\":\n",
    "                    print(f\"MMD {cfg.data[dd]} {dd} {cfg.mmd_bandwidth[dd][e]}\")\n",
    "                    output = experiment.run_experiment(\n",
    "                        dataset1=dataset_gt,\n",
    "                        dataset2=data_comp,\n",
    "                        sample_sizes=cfg.sample_size,\n",
    "                        nb_runs=cfg.runs,\n",
    "                        bandwidth=cfg.mmd_bandwidth[dd][e],\n",
    "                    )\n",
    "                else:\n",
    "                    output = experiment.run_experiment(\n",
    "                        dataset1=dataset_gt,\n",
    "                        dataset2=data_comp,\n",
    "                        sample_sizes=cfg.sample_size,\n",
    "                        nb_runs=cfg.runs,\n",
    "                    )\n",
    "                time_end = time.time()\n",
    "                print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n",
    "\n",
    "                log_path = get_log_path(\n",
    "                    cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_bw_{cfg.mmd_bandwidth[dd][e]}_{dc}\", timestamp=False\n",
    "                )\n",
    "                os.makedirs(os.path.dirname(log_path), exist_ok=True)\n",
    "                experiment.log_results(output, log_path)\n",
    "                print(f\"Numerical results saved to {log_path}\")\n",
    "\n",
    "                experiment.plot_experiment(\n",
    "                    *output,\n",
    "                    cfg.data[dd],\n",
    "                    ax=ax,\n",
    "                    color=color_list[dc][exp_name],\n",
    "                    label=label_list[dc][cfg.data[dd]],\n",
    "                    linestyle=\"-\" if dc == 0 else \"--\",\n",
    "                    lw=2,\n",
    "                    marker=\"o\",\n",
    "                )\n",
    "                ax.set_ylabel(mapping[exp_name] + str(cfg.mmd_bandwidth[dd][e]))\n",
    "                ax.set_xlabel(\"\")\n",
    "                if mapping[exp_name] == \"C2ST\":\n",
    "                    ax.set_ylim([0.45, 1])\n",
    "                    ax.set_yticks([0.5, 1])\n",
    "            ax.legend()\n",
    "    else:  # for the last dataset, we compare dimensions\n",
    "        for e, exp_name in enumerate(cfg.experiments_dim):\n",
    "            experiment = globals()[exp_name]()\n",
    "            ax = axes[e, 2]\n",
    "            ax.set_xscale(\"log\")\n",
    "            for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n",
    "                assert (\n",
    "                    dataset_gt.shape == data_comp.shape\n",
    "                ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n",
    "                time_start = time.time()\n",
    "                if exp_name == \"ScaleDimMMD\":\n",
    "                    output = experiment.run_experiment(\n",
    "                        dataset1=dataset_gt,\n",
    "                        dataset2=data_comp,\n",
    "                        dataset_size=cfg.n[dd],\n",
    "                        dim_sizes=cfg.dim_sizes,\n",
    "                        nb_runs=cfg.runs_dim,  # deterministic\n",
    "                        bandwidth=cfg.mmd_bandwidth[dd][e],\n",
    "                    )\n",
    "                    print(f\"MMD {cfg.data[dd]} {dd} {cfg.mmd_bandwidth[dd][e]}\")\n",
    "                else:\n",
    "                    output = experiment.run_experiment(\n",
    "                        dataset1=dataset_gt,\n",
    "                        dataset2=data_comp,\n",
    "                        dataset_size=cfg.n[dd],\n",
    "                        dim_sizes=cfg.dim_sizes,\n",
    "                        nb_runs=cfg.runs_dim,\n",
    "                    )\n",
    "                time_end = time.time()\n",
    "                print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n",
    "\n",
    "                log_path = get_log_path(\n",
    "                    cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_bw_{cfg.mmd_bandwidth[dd][e]}_{dc}\", timestamp=False\n",
    "                )\n",
    "                os.makedirs(os.path.dirname(log_path), exist_ok=True)\n",
    "                experiment.log_results(output, log_path)\n",
    "                print(f\"Numerical results saved to {log_path}\")\n",
    "                experiment.plot_experiment(\n",
    "                    *output,\n",
    "                    cfg.data[dd],\n",
    "                    ax=ax,\n",
    "                    color=color_list[dc][exp_name],\n",
    "                    label=label_list[dc][cfg.data[dd]],\n",
    "                    linestyle=\"-\" if dc == 0 else \"--\",\n",
    "                    lw=2,\n",
    "                    marker=\"o\",\n",
    "                )\n",
    "                ax.set_ylabel(mapping[exp_name] + str(cfg.mmd_bandwidth[dd][e]))\n",
    "                ax.set_xlabel(\"\")\n",
    "                if mapping[exp_name] == \"C2ST\":\n",
    "                    ax.set_ylim([0.45, 1])\n",
    "                    ax.set_yticks([0.5, 1])\n",
    "\n",
    "            ax.legend()\n",
    "\n",
    "axes[-1, -1].set_xlabel(\"dimensions\")\n",
    "axes[-1, 0].set_xlabel(\"sample size\")\n",
    "axes[-1, 1].set_xlabel(\"sample size\")\n",
    "\n",
    "\n",
    "os.makedirs(\"./results/plots\", exist_ok=True)\n",
    "fig.tight_layout()\n",
    "fig.savefig(\n",
    "    f\"./results/plots/MMD_scaling_{cfg.mmd_bandwidth}_{cfg.n[0]}.png\", dpi=300\n",
    ")\n",
    "fig.savefig(\n",
    "    f\"./results/plots/MMD_scaling_{cfg.mmd_bandwidth}_{cfg.n[0]}.pdf\", dpi=300\n",
    ")\n",
    "\n",
    "print(\"Finished running experiments.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labproject",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
