{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "xUHm2okAuA8m"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline\n",
        "import tensorflow as tf\n",
        "import matplotlib\n",
        "from matplotlib.patches import Ellipse\n",
        "\n",
        "import utils"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "ensemble_sizes = [2, 3, 5, 10, 15, 20, 30, 50]\n",
        "number_repeat_training_runs = 5  ## the number of times to repeat the experiment for statistics\n",
        "fingerprint_size = 200"
      ],
      "metadata": {
        "id": "leQz8ndCHNNc"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#### Train the ensemble of weak learners, beta-VAEs with 1D latent spaces\n",
        "#### Should take a couple seconds each\n",
        "#### Then store the fingerprint matrices that we will use for the model fusion afterward\n",
        "bhat_mats = []\n",
        "\n",
        "display_training_curves = False\n",
        "\n",
        "thetas_fingerprint = np.linspace(0, 2*np.pi,fingerprint_size, endpoint=False)\n",
        "pts_fingerprint = np.stack([np.cos(thetas_fingerprint), np.sin(thetas_fingerprint)], -1)\n",
        "\n",
        "number_training_steps = 3000\n",
        "batch_size = 2048\n",
        "beta = 3e-2\n",
        "learning_rate = 1e-3\n",
        "data_dimensionality = 2\n",
        "enc_arch_spec = [256]*2\n",
        "dec_arch_spec = [256]*2\n",
        "activation_fn = 'tanh'\n",
        "number_bottleneck_channels = 1\n",
        "\n",
        "for trial in range(number_repeat_training_runs*np.max(ensemble_sizes)):\n",
        "  encoder = tf.keras.Sequential([tf.keras.Input((data_dimensionality,))] + \\\n",
        "                                [tf.keras.layers.Dense(number_units, activation_fn) for number_units in enc_arch_spec] + \\\n",
        "                                [tf.keras.layers.Dense(2*number_bottleneck_channels)])\n",
        "  decoder = tf.keras.Sequential([tf.keras.Input((number_bottleneck_channels,))] + \\\n",
        "                                [tf.keras.layers.Dense(number_units, activation_fn) for number_units in dec_arch_spec] + \\\n",
        "                                [tf.keras.layers.Dense(data_dimensionality)])\n",
        "  all_trainable_variables = encoder.trainable_variables + decoder.trainable_variables\n",
        "  optimizer = tf.keras.optimizers.Adam(learning_rate)\n",
        "  mse = tf.keras.losses.MeanSquaredError()\n",
        "  @tf.function\n",
        "  def train_step(pts):\n",
        "    with tf.GradientTape() as tape:\n",
        "      embs_mus, embs_logvars = tf.split(encoder(pts), 2, axis=-1)\n",
        "      kl = tf.reduce_mean(tf.reduce_sum(0.5 * (tf.square(embs_mus) + tf.exp(embs_logvars) - embs_logvars - 1.), axis=-1), axis=0)\n",
        "      reparameterized_embs = tf.random.normal(embs_mus.shape, mean=embs_mus, stddev=tf.exp(embs_logvars/2.))\n",
        "      recon = decoder(reparameterized_embs)\n",
        "      reconstruction_loss = mse(pts, recon)\n",
        "      loss = tf.reduce_mean(reconstruction_loss) + beta * kl\n",
        "    grads = tape.gradient(loss, all_trainable_variables)\n",
        "    optimizer.apply_gradients(zip(grads, all_trainable_variables))\n",
        "    return reconstruction_loss, kl\n",
        "\n",
        "  recon_loss_series, kl_loss_series = [[], []]\n",
        "  for step in range(number_training_steps):\n",
        "    batch_theta = np.random.uniform(0, 2*np.pi, size=batch_size)\n",
        "    batch_pts = np.stack([np.cos(batch_theta), np.sin(batch_theta)], -1)\n",
        "    recon_loss, kl = train_step(batch_pts)\n",
        "    recon_loss_series.append(recon_loss.numpy())\n",
        "    kl_loss_series.append(kl.numpy())\n",
        "\n",
        "  if display_training_curves:\n",
        "    plt.figure(figsize=(8, 5))\n",
        "    plt.plot(recon_loss_series, 'k', lw=2)\n",
        "    plt.ylabel('Recon', fontsize=15)\n",
        "    plt.xlabel('Step', fontsize=15)\n",
        "    plt.gca().twinx().plot(kl_loss_series, lw=2)\n",
        "    plt.ylabel('KL', color='b', fontsize=15)\n",
        "    plt.show()\n",
        "\n",
        "  fingerprint_embs = encoder(pts_fingerprint)\n",
        "  mus, logvars = tf.split(fingerprint_embs, 2, -1)\n",
        "  bhat_mats.append(utils.bhattacharyya_dist_mat(mus, logvars))\n",
        "\n",
        "  print(f'Computed bhat mat for run {len(bhat_mats)}/{number_repeat_training_runs*np.max(ensemble_sizes)}.')"
      ],
      "metadata": {
        "id": "zSlp-6BVuLyx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "@tf.function\n",
        "def compute_vi_similarity(bhat1, bhat2):\n",
        "  i11 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1*2, axis=1))\n",
        "  i22 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2*2, axis=1))\n",
        "  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))\n",
        "  return tf.exp(-(i12*2 - i11 - i22))\n",
        "\n",
        "@tf.function\n",
        "def compute_nmi(bhat1, bhat2):\n",
        "  i1 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1, axis=1))\n",
        "  i2 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2, axis=1))\n",
        "  i11 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1*2, axis=1))\n",
        "  i22 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2*2, axis=1))\n",
        "  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))\n",
        "  return (i1+i2-i12) / tf.sqrt((2*i1-i11)*(2*i2-i22))\n",
        "\n",
        "@tf.function\n",
        "def compute_info(bhat1, bhat2):\n",
        "  i1 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1, axis=1))\n",
        "  i2 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat2, axis=1))\n",
        "  i12 = -tf.reduce_mean(tf.math.reduce_logsumexp(-bhat1-bhat2, axis=1))\n",
        "  return i1+i2-i12"
      ],
      "metadata": {
        "id": "MVFJb3yzvA4N"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def continuity_metric(mus, logvars, percentile=90):\n",
        "  adj_bhat_distances = np.diag(utils.bhattacharyya_dist_mat(mus, logvars), k=1)\n",
        "  return np.max(adj_bhat_distances) / np.percentile(adj_bhat_distances, percentile)"
      ],
      "metadata": {
        "id": "FOqDhvn5vTa2"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "display_periodically_during_training = False\n",
        "\n",
        "fusion_space_dimensionality = 2\n",
        "number_opt_steps = 20000\n",
        "continuity_metric_values = []\n",
        "learning_rate = 3e0\n",
        "cmap = plt.get_cmap('hsv')\n",
        "gaussian_display_alpha = 0.5\n",
        "for sim_method_name, sim_method in zip(['info', 'vi', 'nmi'], [compute_info, compute_vi_similarity, compute_nmi]):\n",
        "  for ensemble_size in ensemble_sizes:\n",
        "    for repeat_ind in range(number_repeat_training_runs):\n",
        "      ## Randomly initialize the posterior parameters for each of the fingerprint points\n",
        "      mus_var = tf.Variable(tf.random.normal((fingerprint_size, fusion_space_dimensionality), stddev=0.05), trainable=True)\n",
        "      logvars_var = tf.Variable(tf.zeros((fingerprint_size, fusion_space_dimensionality)), trainable=True)\n",
        "      all_trainable_variables = [mus_var, logvars_var]\n",
        "      bhats_to_use = bhat_mats[repeat_ind*np.max(ensemble_sizes):repeat_ind*np.max(ensemble_sizes)+ensemble_size]\n",
        "      optimizer = tf.keras.optimizers.SGD(learning_rate)\n",
        "      @tf.function\n",
        "      def compute_avg_similarity():\n",
        "        with tf.GradientTape() as tape:\n",
        "          bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)\n",
        "          total_similarity = 0.\n",
        "          for other_bhat in bhats_to_use:\n",
        "            sim = sim_method(bhat_mat_opt, tf.cast(other_bhat, tf.float32))\n",
        "            total_similarity -= sim\n",
        "          loss = total_similarity / ensemble_size\n",
        "        grads = tape.gradient(loss, all_trainable_variables)\n",
        "        optimizer.apply_gradients(zip(grads, all_trainable_variables))\n",
        "        return total_similarity\n",
        "      sim_series = []\n",
        "      for opt_step in range(number_opt_steps):\n",
        "        total_similarity = compute_avg_similarity()\n",
        "        sim_series.append(total_similarity)\n",
        "        if display_periodically_during_training and (opt_step % 1000) == 0:\n",
        "          bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)\n",
        "          plt.figure(figsize=(10, 5))\n",
        "          plt.subplot(121)\n",
        "          plt.imshow(tf.exp(-bhat_mat_opt), vmin=0, vmax=1)\n",
        "          plt.axis('off')\n",
        "          plt.subplot(122)\n",
        "          plt.scatter(*np.float32(mus_var.numpy()).T, c=np.linspace(0, 1, fingerprint_size), cmap='hsv')\n",
        "          for i in range(fingerprint_size):\n",
        "            ell = Ellipse(xy=mus_var.numpy()[i],\n",
        "                          width=2*np.exp(logvars_var.numpy()[i, 0]/2.), height=2*np.exp(logvars_var.numpy()[i, 1]/2.),\n",
        "                          facecolor=cmap(i/(fingerprint_size-1)), alpha=gaussian_display_alpha, edgecolor='k')\n",
        "            plt.gca().add_artist(ell)\n",
        "          plt.show()\n",
        "      plt.figure(figsize=(7, 4))\n",
        "      plt.plot(sim_series)\n",
        "      plt.ylabel(f'Similarity: {sim_method_name}', fontsize=15)\n",
        "      plt.xlabel('Step', fontsize=15)\n",
        "      plt.show()\n",
        "\n",
        "      bhat_mat_opt = utils.bhattacharyya_dist_mat_tf(mus_var, logvars_var)\n",
        "      plt.figure(figsize=(10, 5))\n",
        "      plt.subplot(121)\n",
        "      plt.imshow(tf.exp(-bhat_mat_opt), vmin=0, vmax=1, cmap='Blues_r')\n",
        "      plt.axis('off')\n",
        "      plt.subplot(122)\n",
        "      plt.scatter(*np.float32(mus_var.numpy()).T, c=np.linspace(0, 1, fingerprint_size), cmap='hsv')\n",
        "      for i in range(fingerprint_size):\n",
        "        ell = Ellipse(xy=mus_var.numpy()[i],\n",
        "                      width=2*np.exp(logvars_var.numpy()[i, 0]/2.), height=2*np.exp(logvars_var.numpy()[i, 1]/2.),\n",
        "                      facecolor=cmap(i/(fingerprint_size-1)), alpha=gaussian_display_alpha, edgecolor='k')\n",
        "        plt.gca().add_artist(ell)\n",
        "      plt.xticks([]); plt.yticks([])\n",
        "      plt.savefig(f'outs/SGD_{sim_method_name}_{ensemble_size}_{repeat_ind}.svg')\n",
        "      plt.show()\n",
        "\n",
        "      continuity_metric_values.append(continuity_metric(mus_var.numpy(), logvars_var.numpy()))\n",
        "\n",
        "np.savez('fusion_results.npz',\n",
        "         continuity_metric_values=continuity_metric_values,\n",
        "         ensemble_sizes=ensemble_sizes)"
      ],
      "metadata": {
        "id": "oAMjtMylvJ2X"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "plt.figure(figsize=(8, 4))\n",
        "continuity_metric_values = np.reshape(continuity_metric_values, [3, -1, number_repeat_training_runs])\n",
        "for method_ind, sim_method_name in enumerate(['info', 'vi', 'nmi']):\n",
        "  plt.errorbar(np.array(ensemble_sizes)*1.1**(method_ind-1),\n",
        "               np.mean(continuity_metric_values[method_ind], -1),\n",
        "               yerr=np.std(continuity_metric_values[method_ind], -1),\n",
        "              ls='-', marker='dos'[method_ind],\n",
        "               markersize=8, lw=2, elinewidth=4, label=sim_method_name)\n",
        "plt.xscale('log')\n",
        "plt.yscale('log')\n",
        "plt.xticks(ensemble_sizes, ensemble_sizes)\n",
        "plt.xticks([2, 5, 10, 20, 50], [2, 5, 10, 20, 50])\n",
        "buffer_factor = 1.15\n",
        "plt.xlim(ensemble_sizes[0]/buffer_factor, ensemble_sizes[-1]*buffer_factor)\n",
        "plt.ylim(1, 1000)\n",
        "plt.xlabel('Ensemble size', fontsize=15)\n",
        "plt.ylabel('Continuity', fontsize=15)\n",
        "plt.tick_params(width=2, length=4, which='major')\n",
        "plt.tick_params(width=2, length=3, which='minor')\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.savefig('fusion_performance_comparison.svg')\n",
        "plt.show()"
      ],
      "metadata": {
        "id": "BzTE30naHfdJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "rQwnhMT7HyaW"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}