{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform supplementary dimensionality scaling experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the configs and setup to the plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "from torch.distributions import MultivariateNormal\n",
    "import matplotlib.pyplot as plt\n",
    "# inline plotting\n",
    "%matplotlib inline\n",
    "\n",
    "from labproject.experiments import Experiment, ScaleDim, ScaleSampleSize\n",
    "from labproject.utils import get_log_path, get_cfg, set_seed, get_cfg_from_file\n",
    "from labproject.metrics import METRICS\n",
    "from labproject.data import DATASETS, DISTRIBUTIONS, get_dataset\n",
    "from labproject.experiments import *\n",
    "from labproject.plotting import cm2inch, generate_palette, color_dict\n",
    "\n",
    "print(\"Running experiments...\")\n",
    "# load the config file\n",
    "cfg = get_cfg_from_file(\"conf_supp_dim_scaling_experiment\")\n",
    "cfg.running_user = 'scaling_dims'\n",
    "seed = cfg.seed\n",
    "\n",
    "set_seed(seed)\n",
    "print(f\"Seed: {seed}\")\n",
    "print(f\"Experiments: {cfg.experiments}\") \n",
    "print(f\"Data: {cfg.data}\")\n",
    "\n",
    "# assert cfg.data is list\n",
    "assert len(cfg.data) == len(cfg.n) == len(cfg.d), \"Data, n and d must be lists of the same length\"\n",
    "    \n",
    "# setup the colors\n",
    "color_dict = {\"wasserstein\": \"#cc241d\",\n",
    "              \"mmd\": \"#eebd35\",\n",
    "              \"c2st\": \"#458588\",\n",
    "              \"fid\": \"#8ec07c\", \n",
    "              \"kl\": \"#8ec07c\"}\n",
    "\n",
    "\n",
    "col_map = {'ScaleDimKL':'kl', 'ScaleDimSW':'wasserstein',\n",
    "           'ScaleDimMMD':'mmd', 'ScaleDimC2ST':'c2st',\n",
    "           'ScaleDimFID':'fid', }\n",
    "\n",
    "# dark and light colors for inter vs. intra comparisons \n",
    "col_dark = {}\n",
    "col_light = {}\n",
    "for e, exp_name in enumerate(cfg.experiments):\n",
    "    col_dark[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='dark')[2]\n",
    "    col_light[exp_name] = generate_palette(color_dict[col_map[exp_name]], saturation='light')[-1]\n",
    "color_list = [col_light, col_dark] # make this a list to account for true and shifted\n",
    "\n",
    "# Mapping to the abbrev for the distances\n",
    "mapping = {'ScaleDimKL':'KL', 'ScaleDimSW':'SW',\n",
    "        'ScaleDimMMD':'MMD', 'ScaleDimC2ST':'C2ST',\n",
    "        'ScaleDimFID':'FD'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loop over the three dataset conditions for the dimensionality sclaing experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make comparison plots\n",
    "fig, axes = plt.subplots(len(cfg.experiments), len(cfg.distort), figsize=cm2inch((20, 10)), sharex='col')\n",
    "for ax in axes.flatten():\n",
    "    # move spines outward\n",
    "    ax.spines['bottom'].set_position(('outward', 4))\n",
    "    ax.spines['left'].set_position(('outward', 4))\n",
    "    ax.locator_params(nbins=4)\n",
    "\n",
    "# Base datasets\n",
    "dataset_fn = get_dataset(name=cfg.data[0])\n",
    "dataset_gt = dataset_fn(cfg.n[0]*cfg.runs, cfg.d[0], distort=None)\n",
    "dataset_intra = dataset_fn(cfg.n[0]*cfg.runs, cfg.d[0], distort=None)\n",
    "print(f\"Dataset shape: {dataset_gt.shape}\")\n",
    "print(f\"Dataset shape: {dataset_intra.shape}\")\n",
    "\n",
    "for d, distort in enumerate(cfg.distort):\n",
    "    print(f\"Distort: {distort.upper()}\")\n",
    "    label_true = {}\n",
    "    label_shift = {}\n",
    "    for e, data_name in enumerate(cfg.data):\n",
    "        label_true[data_name] = \"true\"\n",
    "        if distort == 'shift_one' or distort=='shift_all':\n",
    "            label_shift[data_name] = 'shifted'\n",
    "        if distort == 'increase_var':\n",
    "            label_shift[data_name] = 'increased var'\n",
    "    label_list = [label_true, label_shift]\n",
    "\n",
    "    # Distorted \n",
    "    dataset_inter = dataset_fn(cfg.n[0]*cfg.runs, cfg.d[0], distort=cfg.distort[d])\n",
    "    dataset_inter_var = dataset_fn(cfg.n[0]*cfg.runs, cfg.d[0], distort=cfg.distort[d])\n",
    "\n",
    "    for e, exp_name in enumerate(cfg.experiments):\n",
    "        experiment = globals()[exp_name]()\n",
    "        dataset1 = dataset_gt\n",
    "        ax = axes[e, d]\n",
    "        ax.set_xscale('log')\n",
    "        for dc, data_comp in enumerate([dataset_intra, dataset_inter]):\n",
    "            dataset2 = data_comp\n",
    "            assert dataset1.shape == dataset2.shape, f\"Dataset shapes do not match: {dataset1.shape} vs. {dataset2.shape}\"\n",
    "            time_start = time.time()\n",
    "            if mapping[exp_name]  == 'MMD':\n",
    "                print(f'MMD {cfg.data} {d} {cfg.mmd_bandwidth[d]}')\n",
    "                output = experiment.run_experiment(dataset1=dataset1,\n",
    "                                                dataset2=dataset2,\n",
    "                                                dataset_size = cfg.n[0],\n",
    "                                                dim_sizes=cfg.dim_sizes,\n",
    "                                                nb_runs = cfg.runs,\n",
    "                                                bandwidth=cfg.mmd_bandwidth[d])\n",
    "            else:\n",
    "                output = experiment.run_experiment(dataset1=dataset1,\n",
    "                                                    dataset2=dataset2,\n",
    "                                                    dataset_size = cfg.n[0],\n",
    "                                                    dim_sizes=cfg.dim_sizes,\n",
    "                                                    nb_runs = cfg.runs)\n",
    "            time_end = time.time()\n",
    "            print(f\"Experiment {exp_name} finished in {time_end - time_start}\")\n",
    "\n",
    "            log_path = get_log_path(cfg)\n",
    "            os.makedirs(os.path.dirname(log_path), exist_ok=True)\n",
    "            experiment.log_results(output, log_path)\n",
    "            print(f\"Numerical results saved to {log_path}\")\n",
    "            experiment.plot_experiment(\n",
    "                *output,\n",
    "                cfg.data[0],\n",
    "                ax=ax,\n",
    "                color=color_list[dc][exp_name],\n",
    "                label=label_list[dc][cfg.data[0]],\n",
    "                linestyle='-' if dc == 0 else '--',\n",
    "                lw=2, \n",
    "                marker='o'\n",
    "            )\n",
    "            if mapping[exp_name]  == 'MMD':\n",
    "                ax.set_ylabel(mapping[exp_name]+ str(int(cfg.mmd_bandwidth[d])))\n",
    "            else:\n",
    "                ax.set_ylabel(mapping[exp_name])\n",
    "            ax.set_xlabel('')\n",
    "            if mapping[exp_name]  == 'C2ST':\n",
    "                ax.set_ylim([0.45,1.05])\n",
    "                ax.set_yticks([0.5,1.0])\n",
    "            \n",
    "            # Make all y-labels at the same position\n",
    "            ax.yaxis.set_label_coords(-0.25, 0.5)\n",
    "            \n",
    "        ax.legend()\n",
    "\n",
    "axes[-1, 0].set_xlabel('dimensions')\n",
    "axes[-1, 1].set_xlabel('dimensions')\n",
    "axes[-1, 2].set_xlabel('dimensions')\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(f\"./results/plots/{cfg.exp_log_name}_metric_comparison_sim_budget_{cfg.n[0]}_dim_size_{cfg.data[0]}_adapted_bandwidth.png\", dpi=300)\n",
    "fig.savefig(f\"./results/plots/{cfg.exp_log_name}_metric_comparison_sim_budget_{cfg.n[0]}_dim_size_{cfg.data[0]}_adapted_bandwidth.pdf\", dpi=300)\n",
    "\n",
    "print(f\"Plots saved\")\n",
    "print(\"Finished running experiments.\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "labproject",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
