{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# High-dimensional Gaussian experiment\n",
    "- Find the barycentre of 3 Gaussian marginals\n",
    "- Parameters for Gaussian marginals were generated using the generation procedure from Kolesov et al. (2023)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os  # before importing anything jax\n",
    "\n",
    "# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = \"false\"\n",
    "# os.environ['CUDA_VISIBLE_DEVICES']='5'\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import trange\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "from models import ScoreMLP, BasicModel\n",
    "\n",
    "from run_BarycentreDSBM import BarycentreDSBM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sqrtm_jax(matrix, eps=1e-12):\n",
    "    eigvals, eigvecs = jnp.linalg.eigh(matrix)\n",
    "    sqrt_eigvals = jnp.sqrt(jnp.clip(eigvals, eps))\n",
    "    return eigvecs @ jnp.diag(sqrt_eigvals) @ eigvecs.T\n",
    "\n",
    "class Gaussian:\n",
    "    def __init__(self, shape, mean=0.0, std=1.0):\n",
    "        self.shape = shape\n",
    "        self.mean = mean * jnp.ones(shape)\n",
    "        self.std = std\n",
    "\n",
    "    def sample(self, key, num_samples):\n",
    "        return jax.random.normal(key, (num_samples,) + self.shape) * self.std + self.mean\n",
    "    \n",
    "class GaussianCov:\n",
    "    def __init__(self, shape, mean, cov):\n",
    "        self.shape = shape\n",
    "        self.mean = mean\n",
    "        self.cov = cov\n",
    "        # self.cov_chol = jnp.linalg.cholesky(cov)\n",
    "        self.weight = sqrtm_jax(cov)\n",
    "\n",
    "    def sample(self, key, num_samples):\n",
    "        return jax.random.normal(key, (num_samples,) + self.shape) @ self.weight.T + self.mean\n",
    "        # return jax.random.normal(key, (num_samples,) + self.shape) @ self.cov_chol.T + self.mean\n",
    "    \n",
    "class t_Dist:\n",
    "    def sample(self, key, num_samples):\n",
    "        raise NotImplementedError\n",
    "    \n",
    "class UniformDist(t_Dist):\n",
    "    def sample(self, key, num_samples):\n",
    "        return jax.random.uniform(key, (num_samples,), minval=0.001, maxval=1.0-0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## define the problem, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = 128 # 64, 96, 128\n",
    "shape = (d,)\n",
    "N = 3\n",
    "\n",
    "import json\n",
    "\n",
    "# Load the JSON file\n",
    "with open(f\"../data/gaussian/sampler_stats_{d}.json\", \"r\") as f:\n",
    "    loaded_gaussian_params = json.load(f)\n",
    "\n",
    "# Optionally convert lists back to NumPy arrays\n",
    "for key, value in loaded_gaussian_params.items():\n",
    "    value[\"cov\"] = jnp.array(value[\"cov\"])\n",
    "    value[\"mean\"] = jnp.array(value[\"mean\"])\n",
    "\n",
    "epsilon = 0.0001\n",
    "sigma = jnp.sqrt(epsilon / 2)   # convert from epsilon to sigma\n",
    "\n",
    "\n",
    "mu_lst = []\n",
    "for i in range(N):\n",
    "    mu = GaussianCov(shape=shape, mean=loaded_gaussian_params[f\"sampler_{i}\"][\"mean\"], cov=loaded_gaussian_params[f\"sampler_{i}\"][\"cov\"])\n",
    "    mu_lst.append(mu)\n",
    "\n",
    "ground_truth = GaussianCov(shape=shape, mean=loaded_gaussian_params[\"ground_truth_sampler\"][\"mean\"], cov=loaded_gaussian_params[\"ground_truth_sampler\"][\"cov\"])\n",
    "\n",
    "weights = jnp.ones((N,)) / N\n",
    "weights = weights / jnp.sum(weights)\n",
    "\n",
    "model = BasicModel(out_dim=d, d=d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:19<00:00, 125.23step/s, loss=71] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:14<00:00, 134.08step/s, loss=0.739]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:14<00:00, 134.99step/s, loss=0.528]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running IMF step 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 10000/10000 [01:15<00:00, 131.96step/s, loss=0.519]\n"
     ]
    }
   ],
   "source": [
    "baryDSBM = BarycentreDSBM(\n",
    "    mu_lst=mu_lst,\n",
    "    sigma=sigma,\n",
    "    shape=shape,\n",
    "    model=model,\n",
    ")\n",
    "\n",
    "train_config = OmegaConf.create({\n",
    "    'num_IMF_steps': 4,\n",
    "    'num_sampling_steps': 50,\n",
    "    'num_training_steps': 10_000,\n",
    "    'reflow_num_training_steps': None, # number of training steps for reflow (could be lower if desired)\n",
    "    'num_training_samples': 50_000,  # number of samples to simulate for subsequent IMF iterations\n",
    "    'lr': 1e-3,\n",
    "    'batch_size': 4096,\n",
    "    'simulation_batch_size': 10_000, # if num_training_samples is too large, set this to be smaller to simulate in batches\n",
    "    'ema_rate': 0.01,\n",
    "    'simultaneous_training': True, # True, False,\n",
    "    'warmstart': False, # True, False, # whether to warmstart the model with the params from the first iteration\n",
    "})\n",
    "\n",
    "key = jax.random.PRNGKey(0)\n",
    "all_states_lst, all_bms_lst = baryDSBM.train(key, train_config=train_config, model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "- Using the BW2-UVP and L2-UVP metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from Korotin et al. (2023), https://github.com/iamalexkorotin/Wasserstein2Barycenters\n",
    "\n",
    "import numpy as np\n",
    "import scipy.linalg as ln\n",
    "\n",
    "def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):\n",
    "    \"\"\"Numpy implementation of the Frechet Distance.\n",
    "    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)\n",
    "    and X_2 ~ N(mu_2, C_2) is\n",
    "            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).\n",
    "    Stable version by Dougal J. Sutherland.\n",
    "    Params:\n",
    "    -- mu1   : Numpy array containing the activations of a layer of the\n",
    "               inception net (like returned by the function 'get_predictions')\n",
    "               for generated samples.\n",
    "    -- mu2   : The sample mean over activations, precalculated on an\n",
    "               representative data set.\n",
    "    -- sigma1: The covariance matrix over activations for generated samples.\n",
    "    -- sigma2: The covariance matrix over activations, precalculated on an\n",
    "               representative data set.\n",
    "    Returns:\n",
    "    --   : The Frechet Distance.\n",
    "    \"\"\"\n",
    "\n",
    "    mu1 = np.atleast_1d(mu1)\n",
    "    mu2 = np.atleast_1d(mu2)\n",
    "\n",
    "    sigma1 = np.atleast_2d(sigma1)\n",
    "    sigma2 = np.atleast_2d(sigma2)\n",
    "\n",
    "    assert mu1.shape == mu2.shape, \\\n",
    "        'Training and test mean vectors have different lengths'\n",
    "    assert sigma1.shape == sigma2.shape, \\\n",
    "        'Training and test covariances have different dimensions'\n",
    "\n",
    "    diff = mu1 - mu2\n",
    "\n",
    "    # Product might be almost singular\n",
    "    covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)\n",
    "    if not np.isfinite(covmean).all():\n",
    "        msg = ('fid calculation produces singular product; '\n",
    "               'adding %s to diagonal of cov estimates') % eps\n",
    "        print(msg)\n",
    "        offset = np.eye(sigma1.shape[0]) * eps\n",
    "        covmean = ln.sqrtm((sigma1 + offset).dot(sigma2 + offset))\n",
    "\n",
    "    # Numerical error might give slight imaginary component\n",
    "    if np.iscomplexobj(covmean):\n",
    "        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):\n",
    "            m = np.max(np.abs(covmean.imag))\n",
    "            raise ValueError('Imaginary component {}'.format(m))\n",
    "        covmean = covmean.real\n",
    "\n",
    "    tr_covmean = np.trace(covmean)\n",
    "\n",
    "    return (diff.dot(diff) + np.trace(sigma1) +\n",
    "            np.trace(sigma2) - 2 * tr_covmean)\n",
    "\n",
    "# get ground-truth maps\n",
    "# based on Korotin et al. (2022)\n",
    "\n",
    "def get_map_to_barycentre_from_params(mean_leaf, cov_leaf, mean_bary, cov_bary):\n",
    "    root_cov_leaf = sqrtm_jax(cov_leaf)\n",
    "    inv_root_cov_leaf = jnp.linalg.inv(root_cov_leaf)\n",
    "    middle = root_cov_leaf @ cov_bary @ root_cov_leaf\n",
    "    sqrt_middle = sqrtm_jax(middle)\n",
    "    weight = inv_root_cov_leaf @ sqrt_middle @ inv_root_cov_leaf\n",
    "    bias = mean_bary - weight @ mean_leaf  # if both are zero, bias = 0\n",
    "\n",
    "    def map_to_barycentre(x):\n",
    "        return x @ weight.T + bias\n",
    "    \n",
    "    return map_to_barycentre, weight, bias\n",
    "\n",
    "maps_to_barycentre = []\n",
    "for i in range(N):\n",
    "    mean_leaf = np.array(mu_lst[i].mean, dtype=np.float64)\n",
    "    cov_leaf = np.array(mu_lst[i].cov, dtype=np.float64)\n",
    "    mean_bary = np.array(ground_truth.mean, dtype=np.float64)\n",
    "    cov_bary = np.array(ground_truth.cov, dtype=np.float64)\n",
    "\n",
    "    map_to_barycentre,_,_ = get_map_to_barycentre_from_params(mean_leaf, cov_leaf, mean_bary, cov_bary)\n",
    "    maps_to_barycentre.append(map_to_barycentre)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 50\n",
    "num_samples = 100_000\n",
    "\n",
    "def get_UVP_along_edge(key, state, bm, edge_idx):\n",
    "\n",
    "    drift_fn = bm.get_drift_fn(state, use_ema_params=True, fwd=True)\n",
    "    traj, nu_samples = bm.sample(key, drift_fn, num_samples, num_steps, fwd=True)\n",
    "\n",
    "    # BW2-UVP\n",
    "    bary_samples_cov = jnp.cov(nu_samples.T)\n",
    "    bary_samples_mean = jnp.mean(nu_samples, axis=0)\n",
    "\n",
    "    ground_truth_mean = ground_truth.mean\n",
    "    ground_truth_cov = ground_truth.cov\n",
    "    ground_truth_var = jnp.trace(ground_truth_cov)\n",
    "\n",
    "    BW2_UVP = 100 * calculate_frechet_distance(\n",
    "                bary_samples_mean, bary_samples_cov,\n",
    "                ground_truth_mean, ground_truth_cov,\n",
    "            ) / ground_truth_var\n",
    "    \n",
    "    # L2-UVP\n",
    "    ground_truth_transported_samples = maps_to_barycentre[edge_idx](traj[:,0])\n",
    "    diffs = ground_truth_transported_samples - nu_samples\n",
    "    L2_UVP = 100 * (jnp.mean(jnp.sum(diffs**2, axis=1))) / jnp.trace(ground_truth.cov)\n",
    "    \n",
    "    return BW2_UVP, L2_UVP\n",
    "\n",
    "def get_UVPs_from_run(key, states_lst, bm_lst):\n",
    "\n",
    "    run_BW2_UVP = []\n",
    "    run_L2_UVP = []\n",
    "    for i in range(len(states_lst)):\n",
    "        state = states_lst[i]\n",
    "        bm = bm_lst[i]\n",
    "        BW2_UVP, L2_UVP = get_UVP_along_edge(key, state, bm, i)\n",
    "        run_BW2_UVP.append(BW2_UVP)\n",
    "        run_L2_UVP.append(L2_UVP)\n",
    "    return jnp.array(run_BW2_UVP), jnp.array(run_L2_UVP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-08-07 14:57:28.190312: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3425] Can't reduce memory use below 5.13GiB (5508008948 bytes) by rematerialization; only reduced to 5.15GiB (5534452594 bytes), down from 5.15GiB (5534452594 bytes) originally\n",
      "/tmp/ipykernel_3737450/663433154.py:39: DeprecationWarning: The `disp` argument is deprecated and will be removed in SciPy 1.18.0.\n",
      "  covmean, _ = ln.sqrtm(sigma1.dot(sigma2), disp=False)\n",
      "2025-08-07 14:57:31.208034: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3425] Can't reduce memory use below 5.13GiB (5508008948 bytes) by rematerialization; only reduced to 5.15GiB (5534452594 bytes), down from 5.15GiB (5534452594 bytes) originally\n",
      "2025-08-07 14:57:32.993137: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3425] Can't reduce memory use below 5.13GiB (5508008948 bytes) by rematerialization; only reduced to 5.15GiB (5534452594 bytes), down from 5.15GiB (5534452594 bytes) originally\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BW2_UVPs for each edge: [0.293589   0.26861173 0.28548774]\n",
      "Average BW2_UVP: 0.28256282\n",
      "L2_UVPs for each edge: [1.2588215 1.2496533 1.2154801]\n",
      "Average L2_UVP: 1.2413183\n"
     ]
    }
   ],
   "source": [
    "key = jax.random.PRNGKey(0)\n",
    "\n",
    "IMF_idx = -1  # use the last IMF\n",
    "states_lst = all_states_lst[IMF_idx]\n",
    "bm_lst = all_bms_lst[IMF_idx]\n",
    "\n",
    "BW2_UVPs, L2_UVPs = get_UVPs_from_run(key, states_lst, bm_lst)\n",
    "\n",
    "print(\"BW2_UVPs for each edge:\", BW2_UVPs)\n",
    "print(\"Average BW2_UVP:\", jnp.mean(BW2_UVPs))\n",
    "\n",
    "print(\"L2_UVPs for each edge:\", L2_UVPs)\n",
    "print(\"Average L2_UVP:\", jnp.mean(L2_UVPs))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
