{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "L4vFKVJOJ5Rx"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "import os, time\n",
        "\n",
        "import PIL\n",
        "import tensorflow_datasets as tfds\n",
        "from matplotlib.gridspec import GridSpec\n",
        "import tensorflow_hub as hub\n",
        "\n",
        "import utils"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title load dataset\n",
        "def load_dataset(dataset_name, data_dir='Data/'):\n",
        "  ####################################################################################################################\n",
        "  if dataset_name in ['dsprites', 'cifar10', 'shapes3d', 'mnist', 'fashion_mnist', 'plant_village']:\n",
        "    data_rescaling_factor = 1.\n",
        "\n",
        "    if dataset_name == 'dsprites':\n",
        "      data_rescaling_factor = 255.\n",
        "\n",
        "    dset_loaded, dset_info = tfds.load(dataset_name, data_dir=os.path.join(data_dir, 'tensorflow_datasets'), with_info=True, decoders={\n",
        "        'image': tfds.decode.SkipDecoding(),\n",
        "    })\n",
        "\n",
        "    dset = dset_loaded['train']\n",
        "    if dataset_name == 'cifar10':  ## I didn't use the test set pre-cifar\n",
        "      dset = dset.concatenate(dset_loaded['test'])\n",
        "    dset = dset.map(lambda example: example['image'])\n",
        "\n",
        "    # dset = dset.shuffle(1_000_000)\n",
        "\n",
        "    dset = dset.map(\n",
        "        lambda image: dset_info.features['image'].decode_example(image))\n",
        "    dset = dset.map(lambda image: tf.image.convert_image_dtype(image, tf.float32)*data_rescaling_factor)\n",
        "\n",
        "    return dset\n",
        "  ####################################################################################################################\n",
        "  elif dataset_name == 'smallnorb':\n",
        "    SMALLNORB_TEMPLATE = os.path.join(data_dir,\n",
        "        \"disentanglement_lib\", \"small_norb\",\n",
        "        \"smallnorb-{}-{}.mat\")\n",
        "\n",
        "    SMALLNORB_CHUNKS = [\n",
        "        \"5x46789x9x18x6x2x96x96-training\",\n",
        "        \"5x01235x9x18x6x2x96x96-testing\",\n",
        "    ]\n",
        "\n",
        "    def _load_small_norb_chunks(path_template, chunk_names):\n",
        "      \"\"\"Loads several chunks of the small norb data set for final use.\"\"\"\n",
        "      list_of_images, list_of_features = _load_chunks(path_template, chunk_names)\n",
        "      features = np.concatenate(list_of_features, axis=0)\n",
        "      features[:, 3] = features[:, 3] / 2  # azimuth values are 0, 2, 4, ..., 24\n",
        "      return np.concatenate(list_of_images, axis=0), features\n",
        "\n",
        "\n",
        "    def _load_chunks(path_template, chunk_names):\n",
        "      \"\"\"Loads several chunks of the small norb data set into lists.\"\"\"\n",
        "      list_of_images = []\n",
        "      list_of_features = []\n",
        "      for chunk_name in chunk_names:\n",
        "        norb = _read_binary_matrix(path_template.format(chunk_name, \"dat\"))\n",
        "        list_of_images.append(_resize_images(norb[:, 0]))\n",
        "        norb_class = _read_binary_matrix(path_template.format(chunk_name, \"cat\"))\n",
        "        norb_info = _read_binary_matrix(path_template.format(chunk_name, \"info\"))\n",
        "        list_of_features.append(np.column_stack((norb_class, norb_info)))\n",
        "      return list_of_images, list_of_features\n",
        "\n",
        "\n",
        "    def _read_binary_matrix(filename):\n",
        "      \"\"\"Reads and returns binary formatted matrix stored in filename.\"\"\"\n",
        "      with tf.io.gfile.GFile(filename, \"rb\") as f:\n",
        "        s = f.read()\n",
        "        magic = int(np.frombuffer(s, \"int32\", 1))\n",
        "        ndim = int(np.frombuffer(s, \"int32\", 1, 4))\n",
        "        eff_dim = max(3, ndim)\n",
        "        raw_dims = np.frombuffer(s, \"int32\", eff_dim, 8)\n",
        "        dims = []\n",
        "        for i in range(0, ndim):\n",
        "          dims.append(raw_dims[i])\n",
        "\n",
        "        dtype_map = {\n",
        "            507333717: \"int8\",\n",
        "            507333716: \"int32\",\n",
        "            507333713: \"float\",\n",
        "            507333715: \"double\"\n",
        "        }\n",
        "        data = np.frombuffer(s, dtype_map[magic], offset=8 + eff_dim * 4)\n",
        "      data = data.reshape(tuple(dims))\n",
        "      return data\n",
        "\n",
        "    def _resize_images(integer_images):\n",
        "      resized_images = np.zeros((integer_images.shape[0], 64, 64))\n",
        "      for i in range(integer_images.shape[0]):\n",
        "        image = PIL.Image.fromarray(integer_images[i, :, :])\n",
        "        image = image.resize((64, 64), PIL.Image.ANTIALIAS)\n",
        "        resized_images[i, :, :] = image\n",
        "      return resized_images / 255.\n",
        "\n",
        "    images, features = _load_small_norb_chunks(SMALLNORB_TEMPLATE,\n",
        "                                                SMALLNORB_CHUNKS)\n",
        "    factor_sizes = [5, 10, 9, 18, 6]\n",
        "    # Instances are not part of the latent space.\n",
        "    latent_factor_indices = [0, 2, 3, 4]\n",
        "    num_total_factors = features.shape[1]\n",
        "    np.random.shuffle(images)\n",
        "    return tf.data.Dataset.from_tensor_slices(np.expand_dims(images, -1).astype(np.float32))\n",
        "\n",
        "  ####################################################################################################################\n",
        "  elif dataset_name == 'cars3d':\n",
        "    import scipy.io as sio\n",
        "    from sklearn.utils import extmath\n",
        "\n",
        "    CARS3D_PATH = os.path.join(data_dir, \"disentanglement_lib\", \"cars\")\n",
        "    \"\"\"Cars3D data set.\n",
        "\n",
        "    The data set was first used in the paper \"Deep Visual Analogy-Making\"\n",
        "    (https://papers.nips.cc/paper/5845-deep-visual-analogy-making) and can be\n",
        "    downloaded from http://www.scottreed.info/. The images are rescaled to 64x64.\n",
        "\n",
        "    The ground-truth factors of variation are:\n",
        "    0 - elevation (4 different values)\n",
        "    1 - azimuth (24 different values)\n",
        "    2 - object type (183 different values)\n",
        "    \"\"\"\n",
        "\n",
        "    class StateSpaceAtomIndex(object):\n",
        "      \"\"\"Index mapping from features to positions of state space atoms.\"\"\"\n",
        "\n",
        "      def __init__(self, factor_sizes, features):\n",
        "        \"\"\"Creates the StateSpaceAtomIndex.\n",
        "\n",
        "        Args:\n",
        "          factor_sizes: List of integers with the number of distinct values for each\n",
        "            of the factors.\n",
        "          features: Numpy matrix where each row contains a different factor\n",
        "            configuration. The matrix needs to cover the whole state space.\n",
        "        \"\"\"\n",
        "        self.factor_sizes = factor_sizes\n",
        "        num_total_atoms = np.prod(self.factor_sizes)\n",
        "        self.factor_bases = num_total_atoms / np.cumprod(self.factor_sizes)\n",
        "        feature_state_space_index = self._features_to_state_space_index(features)\n",
        "        if np.unique(feature_state_space_index).size != num_total_atoms:\n",
        "          raise ValueError(\"Features matrix does not cover the whole state space.\")\n",
        "        lookup_table = np.zeros(num_total_atoms, dtype=np.int64)\n",
        "        lookup_table[feature_state_space_index] = np.arange(num_total_atoms)\n",
        "        self.state_space_to_save_space_index = lookup_table\n",
        "\n",
        "      def features_to_index(self, features):\n",
        "        \"\"\"Returns the indices in the input space for given factor configurations.\n",
        "\n",
        "        Args:\n",
        "          features: Numpy matrix where each row contains a different factor\n",
        "            configuration for which the indices in the input space should be\n",
        "            returned.\n",
        "        \"\"\"\n",
        "        state_space_index = self._features_to_state_space_index(features)\n",
        "        return self.state_space_to_save_space_index[state_space_index]\n",
        "\n",
        "      def _features_to_state_space_index(self, features):\n",
        "        \"\"\"Returns the indices in the atom space for given factor configurations.\n",
        "\n",
        "        Args:\n",
        "          features: Numpy matrix where each row contains a different factor\n",
        "            configuration for which the indices in the atom space should be\n",
        "            returned.\n",
        "        \"\"\"\n",
        "        if (np.any(features > np.expand_dims(self.factor_sizes, 0)) or\n",
        "            np.any(features < 0)):\n",
        "          raise ValueError(\"Feature indices have to be within [0, factor_size-1]!\")\n",
        "        return np.array(np.dot(features, self.factor_bases), dtype=np.int64)\n",
        "\n",
        "    def _load_data():\n",
        "      dataset = np.zeros((24 * 4 * 183, 64, 64, 3))\n",
        "      all_files = [x for x in tf.io.gfile.listdir(CARS3D_PATH) if \".mat\" in x]\n",
        "      for i, filename in enumerate(all_files):\n",
        "        data_mesh = _load_mesh(filename)\n",
        "        factor1 = np.array(list(range(4)))\n",
        "        factor2 = np.array(list(range(24)))\n",
        "        all_factors = np.transpose([\n",
        "            np.tile(factor1, len(factor2)),\n",
        "            np.repeat(factor2, len(factor1)),\n",
        "            np.tile(i,\n",
        "                    len(factor1) * len(factor2))\n",
        "        ])\n",
        "        indexes = index.features_to_index(all_factors)\n",
        "        dataset[indexes] = data_mesh\n",
        "      return dataset\n",
        "\n",
        "\n",
        "    def _load_mesh(filename):\n",
        "      \"\"\"Parses a single source file and rescales contained images.\"\"\"\n",
        "      with open(os.path.join(CARS3D_PATH, filename), \"rb\") as f:\n",
        "        mesh = np.einsum(\"abcde->deabc\", sio.loadmat(f)[\"im\"])\n",
        "      flattened_mesh = mesh.reshape((-1,) + mesh.shape[2:])\n",
        "      rescaled_mesh = np.zeros((flattened_mesh.shape[0], 64, 64, 3))\n",
        "      for i in range(flattened_mesh.shape[0]):\n",
        "        pic = PIL.Image.fromarray(flattened_mesh[i, :, :, :])\n",
        "        # pic.thumbnail((64, 64, 3), PIL.Image.ANTIALIAS)\n",
        "        pic = pic.resize((64, 64), PIL.Image.ANTIALIAS)\n",
        "        rescaled_mesh[i, :, :, :] = np.array(pic)\n",
        "      return rescaled_mesh * 1. / 255\n",
        "\n",
        "\n",
        "    factor_sizes = [4, 24, 183]\n",
        "\n",
        "    latent_factor_indices = [0, 1, 2]\n",
        "\n",
        "    features = extmath.cartesian(\n",
        "            [np.array(list(range(i))) for i in factor_sizes])\n",
        "    index = StateSpaceAtomIndex(factor_sizes, features)\n",
        "\n",
        "    data_shape = [64, 64, 3]\n",
        "    images = _load_data()\n",
        "\n",
        "    np.random.shuffle(images)\n",
        "    return tf.data.Dataset.from_tensor_slices(images.astype(np.float32))\n",
        "  elif dataset_name == 'celebA':\n",
        "      images = np.load(os.path.join(data_dir, 'celebA/data.npy'))\n",
        "      dataset = tf.data.Dataset.from_tensor_slices(images)\n",
        "      return dataset.map(lambda img: tf.image.convert_image_dtype(img, tf.float32))"
      ],
      "metadata": {
        "id": "dhjYEF7GOboG",
        "cellView": "form"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_start = 7750\n",
        "model_end = 7755\n",
        "number_models = model_end - model_start\n",
        "\n",
        "number_bottleneck_channels = 10\n",
        "monte_carlo_number_random_samples = 20_00\n",
        "models_dir = 'trained_models/'\n",
        "dataset_name = 'smallnorb'\n",
        "\n",
        "dataset_sizes = {'dsprites': 737280,\n",
        "                 'cars3d': 17568,\n",
        "                 'smallnorb': 48600}\n",
        "\n",
        "ct = time.time()\n",
        "image_dataset = load_dataset(dataset_name, data_dir='Data/')\n",
        "print(f'Loaded {dataset_name}, took {time.time()-ct:.3f} sec')\n",
        "\n",
        "### Embed the full dataset\n",
        "embs_mus_all, embs_logvars_all = [[], []]\n",
        "ct = time.time()\n",
        "for model_num in range(model_start, model_end):\n",
        "  embed = hub.load(os.path.join(models_dir, str(model_num), 'model/tfhub'))\n",
        "\n",
        "  image_chunk_size = 1000\n",
        "  embs_mus, embs_logvars = [[], []]\n",
        "  for image_chunk in image_dataset.batch(image_chunk_size):\n",
        "    embs = embed.signatures['gaussian_encoder'](image_chunk)\n",
        "    embs_mus.append(embs['mean'])\n",
        "    embs_logvars.append(embs['logvar'])\n",
        "  embs_mus_all.append(np.concatenate(embs_mus, 0))\n",
        "  embs_logvars_all.append(np.concatenate(embs_logvars, 0))\n",
        "\n",
        "print(f'Embedded full dataset (number instances: {embs_mus_all[0].shape[0]}) for {number_models} models.  Took {time.time()-ct:.3f} sec.')\n",
        "### Now we have them, run through everything\n",
        "single_infos, double_infos, combined_infos = [[], [], []]\n",
        "nmis, vis = [[], []]\n",
        "nmi_errs, vi_errs = [[], []]\n",
        "ct = time.time()\n",
        "for model_num1 in range(number_models):\n",
        "  single_infos.append(\n",
        "      utils.monte_carlo_info(embs_mus_all[model_num1],\n",
        "                             embs_logvars_all[model_num1],\n",
        "                             number_random_samples=monte_carlo_number_random_samples))\n",
        "  double_infos.append(\n",
        "      utils.monte_carlo_info(np.tile(embs_mus_all[model_num1], [1, 2]),\n",
        "                             np.tile(embs_logvars_all[model_num1], [1, 2]),\n",
        "                             number_random_samples=monte_carlo_number_random_samples))\n",
        "  for model_num2 in range(model_num1+1, number_models):\n",
        "    combined_infos.append(\n",
        "        utils.monte_carlo_info(np.concatenate([embs_mus_all[model_num1], embs_mus_all[model_num2]], 1),\n",
        "                               np.concatenate([embs_logvars_all[model_num1], embs_logvars_all[model_num2]], 1),\n",
        "                               number_random_samples=monte_carlo_number_random_samples))\n",
        "print(f'Computed infos {int(number_models*(number_models+3)/2)} times.  Took {time.time()-ct:.3f} sec.')\n",
        "running_index = 0\n",
        "for model_num1 in range(number_models):\n",
        "  for model_num2 in range(model_num1+1, number_models):\n",
        "    i1 = single_infos[model_num1][0]\n",
        "    i2 = single_infos[model_num2][0]\n",
        "    i11 = double_infos[model_num1][0]\n",
        "    i22 = double_infos[model_num2][0]\n",
        "    i12 = combined_infos[running_index][0]\n",
        "    nmis.append((i1+i2-i12) / np.sqrt((2*i1-i11)*(2*i2-i22)))\n",
        "    vis.append(2*i12 - i11 - i22)\n",
        "\n",
        "    i11_err = double_infos[model_num1][1]\n",
        "    i1_err = single_infos[model_num1][1]\n",
        "    i22_err = double_infos[model_num2][1]\n",
        "    i2_err = single_infos[model_num2][1]\n",
        "    i12_err = combined_infos[running_index][1]\n",
        "    vi_errs.append(np.sqrt(4*i12_err**2 - i11_err**2 - i22_err**2))\n",
        "\n",
        "    partial11_sq = (i1+i2-i12)**2/4/(2*i1-i11)**3/(2*i2-i22)\n",
        "    partial1_sq = (i1+i12-i11-i2)**2/(2*i1-i11)**3/(2*i2-i22)\n",
        "    partial12_sq = 1/(2*i1-i11)/(2*i2-i22)\n",
        "    partial22_sq = (i1+i2-i12)**2/4/(2*i2-i22)**3/(2*i1-i11)\n",
        "    partial2_sq = (i2+i12-i22-i1)**2/(2*i2-i22)**3/(2*i1-i11)\n",
        "    combined_nmi_err = np.sqrt(partial11_sq*i11_err*22 + partial1_sq*i1_err**2 + partial12_sq*i12_err**2 + partial22_sq*i22_err**2 + partial2_sq*i2_err**2)\n",
        "    nmi_errs.append(combined_nmi_err)\n",
        "\n",
        "\n",
        "    running_index += 1\n",
        "\n",
        "nmis = np.array(nmis)\n",
        "nmi_errs = np.array(nmi_errs)\n",
        "vis = np.array(vis)\n",
        "vi_errs = np.array(vi_errs)\n",
        "single_infos_errs = np.array(single_infos)[:, 1]\n",
        "single_infos = np.array(single_infos)[:, 0]\n",
        "\n",
        "comb_info = np.sum(single_infos/single_infos_errs**2)/np.sum(1./single_infos_errs**2)\n",
        "comb_info_err = 1/np.sqrt(np.sum(1./single_infos_errs**2))\n",
        "\n",
        "comb_nmi = np.sum(nmis/nmi_errs**2)/np.sum(1./nmi_errs**2)\n",
        "comb_nmi_err = 1/np.sqrt(np.sum(1./nmi_errs**2))\n",
        "\n",
        "comb_vi = np.sum(vis/vi_errs**2)/np.sum(1./vi_errs**2)\n",
        "comb_vi_err = 1/np.sqrt(np.sum(1./vi_errs**2))\n",
        "\n",
        "print(f'I(U;X)/H(X): {comb_info/np.log2(dataset_sizes[dataset_name]):.3f} +- {comb_info_err/np.log2(dataset_sizes[dataset_name]):.3f} bits')\n",
        "print(f'NMI: {comb_nmi:.3f} +- {comb_nmi_err:.3f}')\n",
        "print(f'VI: {comb_vi:.3f} +- {comb_vi_err:.3f} bits')"
      ],
      "metadata": {
        "id": "yNL5cI0nKScC",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "bb1f66f2-b794-4a0e-9e3c-2b4763b23507"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Loaded smallnorb, took 9.365 sec\n",
            "Embedded full dataset (number instances: 48600) for 5 models.  Took 5.194 sec.\n",
            "Computed infos 20 times.  Took 12.526 sec.\n",
            "I(U;X)/H(X): 0.753 +- 0.002 bits\n",
            "NMI: 0.965 +- 0.018\n",
            "VI: 0.731 +- 0.028 bits\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "OxiEyef0yLei"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}