{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the configs and set up the plotting "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "import os\n",
    "import pickle\n",
    "import sys\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "from torch.distributions import MultivariateNormal\n",
    "\n",
    "from labproject.data import DATASETS, DISTRIBUTIONS, get_dataset\n",
    "from labproject.experiments import *\n",
    "from labproject.plotting import cm2inch, generate_palette\n",
    "from labproject.utils import get_cfg, get_cfg_from_file, get_log_path, set_seed\n",
    "\n",
    "# inline plotting\n",
    "%matplotlib inline\n",
    "\n",
    "print(\"Running experiments...\")\n",
    "# load the config file\n",
    "cfg = get_cfg_from_file(\"conf_mmd_bandwidth_experiment\")\n",
    "cfg.running_user = 'MMD_scale_bandwidth'\n",
    "seed = cfg.seed\n",
    "\n",
    "set_seed(seed)\n",
    "print(f\"Seed: {seed}\")\n",
    "print(f\"Experiments: {cfg.experiments}\") \n",
    "print(f\"Data: {cfg.data}\")\n",
    "\n",
    "# assert cfg.data is list\n",
    "assert len(cfg.data) == len(cfg.n) == len(cfg.d), \"Data, n and d must be lists of the same length\"\n",
    "    \n",
    "# setup colors and labels for plotting\n",
    "color_dict = {\"wasserstein\": \"#cc241d\",\n",
    "              \"mmd\": \"#eebd35\",\n",
    "              \"c2st\": \"#458588\",\n",
    "              \"fid\": \"#8ec07c\", \n",
    "              \"kl\": \"#8ec07c\"}\n",
    "\n",
    "col_map = {'ScaleSampleSizeKL':'kl', 'ScaleSampleSizeSW':'wasserstein',\n",
    "           'ScaleSampleSizeMMD':'mmd', 'ScaleSampleSizeC2ST':'c2st',\n",
    "           'ScaleSampleSizeFID':'fid', 'ScaleDimKL':'kl', 'ScaleDimSW':'wasserstein',\n",
    "           'ScaleDimMMD':'mmd', 'ScaleDimC2ST':'c2st', 'ScaleGammaMMD':'mmd',\n",
    "           'ScaleDimFID':'fid',}\n",
    "\n",
    "mapping = {'ScaleSampleSizeKL':'KL', 'ScaleSampleSizeSW':'SW',\n",
    "           'ScaleSampleSizeMMD':'MMD', 'ScaleSampleSizeC2ST':'C2ST',\n",
    "           'ScaleSampleSizeFID':'FD', 'ScaleDimKL':'KL', 'ScaleDimSW':'SW',\n",
    "           'ScaleDimMMD':'MMD', 'ScaleDimC2ST':'C2ST',\n",
    "           'ScaleDimFID':'FD', 'ScaleGammaMMD':'MMD'}\n",
    "\n",
    "# dark and light colors for inter vs. intra comparisons \n",
    "col_dark = {}\n",
    "col_light = {}\n",
    "for e, exp_name in enumerate(cfg.experiments):\n",
    "    col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n",
    "    col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n",
    "\n",
    "color_list = [col_light, col_dark] # make this a list to account for true and shifted\n",
    "\n",
    "label_true = {}\n",
    "label_shift = {}\n",
    "for e, data_name in enumerate(cfg.data):\n",
    "    label_true[data_name] = \"true\"\n",
    "    label_shift[data_name] = \"generated\"\n",
    "    \n",
    "label_list = [label_true, label_shift]\n",
    "label_list[1]['toy_2d'] = 'approx.'\n",
    "label_list[1]['random'] = 'shifted'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loop over the three datasets for respective MMD bandwidth ranges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 3, figsize=cm2inch((18, 4)), sharex=\"col\")\n",
    "\n",
    "for ax in axes.flatten():\n",
    "    # move spines outward\n",
    "    ax.spines[\"bottom\"].set_position((\"outward\", 4))\n",
    "    ax.spines[\"left\"].set_position((\"outward\", 4))\n",
    "    ax.locator_params(axis=\"y\", nbins=3)\n",
    "    ax.locator_params(axis=\"x\", nbins=8)\n",
    "\n",
    "# Loop over all three datasets\n",
    "for dd, ds in enumerate(cfg.data):\n",
    "\n",
    "    # specify the bandwidth parameters to vary for each dataset\n",
    "    # custom_values = np.linspace(cfg.val_min[dd], cfg.val_max[dd], cfg.val_step[dd]) # switch for linspacing\n",
    "    custom_values = np.array(cfg.value_sizes[dd])\n",
    "\n",
    "    dataset_fn = get_dataset(cfg.data[dd])\n",
    "    # sample double the number of samples to ensure variability at the highest samples set size\n",
    "    n_samples = cfg.n[dd] * 2\n",
    "\n",
    "    # generate the ground truth and the two approximations inter and intra\n",
    "    dataset_gt = dataset_fn(n_samples, cfg.d[dd])\n",
    "    dataset_intra = dataset_fn(n_samples, cfg.d[dd])\n",
    "\n",
    "    print(cfg.data[dd], n_samples, cfg.d[dd])\n",
    "\n",
    "    # generate the inter dataset\n",
    "    if cfg.data[dd] == \"toy_2d\":\n",
    "        dataset_inter = MultivariateNormal(\n",
    "            torch.mean(dataset_gt, axis=0).T, torch.cov(dataset_gt.T)\n",
    "        ).sample((n_samples,))\n",
    "    elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"mean_shift\":\n",
    "        # shift the mean by 1 for all dimensions\n",
    "        dataset_inter = dataset_fn(n_samples, cfg.d[dd]) + 1\n",
    "    elif cfg.data[dd] == \"random\" and cfg.augmentation[dd] == \"one_dim_shift\":\n",
    "        # just shift the first dimension by 1\n",
    "        dataset_inter = dataset_fn(n_samples, cfg.d[dd])\n",
    "        dataset_inter[:, 0] += 1  # just shift the mean of first dim by 1\n",
    "\n",
    "    experiment = globals()[exp_name](value_sizes=custom_values)\n",
    "    ax = axes[dd]\n",
    "\n",
    "    for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n",
    "\n",
    "        assert (\n",
    "            dataset_gt.shape == data_comp.shape\n",
    "        ), f\"Dataset shapes do not match: {dataset_gt.shape} vs. {data_comp.shape}\"\n",
    "\n",
    "        time_start = time.time()\n",
    "        output = experiment.run_experiment(dataset1=dataset_gt, dataset2=data_comp, n=cfg.n[dd])\n",
    "        time_end = time.time()\n",
    "        print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n",
    "\n",
    "        # log the results to the experiment folder\n",
    "        log_path = get_log_path(\n",
    "            cfg, tag=f\"_{mapping[exp_name]}_{cfg.data[dd]}_ds_{dd}_{dc}\", timestamp=False\n",
    "        )\n",
    "        os.makedirs(os.path.dirname(log_path), exist_ok=True)\n",
    "        experiment.log_results(output, log_path)\n",
    "        print(f\"Numerical results saved to {log_path}\")\n",
    "\n",
    "        # plot experiment results\n",
    "        experiment.plot_experiment(\n",
    "            *output,\n",
    "            cfg.data[dd],\n",
    "            ax=ax,\n",
    "            color=color_list[dc][exp_name],\n",
    "            label=label_list[dc][cfg.data[dd]],\n",
    "            linestyle=\"-\" if dc == 0 else \"--\",\n",
    "            lw=2,\n",
    "            marker=\"o\",\n",
    "        )\n",
    "    ax.legend()\n",
    "\n",
    "axes[-1].set_xlabel(\"bandwidth\")\n",
    "axes[0].set_xlabel(\"bandwidth\")\n",
    "axes[1].set_xlabel(\"bandwidth\")\n",
    "\n",
    "fig.tight_layout()\n",
    "# if folder results/plots does not exist, create it\n",
    "os.makedirs(\"./results/plots\", exist_ok=True)\n",
    "fig.savefig(f\"./results/plots/MMD_scale_bandwidth_{cfg.n[0]}.png\", dpi=300)\n",
    "fig.savefig(f\"./results/plots/MMD_scale_bandwidth_{cfg.n[0]}.pdf\", dpi=300)\n",
    "\n",
    "print(\"Finished running experiments.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labproject",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
