{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f2a1a0a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-02 15:59:46.455650: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9373] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-08-02 15:59:46.455711: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-08-02 15:59:46.457231: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1534] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-08-02 15:59:46.465121: I tensorflow/core/platform/cpu_feature_guard.cc:183] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "\n",
    "from tensorflow.keras.layers import Dense\n",
    "from tensorflow.keras import Model\n",
    "from tensorflow.keras.optimizers import Adam, SGD\n",
    "from tensorflow.keras.optimizers.schedules import ExponentialDecay\n",
    "\n",
    "from src.datasets import load_dataset, preprocess_dataset, prefetch_dataset\n",
    "from src.pmi_estimators import train_critic_model, neural_pmi\n",
    "from src.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c2c45042",
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = { 'dataset' : 'mnist',\n",
    "       'model' : 'mlp',\n",
    "       'batch_size' : 512,\n",
    "       'optimizer' : 'Adam',\n",
    "       'learning_rate' : 0.001,\n",
    "       'epoch' : 100,\n",
    "       }\n",
    "\n",
    "model_name = cfg['model']\n",
    "dataset_name = cfg['dataset']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "25f9a232",
   "metadata": {},
   "outputs": [],
   "source": [
    "def neural_pmi(x, y, pmi_model, estimator):\n",
    "    scores = pmi_model(x,y)\n",
    "    if estimator == 'probabilistic_classifier':\n",
    "        batch_size = scores.shape[0]\n",
    "        pmi = tf.linalg.diag_part(scores) + tf.math.log(tf.cast(batch_size - 1, dtype=tf.float32))\n",
    "    elif estimator == 'density_ratio_fitting':\n",
    "        pmi = tf.linalg.diag_part(tf.math.log(tf.maximum(scores, 1e-4)))\n",
    "    elif estimator == 'variational_f_js':\n",
    "        pmi = tf.linalg.diag_part(scores)\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Estimator ({estimator}) not supported.\")\n",
    "    if scores.shape[0] == 1:\n",
    "        pmi = scores.numpy()[0]\n",
    "    return np.array(pmi)\n",
    "\n",
    "##############################################################\n",
    "#\n",
    "# Critic architectures\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "def mlp_critic(input_dim, output_dim):\n",
    "    model = tf.keras.Sequential()\n",
    "    model.add(Dense(64, activation='relu', input_shape=(input_dim,)))\n",
    "    model.add(Dense(64, activation='relu', input_shape=(input_dim,)))\n",
    "    model.add(Dense(output_dim))\n",
    "    return model\n",
    "\n",
    "class SeparableCritic(Model):\n",
    "    # pass x to g and pass y to h --> f(x,y) = g(x)^T h(y) --> only require 2N forward passes\n",
    "    def __init__(self, dataset, output_dim=128, **extra_kwargs):\n",
    "        super(SeparableCritic, self).__init__()\n",
    "        dim_x = dataset.element_spec[0].shape[1]\n",
    "        dim_y = dataset.element_spec[1].shape[1]\n",
    "        self.output_dim = output_dim\n",
    "        self._g = mlp_critic(dim_x, self.output_dim)\n",
    "        self._h = mlp_critic(dim_y, self.output_dim)\n",
    "    def call(self, x, y):\n",
    "        g_output = self._g(x)\n",
    "        h_output = self._h(y)\n",
    "        scores = tf.matmul(h_output, tf.transpose(g_output))\n",
    "        return scores   # shape = (batch_size, batch_size)\n",
    "    def get_config(self):\n",
    "        config = super(SeparableCritic, self).get_config()\n",
    "        config.update({\n",
    "            'output_dim': self.output_dim,\n",
    "            'g': self._g.get_config(),\n",
    "            'h': self._h.get_config()\n",
    "        })\n",
    "        return config\n",
    "\n",
    "class ConcatCritic(Model):\n",
    "    # concatenate x and y --> require batch_size^2 forward passes\n",
    "    def __init__(self, dataset, **extra_kwargs):\n",
    "        super(ConcatCritic, self).__init__()\n",
    "        dim_x = dataset.element_spec[0].shape[1]\n",
    "        dim_y = dataset.element_spec[1].shape[1]\n",
    "        self._f = mlp_critic(dim_x+dim_y, 1)  # output is scalar score\n",
    "    def call(self, x, y):\n",
    "        # shape of x: (batch_size, dim_x)\n",
    "        # shape of y: (batch_size, dim_y)\n",
    "        batch_size = tf.shape(x)[0]\n",
    "        x_tiled = tf.tile(tf.expand_dims(x, axis=1), [1, batch_size, 1])    # shape = (batch_size, batch_size, dim_x)\n",
    "        y_tiled = tf.tile(tf.expand_dims(y, axis=0), [batch_size, 1, 1])    # shape = (batch_size, batch_size, dim_y)\n",
    "        xy_pairs = tf.concat([x_tiled, y_tiled], axis=-1)                   # shape = (batch_size, batch_size, dim_x+dim_y)\n",
    "        scores = self._f(tf.reshape(xy_pairs, [batch_size * batch_size, -1]))\n",
    "        return tf.reshape(scores, [batch_size, batch_size])                 # shape = (batch_size, batch_size)\n",
    "    def get_config(self):\n",
    "        config = super(ConcatCritic, self).get_config()\n",
    "        config['f'] = self._f\n",
    "        return config\n",
    "\n",
    "##############################################################\n",
    "#\n",
    "# Training Objectives\n",
    "#\n",
    "# #############################################################\n",
    "\n",
    "# Method 1: Probabilistic Classifier\n",
    "@tf.function\n",
    "def probabilistic_classifier_obj(score):\n",
    "    criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
    "    batch_size = score.shape[0]\n",
    "    labels = [0.]*(batch_size*batch_size)\n",
    "    labels[::(batch_size+1)] = [1.]*batch_size      # assign label 0 to samples drawn from product of marginals and label 1 to samples drawn from joint density\n",
    "    labels = tf.convert_to_tensor(labels, dtype=score.dtype)\n",
    "    labels = tf.reshape(labels, (-1, 1))\n",
    "    logits = tf.reshape(score, (-1, 1))\n",
    "    loss = -1.*criterion(labels, logits)\n",
    "    return loss\n",
    "\n",
    "# Method 2: Density Ratio Fitting\n",
    "@tf.function\n",
    "def density_ratio_fitting_obj(score):\n",
    "    score_square = tf.square(score)\n",
    "    batch_size = score.shape[0]  # batch_size\n",
    "    joint_term = tf.reduce_mean(tf.linalg.diag_part(score))\n",
    "    marg_term = ((tf.reduce_sum(score_square) - tf.reduce_sum(tf.linalg.diag_part(score_square))) / (batch_size*(batch_size-1.)))\n",
    "    return joint_term - 0.5*marg_term\n",
    "\n",
    "# Method 3: Variational Representation of f-divergence (JS)\n",
    "@tf.function\n",
    "def js_fgan_lower_bound_obj(score):\n",
    "    score_diag = tf.linalg.diag_part(score)\n",
    "    first_term = -tf.reduce_mean(tf.nn.softplus(-score_diag))\n",
    "    batch_size = score.shape[0]\n",
    "    second_term = (tf.reduce_sum(tf.nn.softplus(score)) - tf.reduce_sum(tf.nn.softplus(score_diag))) / (batch_size * (batch_size - 1.))\n",
    "    return first_term - second_term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66399737",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-02 15:59:50.191806: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1926] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 77259 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:47:00.0, compute capability: 8.0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, probabilistic_classifier\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:   0%|          | 0/100 [00:00<?, ?it/s]2024-08-02 15:59:53.756148: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f4b3f1a6f80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
      "2024-08-02 15:59:53.756194: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-80GB, Compute Capability 8.0\n",
      "2024-08-02 15:59:53.761898: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
      "2024-08-02 15:59:53.804579: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:467] Loaded cuDNN version 90100\n",
      "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
      "I0000 00:00:1722614393.897009  868334 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.\n",
      "Epochs: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, probabilistic_classifier\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, probabilistic_classifier\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, probabilistic_classifier\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:53<00:00,  1.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, probabilistic_classifier\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:54<00:00,  1.15s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, density_ratio_fitting\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, density_ratio_fitting\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, density_ratio_fitting\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:51<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, density_ratio_fitting\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:50<00:00,  1.11s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, density_ratio_fitting\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:54<00:00,  1.15s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, variational_f_js\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:52<00:00,  1.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, variational_f_js\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:53<00:00,  1.14s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, variational_f_js\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:56<00:00,  1.16s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, variational_f_js\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs: 100%|██████████| 100/100 [01:53<00:00,  1.13s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doing for concat, variational_f_js\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epochs:  95%|█████████▌| 95/100 [01:48<00:05,  1.17s/it]"
     ]
    }
   ],
   "source": [
    "critic_list = ['concat','separable']\n",
    "estimators_list = ['probabilistic_classifier', 'density_ratio_fitting', 'variational_f_js']\n",
    "results = {}\n",
    "\n",
    "ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "n_classes = ds_info.features['label'].num_classes\n",
    "ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=False, normalize=True, onehot=True)\n",
    "    \n",
    "for critic in critic_list:\n",
    "    for estimator in estimators_list:\n",
    "        results[f'{critic}_{estimator}'] = []\n",
    "        for run in range(5):\n",
    "            model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/saved_models/trained_model.keras')\n",
    "            int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-2].output)\n",
    "            ds_activity_trn = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)\n",
    "            ds_activity_val = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.one_hot(tf.argmax(model(x), axis=-1), depth=n_classes)))\n",
    "        \n",
    "            print(f'Doing for {critic}, {estimator}')\n",
    "            if critic == 'concat':\n",
    "                pmi_model = ConcatCritic(ds_activity_trn)\n",
    "            elif critic == 'separable':\n",
    "                pmi_model = SeparableCritic(ds_activity_trn)\n",
    "\n",
    "            if estimator =='probabilistic_classifier':\n",
    "                loss_fn = probabilistic_classifier_obj\n",
    "            elif estimator == 'density_ratio_fitting':\n",
    "                loss_fn = density_ratio_fitting_obj\n",
    "            elif estimator =='variational_f_js':\n",
    "                loss_fn = js_fgan_lower_bound_obj\n",
    "\n",
    "            optimizer = Adam(learning_rate=0.001)\n",
    "\n",
    "            @tf.function\n",
    "            def train_step(x, y, model, optimizer, loss_fn):\n",
    "                with tf.GradientTape() as tape:\n",
    "                    scores = model(x,y)\n",
    "                    loss_value = -loss_fn(scores)\n",
    "                grads = tape.gradient(loss_value, model.trainable_weights)\n",
    "                optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "                return -loss_value\n",
    "            \n",
    "            mi = []\n",
    "            for epoch in tqdm(range(100), desc='Epochs'):\n",
    "                for step, (x_batch, y_batch) in enumerate(ds_activity_trn):\n",
    "                    negative_loss = train_step(x_batch, y_batch, pmi_model, optimizer, loss_fn)\n",
    "                pmi_list = []\n",
    "                for x_batch, y_batch in ds_activity_val:\n",
    "                    pmi_list.extend(neural_pmi(x_batch, y_batch, pmi_model, estimator=estimator).tolist())\n",
    "                mi.append(np.mean(pmi_list))\n",
    "            results[f'{critic}_{estimator}'].append(np.array(mi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae98583",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(10, 3))\n",
    "\n",
    "axes[0].plot(range(1, 101), np.mean(results['concat_probabilistic_classifier'], axis=0), label='probabilistic classifier')\n",
    "axes[0].fill_between(range(1, 101),\n",
    "                     np.mean(results['concat_probabilistic_classifier'], axis=0) - np.std(results['concat_probabilistic_classifier'], axis=0),\n",
    "                     np.mean(results['concat_probabilistic_classifier'], axis=0) + np.std(results['concat_probabilistic_classifier'], axis=0),\n",
    "                     alpha=0.5)\n",
    "axes[0].plot(range(1, 101), np.mean(results['concat_density_ratio_fitting'], axis=0), label='density ratio fitting')\n",
    "axes[0].fill_between(range(1, 101),\n",
    "                     np.mean(results['concat_density_ratio_fitting'], axis=0) - np.std(results['concat_density_ratio_fitting'], axis=0),\n",
    "                     np.mean(results['concat_density_ratio_fitting'], axis=0) + np.std(results['concat_density_ratio_fitting'], axis=0),\n",
    "                     alpha=0.5)\n",
    "axes[0].plot(range(1, 101), np.mean(results['concat_variational_f_js'], axis=0), label='variational JS bound')\n",
    "axes[0].fill_between(range(1, 101),\n",
    "                     np.mean(results['concat_variational_f_js'], axis=0) - np.std(results['concat_variational_f_js'], axis=0),\n",
    "                     np.mean(results['concat_variational_f_js'], axis=0) + np.std(results['concat_variational_f_js'], axis=0),\n",
    "                     alpha=0.5)\n",
    "\n",
    "axes[0].set_title('Joint Critic')\n",
    "axes[0].set_xlabel('Epochs', fontsize=11)\n",
    "axes[0].set_ylabel('I (T;Y)', fontsize=11)\n",
    "axes[0].grid()\n",
    "axes[0].legend()\n",
    "\n",
    "\n",
    "axes[1].plot(range(1, 101), np.mean(results['separable_probabilistic_classifier'], axis=0), label='probabilistic classifier')\n",
    "axes[1].fill_between(range(1, 101),\n",
    "                     np.mean(results['separable_probabilistic_classifier'], axis=0) - np.std(results['separable_probabilistic_classifier'], axis=0),\n",
    "                     np.mean(results['separable_probabilistic_classifier'], axis=0) + np.std(results['separable_probabilistic_classifier'], axis=0),\n",
    "                     alpha=0.5)\n",
    "axes[1].plot(range(1, 101), np.mean(results['separable_density_ratio_fitting'], axis=0), label='density ratio fitting')\n",
    "axes[1].fill_between(range(1, 101),\n",
    "                     np.mean(results['separable_density_ratio_fitting'], axis=0) - np.std(results['separable_density_ratio_fitting'], axis=0),\n",
    "                     np.mean(results['separable_density_ratio_fitting'], axis=0) + np.std(results['separable_density_ratio_fitting'], axis=0),\n",
    "                     alpha=0.5)\n",
    "axes[1].plot(range(1, 101), np.mean(results['separable_variational_f_js'], axis=0), label='variational JS bound')\n",
    "axes[1].fill_between(range(1, 101),\n",
    "                     np.mean(results['separable_variational_f_js'], axis=0) - np.std(results['separable_variational_f_js'], axis=0),\n",
    "                     np.mean(results['separable_variational_f_js'], axis=0) + np.std(results['separable_variational_f_js'], axis=0),\n",
    "                     alpha=0.5)\n",
    "\n",
    "axes[1].set_title('Separable Critic')\n",
    "axes[1].set_xlabel('Epochs', fontsize=11)\n",
    "axes[1].grid()\n",
    "axes[1].legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb7a19de",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'../results/PI_Explainability/{model_name}_{dataset_name}/pmi_sanity_check.pickle', 'wb') as f:\n",
    "    pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
