{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Aggregation experiment\n",
    "\n",
    "- applies TreeDSBM on the bike dataset example from Korotin et al. 2021."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os  # before importing anything jax\n",
    "\n",
    "# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = \"false\"\n",
    "# os.environ['CUDA_VISIBLE_DEVICES']='4'\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "from tqdm import trange\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from models import ScoreMLP, BasicModel\n",
    "\n",
    "from run_BarycentreDSBM import BarycentreDSBM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define distributions\n",
    "- follows the preprocessing used in Korotin et al. 2021."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100000, 8), (100000, 8), (100000, 8), (100000, 8), (100000, 8), (100000, 8))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scale = jnp.sqrt(1e7)\n",
    "\n",
    "dataset_1 = jnp.load('../data/bike_posterior/samples_0_0.npy')[:, 1:]\n",
    "dataset_1 -= dataset_1.mean(axis=0)\n",
    "dataset_1 *= scale\n",
    "\n",
    "dataset_2 = jnp.load('../data/bike_posterior/samples_0_1.npy')[:, 1:]\n",
    "dataset_2 -= dataset_2.mean(axis=0)\n",
    "dataset_2 *= scale\n",
    "\n",
    "dataset_3 = jnp.load('../data/bike_posterior/samples_0_2.npy')[:, 1:]\n",
    "dataset_3 -= dataset_3.mean(axis=0)\n",
    "dataset_3 *= scale\n",
    "\n",
    "dataset_4 = jnp.load('../data/bike_posterior/samples_0_3.npy')[:, 1:]\n",
    "dataset_4 -= dataset_4.mean(axis=0)\n",
    "dataset_4 *= scale\n",
    "\n",
    "dataset_5 = jnp.load('../data/bike_posterior/samples_0_4.npy')[:, 1:]\n",
    "dataset_5 -= dataset_5.mean(axis=0)\n",
    "dataset_5 *= scale\n",
    "\n",
    "ground_truth = jnp.load('../data/bike_posterior/samples_0_all.npy')[:, 1:]\n",
    "ground_truth -= ground_truth.mean(axis=0)\n",
    "ground_truth *= scale\n",
    "\n",
    "class DatasetDist:\n",
    "    def __init__(self, dataset):\n",
    "        self.dataset = dataset\n",
    "\n",
    "    @partial(jax.jit, static_argnums=(0,2))\n",
    "    def sample(self, key, num_samples):\n",
    "        return jax.random.choice(key, self.dataset, shape=(num_samples,), replace=True)\n",
    "    \n",
    "class t_Dist:\n",
    "    def sample(self, key, num_samples):\n",
    "        raise NotImplementedError\n",
    "    \n",
    "class UniformDist(t_Dist):\n",
    "    def sample(self, key, num_samples):\n",
    "        return jax.random.uniform(key, (num_samples,), minval=0.001, maxval=1.0-0.001)\n",
    "\n",
    "dataset_1.shape, dataset_2.shape, dataset_3.shape, dataset_4.shape, dataset_5.shape, ground_truth.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define the problem, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 8\n",
    "shape = (d,)\n",
    "N = 5\n",
    "\n",
    "epsilon = 0.001\n",
    "sigma = jnp.sqrt(epsilon / 2)   # convert from epsilon to sigma\n",
    "\n",
    "# define the fixed marginals\n",
    "mu_0 = DatasetDist(dataset_1)\n",
    "mu_1 = DatasetDist(dataset_2)\n",
    "mu_2 = DatasetDist(dataset_3)\n",
    "mu_3 = DatasetDist(dataset_4)\n",
    "mu_4 = DatasetDist(dataset_5)\n",
    "\n",
    "mu_lst = [mu_0, mu_1, mu_2, mu_3, mu_4]\n",
    "weights = jnp.ones((N,)) / N\n",
    "weights = weights / jnp.sum(weights)\n",
    "\n",
    "model = BasicModel(out_dim=d, d=d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:38<00:00, 101.22step/s, loss=13.5]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:35<00:00, 104.88step/s, loss=0.698]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:33<00:00, 106.62step/s, loss=0.638]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:32<00:00, 107.80step/s, loss=0.643]\n"
     ]
    }
   ],
   "source": [
    "baryDSBM = BarycentreDSBM(\n",
    "    mu_lst=mu_lst,\n",
    "    sigma=sigma,\n",
    "    shape=shape,\n",
    "    model=model,\n",
    ")\n",
    "\n",
    "train_config = OmegaConf.create({\n",
    "    'num_IMF_steps': 4,\n",
    "    'num_sampling_steps': 100,\n",
    "    'num_training_steps': 10_000,\n",
    "    'reflow_num_training_steps': None, # number of training steps for reflow (could be lower if desired)\n",
    "    'num_training_samples': 8192,  # number of samples to simulate for subsequent IMF iterations\n",
    "    'lr': 1e-3,\n",
    "    'batch_size': 4096,\n",
    "    'simulation_batch_size': None, # if num_training_samples is too large, set this to be smaller to simulate in batches\n",
    "    'ema_rate': 0.01,\n",
    "    'simultaneous_training': True, # True, False,\n",
    "    'warmstart': False, # True, False, # whether to warmstart the model with the params from the first iteration\n",
    "})\n",
    "\n",
    "key = jax.random.PRNGKey(0)\n",
    "all_states_lst, all_bms_lst = baryDSBM.train(key, train_config=train_config, model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "- using the BW2-UVP metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from Korotin et al. (2023), https://github.com/iamalexkorotin/Wasserstein2Barycenters\n",
    "\n",
    "import numpy as np\n",
    "import scipy.linalg as ln\n",
    "\n",
    "def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n",
    "    \"\"\"Numpy implementation of the Frechet Distance.\n",
    "    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n",
    "    and X_2 ~ N(mu_2, C_2) is\n",
    "            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n",
    "    Stable version by Dougal J. Sutherland.\n",
    "    Params:\n",
    "    -- mu1   : Numpy array containing the activations of a layer of the\n",
    "               inception net (like returned by the function 'get_predictions')\n",
    "               for generated samples.\n",
    "    -- mu2   : The sample mean over activations, precalculated on an\n",
    "               representative data set.\n",
    "    -- sigma1: The covariance matrix over activations for generated samples.\n",
    "    -- sigma2: The covariance matrix over activations, precalculated on an\n",
    "               representative data set.\n",
    "    Returns:\n",
    "    --   : The Frechet Distance.\n",
    "    \"\"\"\n",
    "\n",
    "    mu1 = np.atleast_1d(mu1)\n",
    "    mu2 = np.atleast_1d(mu2)\n",
    "\n",
    "    sigma1 = np.atleast_2d(sigma1)\n",
    "    sigma2 = np.atleast_2d(sigma2)\n",
    "\n",
    "    assert mu1.shape == mu2.shape, \\\n",
    "        'Training and test mean vectors have different lengths'\n",
    "    assert sigma1.shape == sigma2.shape, \\\n",
    "        'Training and test covariances have different dimensions'\n",
    "\n",
    "    diff = mu1 - mu2\n",
    "\n",
    "    # Product might be almost singular\n",
    "    covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "    if not np.isfinite(covmean).all():\n",
    "        msg = ('fid calculation produces singular product; '\n",
    "               'adding %s to diagonal of cov estimates') % eps\n",
    "        print(msg)\n",
    "        offset = np.eye(sigma1.shape[0]) * eps\n",
    "        covmean = ln.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n",
    "\n",
    "    # Numerical error might give slight imaginary component\n",
    "    if np.iscomplexobj(covmean):\n",
    "        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n",
    "            m = np.max(np.abs(covmean.imag))\n",
    "            raise ValueError('Imaginary component {}'.format(m))\n",
    "        covmean = covmean.real\n",
    "\n",
    "    tr_covmean = np.trace(covmean)\n",
    "\n",
    "    return (diff.dot(diff) + np.trace(sigma1) +\n",
    "            np.trace(sigma2) - 2 * tr_covmean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 50\n",
    "num_samples = 100_000\n",
    "\n",
    "def get_UVP_along_edge(key, state, bm):\n",
    "\n",
    "    drift_fn = bm.get_drift_fn(state, use_ema_params=True, fwd=True)\n",
    "    _, nu_samples = bm.sample(key, drift_fn, num_samples, num_steps, fwd=True)\n",
    "\n",
    "    bary_samples_cov = np.cov(nu_samples.T)\n",
    "    bary_samples_mean = np.mean(nu_samples, axis=0)\n",
    "\n",
    "    ground_truth_mean = np.mean(ground_truth, axis=0)\n",
    "    ground_truth_cov = np.cov(ground_truth.T)\n",
    "    ground_truth_var = np.trace(ground_truth_cov)\n",
    "\n",
    "    UVP = 100 * calculate_frechet_distance(\n",
    "                bary_samples_mean, bary_samples_cov,\n",
    "                ground_truth_mean, ground_truth_cov,\n",
    "            ) / ground_truth_var\n",
    "    \n",
    "    return UVP\n",
    "\n",
    "def get_UVPs_from_run(key, states_lst, bm_lst):\n",
    "\n",
    "    run_UVP = []\n",
    "    for i in range(len(states_lst)):\n",
    "        state = states_lst[i]\n",
    "        bm = bm_lst[i]\n",
    "        UVP = get_UVP_along_edge(key, state, bm)\n",
    "        run_UVP.append(UVP)\n",
    "    return jnp.array(run_UVP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3722727/4166324493.py:39: DeprecationWarning: The `disp` argument is deprecated and will be removed in SciPy 1.18.0.\n",
      "  covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UVPs for each edge: [0.01108639 0.01151861 0.00821571 0.01161379 0.01200999]\n",
      "Average UVP: 0.010888897\n"
     ]
    }
   ],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "IMF_idx = -1  # use the last IMF\n",
    "states_lst = all_states_lst[IMF_idx]\n",
    "bm_lst = all_bms_lst[IMF_idx]\n",
    "\n",
    "UVPs = get_UVPs_from_run(key, states_lst, bm_lst)\n",
    "\n",
    "print(\"UVPs for each edge:\", UVPs)\n",
    "print(\"Average UVP:\", jnp.mean(UVPs))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
