{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "'''\n",
        "This code reproduces the results from Fig 2, where similarity measures\n",
        "are compared for 9 synthetic representation spaces\n",
        "\n",
        "To evaluate the stochastic shape metrics, please install the netrep package:\n",
        "https://github.com/ahwillia/netrep\n",
        "'''"
      ],
      "metadata": {
        "id": "12yoWes5Kd1Y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qyjLuSAE5Vjx"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline\n",
        "import tensorflow as tf\n",
        "tfkl = tf.keras.layers\n",
        "from sklearn import cluster\n",
        "import os, time\n",
        "import PIL\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "from matplotlib.gridspec import GridSpec\n",
        "import scipy.ndimage as nim\n",
        "\n",
        "from matplotlib.patches import Ellipse\n",
        "\n",
        "from netrep.metrics import GaussianStochasticMetric\n",
        "\n",
        "default_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']\n",
        "fingerprint_size = 64  ## the size of the dataset used in these examples"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Monte Carlo information evaluation\n",
        "def monte_carlo_info(mus, logvars, number_random_samples=10):\n",
        "  sample_size = 2000\n",
        "  chunk_eval_size = 10_000\n",
        "  info_estimates = []\n",
        "  emb_dim = mus.shape[-1]\n",
        "  for rand_sample in range(number_random_samples):\n",
        "    rand_inds = np.random.choice(mus.shape[0], size=sample_size)\n",
        "    rand_sample = tf.random.normal(shape=(sample_size, emb_dim),\n",
        "                                  mean=mus[rand_inds],\n",
        "                                  stddev=tf.exp(logvars[rand_inds]/2.),\n",
        "                                   dtype=tf.float64)\n",
        "    # rand_sample = tf.cast(rand_sample, tf.float64)\n",
        "    posterior_probs = compute_likelihoods(rand_sample, mus[rand_inds], logvars[rand_inds], diag=True)\n",
        "    marginal_probs = np.zeros((sample_size))\n",
        "    for start_ind in range(0, mus.shape[0], chunk_eval_size):\n",
        "      end_ind = min(start_ind+chunk_eval_size, mus.shape[0])\n",
        "      marginal_probs = marginal_probs + compute_likelihoods(rand_sample, mus[start_ind:end_ind], logvars[start_ind:end_ind])\n",
        "    marginal_probs = marginal_probs / mus.shape[0]\n",
        "\n",
        "    info_estimates.append(tf.math.log(posterior_probs/marginal_probs))\n",
        "  return np.mean(info_estimates)/np.log(2)\n",
        "\n",
        "\n",
        "@tf.function(experimental_relax_shapes=True)\n",
        "def compute_likelihoods(samples, mus, logvars, diag=False):\n",
        "  mus = tf.cast(mus, tf.float64)\n",
        "  logvars = tf.cast(logvars, tf.float64)\n",
        "  sample_size = tf.shape(samples)[0]\n",
        "  evaluation_batch_size = tf.shape(mus)[0]\n",
        "  embedding_dimension = tf.shape(mus)[-1]\n",
        "  stddevs = tf.exp(logvars/2.)\n",
        "  # Expand dimensions to broadcast and compute the pairwise distances between\n",
        "  # the sampled points and the centers of the conditional distributions\n",
        "  samples = tf.reshape(samples,\n",
        "    [sample_size, 1, embedding_dimension])\n",
        "  mus = tf.reshape(mus, [1, evaluation_batch_size, embedding_dimension])\n",
        "  distances_ui_muj = samples - mus\n",
        "\n",
        "  normalized_distances_ui_muj = distances_ui_muj / tf.reshape(stddevs, [1, evaluation_batch_size, embedding_dimension])\n",
        "  p_ui_cond_xj = tf.exp(-tf.reduce_sum(normalized_distances_ui_muj**2, axis=-1)/2. - \\\n",
        "    tf.reshape(tf.reduce_sum(logvars, axis=-1), [1, evaluation_batch_size])/2.)\n",
        "  normalization_factor = (2.*np.pi)**(tf.cast(embedding_dimension, tf.float64)/2.)\n",
        "  p_ui_cond_xj = p_ui_cond_xj / normalization_factor\n",
        "  if diag:\n",
        "    return tf.linalg.diag_part(p_ui_cond_xj)\n",
        "  else:\n",
        "    return tf.reduce_sum(p_ui_cond_xj, axis=-1)"
      ],
      "metadata": {
        "id": "J-uCw4DR5u8c",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Bhattacharyya-based information evaluation\n",
        "def bhattacharyya_dist_mat(mus, logvars):\n",
        "  \"\"\"Computes Bhattacharyya distances between multivariate Gaussians.\n",
        "  The Bhattacharyya coefficient is the exponentiated negative distance.\n",
        "  Args:\n",
        "    mus: [N, d] float array of the means of the Gaussians.\n",
        "    logvars: [N, d] float array of the log variances of the Gaussians (so we're assuming diagonal\n",
        "    covariance matrices; these are the logs of the diagonal).\n",
        "  Returns:\n",
        "    [N, N] array of distances.\n",
        "  \"\"\"\n",
        "  N = mus.shape[0]\n",
        "  embedding_dimension = mus.shape[1]\n",
        "\n",
        "  ## Manually broadcast\n",
        "  mus1 = np.tile(mus[:, np.newaxis], [1, N, 1])\n",
        "  logvars1 = np.tile(logvars[:, np.newaxis], [1, N, 1])\n",
        "  mus2 = np.tile(mus[np.newaxis], [N, 1, 1])\n",
        "  logvars2 = np.tile(logvars[np.newaxis], [N, 1, 1])\n",
        "  difference_mus = mus1 - mus2  # [N, M, embedding_dimension]; we want [N, N, embedding_dimension, 1]\n",
        "  difference_mus = difference_mus[..., np.newaxis]\n",
        "  difference_mus_T = np.transpose(difference_mus, [0, 1, 3, 2])\n",
        "\n",
        "  sigma_diag = 0.5 * (np.exp(logvars1) + np.exp(logvars2))  ## [N, N, embedding_dimension], but we want a diag mat [N, N, embedding_dimension, embedding_dimension]\n",
        "  sigma_mat = np.expand_dims(sigma_diag, -1) * np.expand_dims(np.ones_like(sigma_diag), -2) * np.reshape(np.eye(embedding_dimension), [1, 1, embedding_dimension, embedding_dimension])\n",
        "  sigma_mat_inv = np.expand_dims(1./sigma_diag, -1) * np.expand_dims(np.ones_like(sigma_diag), -2) * np.reshape(np.eye(embedding_dimension), [1, 1, embedding_dimension, embedding_dimension])\n",
        "\n",
        "  log_determinant_sigma = np.sum(np.log(sigma_diag), axis=-1)\n",
        "  log_determinant_sigma1 = np.sum(logvars1, axis=-1)\n",
        "  log_determinant_sigma2 = np.sum(logvars2, axis=-1)\n",
        "  term1 = 0.125 * (difference_mus_T @ sigma_mat_inv @ difference_mus).reshape([N, N])\n",
        "  term2 = 0.5 * (log_determinant_sigma - 0.5 * (log_determinant_sigma1  + log_determinant_sigma2))\n",
        "  return term1+term2\n",
        "\n",
        "@tf.function(experimental_relax_shapes=True)\n",
        "def bhattacharyya_dist_mat_tf(mus, logvars):\n",
        "  \"\"\"Computes Bhattacharyya distances between multivariate Gaussians.\n",
        "  Args:\n",
        "    mus1: [N, d] float array of the means of the Gaussians.\n",
        "    logvars1: [N, d] float array of the log variances of the Gaussians (so we're assuming diagonal\n",
        "    covariance matrices; these are the logs of the diagonal).\n",
        "  Returns:\n",
        "    [N, M] array of distances.\n",
        "  \"\"\"\n",
        "  N = tf.shape(mus)[0]\n",
        "  embedding_dimension = tf.shape(mus)[1]\n",
        "\n",
        "  mus = tf.cast(mus, tf.float64)\n",
        "  logvars = tf.cast(logvars, tf.float64)\n",
        "\n",
        "  ## Manually broadcast in case either M or N is 1\n",
        "  mus1 = tf.tile(tf.expand_dims(mus, 1), [1, N, 1])\n",
        "  logvars1 = tf.tile(tf.expand_dims(logvars, 1), [1, N, 1])\n",
        "  mus2 = tf.tile(tf.expand_dims(mus, 0), [N, 1, 1])\n",
        "  logvars2 = tf.tile(tf.expand_dims(logvars, 0), [N, 1, 1])\n",
        "  difference_mus = mus1 - mus2  # [N, M, embedding_dimension]; we want [N, M, embedding_dimension, 1]\n",
        "  difference_mus = tf.expand_dims(difference_mus, -1)\n",
        "  difference_mus_T = tf.transpose(difference_mus, [0, 1, 3, 2])\n",
        "\n",
        "  sigma_diag = 0.5 * (tf.exp(logvars1) + tf.exp(logvars2))  ## [N, M, embedding_dimension], but we want a diag mat [N, M, embedding_dimension, embedding_dimension]\n",
        "  # sigma_mat = np.apply_along_axis(np.diag, -1, sigma_diag)\n",
        "  sigma_mat = tf.expand_dims(sigma_diag, -1) * tf.expand_dims(tf.ones_like(sigma_diag, dtype=tf.float64), -2) * tf.reshape(tf.eye(embedding_dimension, dtype=tf.float64), [1, 1, embedding_dimension, embedding_dimension])\n",
        "  # sigma_mat_inv = np.apply_along_axis(np.diag, -1, 1./sigma_diag)\n",
        "  sigma_mat_inv = tf.expand_dims(1./sigma_diag, -1) * tf.expand_dims(tf.ones_like(sigma_diag, dtype=tf.float64), -2) * tf.reshape(tf.eye(embedding_dimension, dtype=tf.float64), [1, 1, embedding_dimension, embedding_dimension])\n",
        "\n",
        "  log_determinant_sigma = tf.reduce_sum(tf.math.log(sigma_diag), axis=-1)\n",
        "  log_determinant_sigma1 = tf.reduce_sum(logvars1, axis=-1)\n",
        "  log_determinant_sigma2 = tf.reduce_sum(logvars2, axis=-1)\n",
        "  term1 = 0.125 * tf.reshape(difference_mus_T @ sigma_mat_inv @ difference_mus, [N, N])\n",
        "  term2 = 0.5 * (log_determinant_sigma - 0.5 * (log_determinant_sigma1 + log_determinant_sigma2))\n",
        "  return term1+term2\n",
        "\n",
        "@tf.function(experimental_relax_shapes=True)\n",
        "def bhat_info_tf(mus, logvars):\n",
        "  bhat_dist_mat = bhattacharyya_dist_mat_tf(mus, logvars)\n",
        "  info = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat_dist_mat))))\n",
        "  return info\n",
        "\n",
        "@tf.function\n",
        "def compute_nmi_bhat_tf(bhat1, bhat2):\n",
        "  i1 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1), axis=1)))\n",
        "  i2 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat2), axis=1)))\n",
        "  i11 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1*2), axis=1)))\n",
        "  i22 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat2*2), axis=1)))\n",
        "  i12 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1-bhat2), axis=1)))\n",
        "  return (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))\n",
        "\n",
        "@tf.function\n",
        "def compute_vi_bhat_tf(bhat1, bhat2):\n",
        "  i11 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1*2), axis=1)))\n",
        "  i22 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat2*2), axis=1)))\n",
        "  i12 = -tf.reduce_mean(tf.math.log(tf.reduce_mean(tf.exp(-bhat1-bhat2), axis=1)))\n",
        "  return 2*i12 - i11 - i22"
      ],
      "metadata": {
        "id": "5jKIRrgudIXV",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title CKA with Bhattacharyya matrices\n",
        "\n",
        "centering_matrix = np.eye(fingerprint_size) - np.ones((fingerprint_size, fingerprint_size))/fingerprint_size\n",
        "def compute_cka_bhat(bhat1, bhat2):\n",
        "  sim11 = np.trace(bhat1 @ centering_matrix @ bhat1 @ centering_matrix)\n",
        "  sim22 = np.trace(bhat2 @ centering_matrix @ bhat2 @ centering_matrix)\n",
        "  sim12 = np.trace(bhat1 @ centering_matrix @ bhat2 @ centering_matrix)\n",
        "  cka = sim12 / np.sqrt(sim11*sim22)\n",
        "  return cka\n",
        "\n",
        "@tf.function\n",
        "def compute_cka_bhat_tf(bhat1, bhat2):\n",
        "  sim11 = tf.linalg.trace(bhat1 @ centering_matrix @ bhat1 @ centering_matrix)\n",
        "  sim22 = tf.linalg.trace(bhat2 @ centering_matrix @ bhat2 @ centering_matrix)\n",
        "  sim12 = tf.linalg.trace(bhat1 @ centering_matrix @ bhat2 @ centering_matrix)\n",
        "  cka = sim12 / tf.sqrt(sim11*sim22)\n",
        "  return cka\n"
      ],
      "metadata": {
        "cellView": "form",
        "id": "VaIbBUfyLFBc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n"
      ],
      "metadata": {
        "id": "aoAXkv3EG_m8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Generate the 9 representation spaces\n",
        "sqrt_N = int(np.round(np.sqrt(N)))\n",
        "\n",
        "cmap = plt.get_cmap('viridis')\n",
        "alpha = 0.5\n",
        "\n",
        "x = np.arange(N)\n",
        "\n",
        "u_mus_all, u_logvars_all = [[], []]\n",
        "\n",
        "constant_variance_offset = 0.1\n",
        "\n",
        "###### Spiral: constant variance\n",
        "\n",
        "spiral_freq = 0.2\n",
        "u_mus = np.sqrt(x).reshape([-1, 1])*np.stack([np.cos(2*np.pi*np.sqrt(x)*spiral_freq), np.sin(2*np.pi*np.sqrt(x)*spiral_freq)], -1)\n",
        "u_logvars = np.zeros_like(u_mus) + constant_variance_offset\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "###### Bloated spiral\n",
        "\n",
        "u_mus = u_mus_all[0].copy()\n",
        "u_logvars = u_logvars_all[0].copy() + 1.25\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "\n",
        "###### Bloated bloated spiral\n",
        "\n",
        "u_mus = u_mus_all[0].copy()\n",
        "u_logvars = u_logvars_all[0].copy() + 2.5\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "\n",
        "###### Square spiral\n",
        "\n",
        "position = np.zeros(2)\n",
        "u_mus = [position.copy()]\n",
        "step_size = 0.5\n",
        "movements = np.float32([[0, 1],\n",
        "                        [1, 0],\n",
        "                        [0, -1],\n",
        "                        [-1, 0]])\n",
        "\n",
        "movement_ind = 0\n",
        "step_ind = 0\n",
        "side_length = 1\n",
        "for i in range(N-1):\n",
        "  position += step_size*movements[movement_ind]\n",
        "  step_ind += 1\n",
        "  if step_ind == side_length:\n",
        "    step_ind = 0\n",
        "    movement_ind = (movement_ind+1) % 4\n",
        "    step_size += 0.1\n",
        "    if not(movement_ind % 2):\n",
        "      side_length += 1\n",
        "  u_mus.append(position.copy())\n",
        "u_mus = np.array(u_mus)\n",
        "\n",
        "u_logvars = np.zeros_like(u_mus)\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "\n",
        "########### more variance\n",
        "position = np.zeros(2)\n",
        "u_mus = [position.copy()]\n",
        "step_size = 0.5\n",
        "movements = np.float32([[0, 1],\n",
        "                        [1, 0],\n",
        "                        [0, -1],\n",
        "                        [-1, 0]])\n",
        "\n",
        "movement_ind = 0\n",
        "step_ind = 0\n",
        "side_length = 1\n",
        "for i in range(N-1):\n",
        "  position += step_size*movements[movement_ind]\n",
        "  step_ind += 1\n",
        "  if step_ind == side_length:\n",
        "    step_ind = 0\n",
        "    movement_ind = (movement_ind+1) % 4\n",
        "    step_size += 0.1\n",
        "    if not(movement_ind % 2):\n",
        "      side_length += 1\n",
        "  u_mus.append(position.copy())\n",
        "u_mus = np.array(u_mus)\n",
        "\n",
        "u_logvars = np.zeros_like(u_mus) + 2.5\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "###### 1D line\n",
        "\n",
        "u_mus = np.linspace(-N/2, N/2, N).reshape([-1, 1])\n",
        "u_logvars = np.zeros_like(u_mus)-1.\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "###### Discrete: two\n",
        "\n",
        "u_mus = np.concatenate([np.ones((N//2, 2))*[[-sqrt_N*0.8, 0]],\n",
        "                            np.ones((N//2, 2))*[[sqrt_N*0.8, 0]]], 0)\n",
        "u_mus = u_mus + np.random.randn(N, 2)\n",
        "\n",
        "u_logvars = np.zeros_like(u_mus)+2\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "###### Discrete: four\n",
        "\n",
        "u_mus = np.concatenate([\n",
        "    np.ones((N//4, 2))*[[-sqrt_N*0.7, sqrt_N*0.7]],\n",
        "    np.ones((N//4, 2))*[[sqrt_N*0.7, sqrt_N*0.7]],\n",
        "    np.ones((N//4, 2))*[[sqrt_N*0.7, -sqrt_N*0.7]],\n",
        "    np.ones((N//4, 2))*[[-sqrt_N*0.7, -sqrt_N*0.7]]\n",
        "    ], 0)\n",
        "\n",
        "u_mus = u_mus + np.random.randn(N, 2)*0\n",
        "u_logvars = np.zeros_like(u_mus)+2\n",
        "\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "u_mus = np.concatenate([\n",
        "    np.ones((N//4, 2))*[[-sqrt_N*0.7, sqrt_N*0.7]],\n",
        "    np.ones((N//4, 2))*[[-sqrt_N*0.7, -sqrt_N*0.7]],\n",
        "    np.ones((N//4, 2))*[[sqrt_N*0.7, -sqrt_N*0.7]],\n",
        "    np.ones((N//4, 2))*[[sqrt_N*0.7, sqrt_N*0.7]]\n",
        "    ], 0)\n",
        "\n",
        "vert_variance = 3.\n",
        "\n",
        "u_mus = u_mus + np.random.randn(N, 2)*0\n",
        "u_logvars = np.ones((N, 2))*[[0, vert_variance]]+1\n",
        "u_mus_all.append(u_mus)\n",
        "u_logvars_all.append(u_logvars)\n",
        "\n",
        "\n",
        "plt.figure(figsize=(8, 8))\n",
        "for plt_ind, (mus, logvars) in enumerate(zip(u_mus_all, u_logvars_all)):\n",
        "  plt.subplot(3, 3, plt_ind+1)\n",
        "  if mus.shape[1] == 2:\n",
        "    if plt_ind < 6:\n",
        "      for i in range(N):\n",
        "        ell = Ellipse(xy=mus[i],\n",
        "                      width=2*np.exp(logvars[i, 0]/2.), height=2*np.exp(logvars[i, 1]/2.),\n",
        "                      facecolor=cmap(i/(N-1)), alpha=alpha, edgecolor='k')\n",
        "        plt.gca().add_artist(ell)\n",
        "      plt.ylim(-sqrt_N*1.5, sqrt_N*1.5)\n",
        "      plt.xlim(-sqrt_N*1.5, sqrt_N*1.5)\n",
        "    else:\n",
        "      for i in range(N):\n",
        "        ell = Ellipse(xy=mus[i],\n",
        "                      width=2*np.exp(logvars[i, 0]/2.), height=2*np.exp(logvars[i, 1]/2.),\n",
        "                      facecolor=cmap(i/(N-1)), alpha=1, edgecolor='k')\n",
        "        plt.gca().add_artist(ell)\n",
        "      plt.ylim(-sqrt_N*2, sqrt_N*2)\n",
        "      plt.xlim(-sqrt_N*2, sqrt_N*2)\n",
        "  else:\n",
        "    plt_x = np.linspace(-N/2-20, N/2+20, 10000)\n",
        "    for i in range(20, 40):\n",
        "\n",
        "      sig = np.exp(logvars[i]/2.)\n",
        "      plt_y = np.exp(-np.power((plt_x - mus[i]) / sig, 2.0) / 2) /  (np.sqrt(2.0 * np.pi) * sig)\n",
        "      plt.plot(plt_x, plt_y, lw=4, color=cmap(i/(N-1)))\n",
        "    plt.xlim(-3, 3)\n",
        "  plt.axis('off')\n",
        "plt.tight_layout()\n",
        "\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "ROkrccKnBOOK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Now compute the pairwise similarities\n",
        "ssm_scores = []\n",
        "nmis_bhat, vis_bhat = [np.eye(len(u_mus_all)), np.eye(len(u_mus_all))]\n",
        "cka_bhat, cka_reg = [np.eye(len(u_mus_all)), np.eye(len(u_mus_all))]\n",
        "nmis_mc, vis_mc = [np.eye(len(u_mus_all)), np.eye(len(u_mus_all))]\n",
        "number_mc_random_samples = 5\n",
        "fractional_infos = []\n",
        "for embedding_space_ind1 in range(len(u_mus_all)):\n",
        "  u_mus = u_mus_all[embedding_space_ind1]\n",
        "  u_logvars = u_logvars_all[embedding_space_ind1]\n",
        "  # Only the MC evaluation takes long enough to care about saving some computation by doing the following in the outer loop\n",
        "  i1 = monte_carlo_info(u_mus, u_logvars, number_random_samples=number_mc_random_samples)\n",
        "  i11 = monte_carlo_info(np.tile(u_mus, [1, 2]),\n",
        "                        np.tile(u_logvars, [1, 2]), number_random_samples=number_mc_random_samples)\n",
        "  fractional_infos.append(i1/np.log2(N))\n",
        "  for embedding_space_ind2 in range(embedding_space_ind1, len(u_mus_all)):\n",
        "\n",
        "    v_mus = u_mus_all[embedding_space_ind2]\n",
        "    v_logvars = u_logvars_all[embedding_space_ind2]\n",
        "\n",
        "    i2 = monte_carlo_info(v_mus, v_logvars, number_random_samples=number_mc_random_samples)\n",
        "    i22 = monte_carlo_info(np.tile(v_mus, [1, 2]),\n",
        "                          np.tile(v_logvars, [1, 2]), number_random_samples=number_mc_random_samples)\n",
        "    i12 = monte_carlo_info(np.concatenate([u_mus, v_mus], 1),\n",
        "                          np.concatenate([u_logvars, v_logvars], 1), number_random_samples=number_mc_random_samples)\n",
        "\n",
        "    nmi = (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))\n",
        "    nmis_mc[embedding_space_ind1, embedding_space_ind2] = nmi\n",
        "    nmis_mc[embedding_space_ind2, embedding_space_ind1] = nmi\n",
        "\n",
        "    vi = 2*i12 - i11 - i22\n",
        "    vis_mc[embedding_space_ind1, embedding_space_ind2] = vi\n",
        "    vis_mc[embedding_space_ind2, embedding_space_ind1] = vi\n",
        "\n",
        "    bhat1 = bhattacharyya_dist_mat_tf(u_mus, u_logvars)\n",
        "    bhat2 = bhattacharyya_dist_mat_tf(v_mus, v_logvars)\n",
        "\n",
        "    nmi = compute_nmi_bhat_tf(bhat1, bhat2)\n",
        "    nmis_bhat[embedding_space_ind1, embedding_space_ind2] = nmi\n",
        "    nmis_bhat[embedding_space_ind2, embedding_space_ind1] = nmi\n",
        "\n",
        "    vi = compute_vi_bhat_tf(bhat1, bhat2)\n",
        "    vis_bhat[embedding_space_ind1, embedding_space_ind2] = vi\n",
        "    vis_bhat[embedding_space_ind2, embedding_space_ind1] = vi\n",
        "    cka = compute_cka_bhat_tf(tf.exp(-bhat1), tf.exp(-bhat2))\n",
        "    cka_bhat[embedding_space_ind1, embedding_space_ind2] = cka\n",
        "    cka_bhat[embedding_space_ind2, embedding_space_ind1] = cka\n",
        "\n",
        "\n",
        "Xs = []\n",
        "for embedding_space_ind in range(len(u_mus_all)):\n",
        "  mus = u_mus_all[embedding_space_ind]\n",
        "  logvars = u_logvars_all[embedding_space_ind]\n",
        "  if mus.shape[1] != 2:  ## Since the netrep code does not allow different dimensionalities, just fill the values in with something valid\n",
        "    mus = u_mus_all[embedding_space_ind-1]\n",
        "    logvars = u_logvars_all[embedding_space_ind-1]\n",
        "  covs = np.apply_along_axis(np.diag, 1, np.exp(-logvars))\n",
        "  X = (mus, covs)\n",
        "  Xs.append(X)\n",
        "\n",
        "ct = time.time()\n",
        "alpha = 1  ## the Wasserstein thing comparing means and covariances\n",
        "metric = GaussianStochasticMetric(alpha, init='rand', n_restarts=50)\n",
        "dist_matrix, _ = metric.pairwise_distances(Xs)\n",
        "ssm_scores = dist_matrix\n",
        "## Fill in the NAN values\n",
        "one_dim_ind = 5\n",
        "ssm_scores[one_dim_ind] = np.nan\n",
        "ssm_scores[:, one_dim_ind] = np.nan\n",
        "ssm_scores[one_dim_ind, one_dim_ind] = 0.\n",
        "print(f'Computed SSM scores, time taken: {time.time()-ct:.3f} sec')"
      ],
      "metadata": {
        "id": "Fw7ftakF8a6M"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "labels = ['nmi_mc', 'nmi_bhat', 'cka_bhat', 'vi_mc', 'vi_bhat', 'SSM']\n",
        "similarities = [nmis_mc, nmis_bhat, cka_bhat, vis_mc, vis_bhat, SSM_scores]\n",
        "cmaps = ['Blues', 'Blues', 'Blues', 'Reds_r', 'Reds_r', 'Reds_r']\n",
        "vmins = [0, 0, 0, 0, 0, 0]\n",
        "vmaxes = [1, 1, 1, 3.2, 3.2, None]\n",
        "plt.figure(figsize=(8, 5))\n",
        "for plt_ind, (similarity_values, label, cmap, vmin, vmax) in enumerate(zip(similarities, labels, cmaps, vmins, vmaxes)):\n",
        "  plt.subplot(2, 3, plt_ind+1)\n",
        "  plt.imshow(np.reshape(similarity_values, [len(u_mus_all), -1]), vmin=vmin, vmax=vmax, cmap=cmap)\n",
        "  plt.colorbar()\n",
        "  plt.title(label, fontsize=15)\n",
        "  plt.xticks(np.arange(len(u_mus_all)), 'abcdefghi')\n",
        "  plt.yticks(np.arange(len(u_mus_all)), 'abcdefghi')\n",
        "  plt.tick_params(axis='both', which='both',length=0)\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "uOkF9F-x9BI8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "LKND83Buup7g"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}