{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VSJOwNIG57pB",
    "outputId": "91e17eb0-6756-493e-b376-f4531751481d"
   },
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ld7OUnHi4ZHt",
    "outputId": "06063379-9c83-4542-940b-e69a923223a0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Num GPUs Available:  1\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import math\n",
    "import scipy\n",
    "import scipy.signal\n",
    "import scipy.stats\n",
    "from scipy.stats import norm\n",
    "from scipy.stats import laplace\n",
    "from scipy import integrate\n",
    "import argparse\n",
    "import numpy as np\n",
    "import pickle\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "\n",
    "# local libs\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import dpdl_privacy\n",
    "from train_tf import *\n",
    "\n",
    "print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))\n",
    "\n",
    "# gpu_devices = tf.config.experimental.list_physical_devices('GPU')\n",
    "# for device in gpu_devices: \n",
    "#     tf.config.experimental.set_memory_growth(device, True)\n",
    "    \n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "#%matplotlib widget\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import integrate\n",
    "matplotlib.rcParams['figure.dpi'] = 100  #Set quality of figures shown inline\n",
    "\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "JLdim_list=[1,5,10,20,30]\n",
    "JLdim_list_privacy=[5,10,20,30]\n",
    "epochs=15 #number of epochs of training\n",
    "\n",
    "batch_size=256 #batch_size should divide num_samples exactly\n",
    "\n",
    "sigma=1 #noise multiplier\n",
    "clipping_norm = 1\n",
    "delta=1e-5\n",
    "\n",
    "\n",
    "###################################################\n",
    "\n",
    "mesh_size=1e-4 # Controls accuracy of privacy calculations. If set it to 'None', it will be chosen automatically\n",
    "\n",
    "max_eps=100\n",
    "min_delta=10**(-10)\n",
    "precision=10**(-15)\n",
    "\n",
    "# Hyperparameters for Adam Optimizer\n",
    "lr = 0.15# learning rate\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Qj3tJcqV01JF"
   },
   "source": [
    "## Load model and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of samples:  60000\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Training a CNN on MNIST with Keras and the DP SGD optimizer.\"\"\"\n",
    "\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "\n",
    "from absl import app\n",
    "from absl import flags\n",
    "from absl import logging\n",
    "\n",
    "from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp\n",
    "from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent\n",
    "from tensorflow_privacy.privacy.optimizers.dp_optimizer_keras import DPKerasSGDOptimizer\n",
    "\n",
    "def load_mnist():\n",
    "    \"\"\"Loads MNIST and preprocesses to combine training and validation data.\"\"\"\n",
    "    train, test = tf.keras.datasets.mnist.load_data()\n",
    "    train_data, train_labels = train\n",
    "    test_data, test_labels = test\n",
    "\n",
    "    train_data = np.array(train_data, dtype=np.float32) / 255\n",
    "    test_data = np.array(test_data, dtype=np.float32) / 255\n",
    "\n",
    "    train_data = train_data.reshape((train_data.shape[0], 28, 28, 1))\n",
    "    test_data = test_data.reshape((test_data.shape[0], 28, 28, 1))\n",
    "\n",
    "    train_labels = np.array(train_labels, dtype=np.int32)\n",
    "    test_labels = np.array(test_labels, dtype=np.int32)\n",
    "\n",
    "    train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=10)\n",
    "    test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=10)\n",
    "\n",
    "    assert train_data.min() == 0.\n",
    "    assert train_data.max() == 1.\n",
    "    assert test_data.min() == 0.\n",
    "    assert test_data.max() == 1.\n",
    "\n",
    "    return train_data, train_labels, test_data, test_labels\n",
    "\n",
    "\n",
    "\n",
    "# logging.set_verbosity(logging.INFO)\n",
    "\n",
    "\n",
    "# Load training and test data.\n",
    "train_data, train_labels, test_data, test_labels = load_mnist()\n",
    "\n",
    "train_ds = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).shuffle(5000).batch(batch_size)\n",
    "test_ds = tf.data.Dataset.from_tensor_slices((test_data, test_labels)).batch(10000)\n",
    "\n",
    "# Define a sequential Keras model\n",
    "\n",
    "def get_MNIST_model():\n",
    "    model = tf.keras.Sequential([\n",
    "      tf.keras.layers.Conv2D(16, 8,\n",
    "                             strides=2,\n",
    "                             padding='same',\n",
    "                             activation='relu',\n",
    "                             input_shape=(28, 28, 1)),\n",
    "      tf.keras.layers.MaxPool2D(2, 1),\n",
    "      tf.keras.layers.Conv2D(32, 4,\n",
    "                             strides=2,\n",
    "                             padding='valid',\n",
    "                             activation='relu'),\n",
    "      tf.keras.layers.MaxPool2D(2, 1),\n",
    "      tf.keras.layers.Flatten(),\n",
    "      tf.keras.layers.Dense(32, activation='relu'),\n",
    "      tf.keras.layers.Dense(10)\n",
    "    ])\n",
    "    return model\n",
    "\n",
    "\n",
    "num_samples=len(train_data)\n",
    "print('Number of samples: ',num_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "conv2d (Conv2D)              (None, 14, 14, 16)        1040      \n",
      "_________________________________________________________________\n",
      "max_pooling2d (MaxPooling2D) (None, 13, 13, 16)        0         \n",
      "_________________________________________________________________\n",
      "conv2d_1 (Conv2D)            (None, 5, 5, 32)          8224      \n",
      "_________________________________________________________________\n",
      "max_pooling2d_1 (MaxPooling2 (None, 4, 4, 32)          0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 512)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 32)                16416     \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                330       \n",
      "=================================================================\n",
      "Total params: 26,010\n",
      "Trainable params: 26,010\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "get_MNIST_model().summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Packaging privacy parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "privacy_args=dpdl_privacy.privacy_params(num_JL=None,epochs=epochs,sigma=sigma,num_samples=num_samples,batch_size=batch_size,delta=delta,mesh_size=mesh_size,min_delta=min_delta,precision=precision,max_eps=max_eps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bPJrfEw35sei"
   },
   "source": [
    "# Non-DP training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KGCwypcF9Wrs"
   },
   "source": [
    "## Direct implementation using tape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "Bm6qY6ls5Qwx",
    "outputId": "05be29d6-3d5a-4d8f-8a47-34d3a6a6e9fc"
   },
   "outputs": [],
   "source": [
    "# model=get_MNIST_model()\n",
    "# loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)\n",
    "# train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "# train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
    "\n",
    "# ## Process Parameters\n",
    "# n_var = len(model.trainable_variables)\n",
    "# total_time=0\n",
    "\n",
    "# ## Run SGD to train\n",
    "# for epoch in range(epochs):\n",
    "#     for images, labels in train_ds:\n",
    "\n",
    "#         n_images = images.shape[0]\n",
    "#         start_time=time.time()\n",
    "\n",
    "#         # Compute the gradient\n",
    "#         with tf.GradientTape() as tape:\n",
    "#             predictions = model(images, training=True)\n",
    "#             #print('predictions',predictions.shape)\n",
    "#             loss_ind = loss_obj(labels, predictions)\n",
    "#             loss = tf.reduce_mean(loss_ind)\n",
    "#         gradients = tape.gradient(loss, model.trainable_variables)\n",
    "\n",
    "#         # Run the gradient descent\n",
    "#         k = 0\n",
    "#         for p in model.trainable_variables:\n",
    "#             p.assign(p - lr * gradients[k])\n",
    "#             k = k + 1\n",
    "#         total_time+=time.time()-start_time\n",
    "        \n",
    "#     # Record the result\n",
    "#     for images, labels in test_ds:\n",
    "#         predictions = model(images, training=False)\n",
    "#         loss = loss_obj(labels, predictions)\n",
    "#         train_loss(loss)\n",
    "#         train_accuracy(labels, predictions)\n",
    "                    \n",
    "#     print('Epoch %i, Loss: %f, Accuracy: %f' % \n",
    "#           (epoch + 1, train_loss.result(), train_accuracy.result() * 100))\n",
    "    \n",
    "#     train_loss.reset_states()\n",
    "#     train_accuracy.reset_states()\n",
    "# print('Average Per epoch time:', total_time/epochs)\n",
    "# tf.keras.backend.clear_session()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wl2v0iB29MdN"
   },
   "source": [
    "## Implementation using keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MyModel = get_MNIST_model()\n",
    "optimizer = tf.keras.optimizers.SGD(learning_rate=lr)\n",
    "loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
    "\n",
    "# Compile MyModel with Keras\n",
    "MyModel.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])\n",
    "\n",
    "# Train MyModel with Keras\n",
    "start_time = time.time()\n",
    "training_metrics=MyModel.fit(train_data, train_labels,\n",
    "        epochs=epochs,\n",
    "        validation_data=(test_data, test_labels),\n",
    "        batch_size=batch_size).history\n",
    "\n",
    "total_time = time.time()-start_time\n",
    "accuracy_list = training_metrics['val_accuracy']\n",
    "epochtime_list = epochs*[total_time/epochs]\n",
    "\n",
    "print(epochtime_list,accuracy_list)\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "record_dict = {'accuracy':accuracy_list,'time':epochtime_list}\n",
    "with open(f'pickles/nonDP_MNIST{batch_size}.pickle','wb') as handle:\n",
    "    pickle.dump(record_dict,handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DPVanilla"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Privacy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "privacy_args.num_JL=None\n",
    "privacy_engine=dpdl_privacy.privacy_engine_dpdl(privacy_args,verbose=True,interpolate_flag=True)\n",
    "\n",
    "\n",
    "epsilon_list=[]\n",
    "for i in range(epochs):\n",
    "    privacy_engine.add_epochs(1)\n",
    "    eps=privacy_engine.calculate_eps()\n",
    "    print(eps)\n",
    "    epsilon_list.append(eps)\n",
    "\n",
    "plt.plot(np.arange(1,1+epochs),epsilon_list)\n",
    "plt.xlabel('epochs')\n",
    "plt.ylabel('eps')\n",
    "plt.show()\n",
    "\n",
    "record_dict = {'privacy_params':privacy_args,'epsilon':epsilon_list}\n",
    "with open(f'pickles/DPVanilla_MNIST_epsilons{batch_size}.pickle','wb') as handle:\n",
    "    pickle.dump(record_dict,handle)\n",
    "\n",
    "privacy_engine.plot_eps_delta_curve(max_eps=10,min_delta=10**(-10),CLT_approx_DPSGD=True,MA=True)\n",
    "privacy_engine.print_stdQ_stats()\n",
    "print('1-total mass of stdQ_pdf:', 1-np.sum(privacy_engine.stdQ_pdf))\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tensorflow implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# privacy_args.num_JL=None\n",
    "# privacy_engine=dpdl_privacy.privacy_engine_dpdl(privacy_args,verbose=True,interpolate_flag=True)\n",
    "\n",
    "if batch_size<=300: #The implementation of DPSGD in Tensorflow Privacy is crashing for batch size 512. So we only run this part of the code for small batch sizes. For large batch sizes, we use gradient accumulation.\n",
    "\n",
    "    MyModel=get_MNIST_model()\n",
    "    optimizer = DPKerasSGDOptimizer(\n",
    "        l2_norm_clip=clipping_norm,\n",
    "        noise_multiplier=sigma,\n",
    "        num_microbatches=batch_size,\n",
    "        learning_rate=lr)\n",
    "    # Compute vector of per-example loss rather than its mean over a minibatch.\n",
    "    loss = tf.keras.losses.CategoricalCrossentropy(\n",
    "        from_logits=True, reduction=tf.losses.Reduction.NONE)\n",
    "\n",
    "    # Compile MyModel with Keras\n",
    "    MyModel.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])\n",
    "\n",
    "    # Train MyModel with Keras\n",
    "    start_time = time.time()\n",
    "    training_metrics=MyModel.fit(train_data, train_labels,\n",
    "            epochs=epochs,\n",
    "            validation_data=(test_data, test_labels),\n",
    "            batch_size=batch_size).history\n",
    "\n",
    "    total_time = time.time()-start_time\n",
    "    accuracy_list = training_metrics['val_accuracy']\n",
    "    epochtime_list = epochs*[total_time/epochs]\n",
    "\n",
    "    print(epochtime_list,accuracy_list)\n",
    "    tf.keras.backend.clear_session()\n",
    "\n",
    "    record_dict = {'accuracy':accuracy_list,'time':epochtime_list}\n",
    "    with open(f'pickles/DPVanilla_MNIST{batch_size}.pickle','wb') as handle:\n",
    "        pickle.dump(record_dict,handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradient accumulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model=get_MNIST_model()\n",
    "loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)\n",
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')\n",
    "\n",
    "## Process Parameters\n",
    "n_var = len(model.trainable_variables)\n",
    "total_time=0\n",
    "\n",
    "epochtime_list=[]\n",
    "accuracy_list=[]\n",
    "\n",
    "## Run SGD to train\n",
    "for epoch in range(epochs):\n",
    "    start_time=time.time()\n",
    "    print('Starting epoch ',epoch+1)\n",
    "    for images, labels in tqdm(train_ds):\n",
    "        \n",
    "        images = tf.split(images,images.shape[0])\n",
    "        labels = tf.split(labels,labels.shape[0])\n",
    "        \n",
    "        accumulated_gradient = [tf.zeros(w.shape) for w in model.trainable_variables]\n",
    "\n",
    "        \n",
    "        for i in range(len(images)):\n",
    "            \n",
    "            \n",
    "            # Compute the gradient on each sample, clip it and accumulate it\n",
    "            with tf.GradientTape() as tape:\n",
    "                predictions = model(images[i], training=True)\n",
    "                #print('predictions',predictions.shape)\n",
    "                loss_ind = loss_obj(labels[i], predictions)\n",
    "                #loss = tf.reduce_mean(loss_ind)\n",
    "            gradient = tape.gradient(loss_ind, model.trainable_variables)\n",
    "            gradient_norm = tf.linalg.global_norm(gradient)\n",
    "            scale = min(1,clipping_norm/gradient_norm)/batch_size\n",
    "            accumulated_gradient = [a+b*scale for a,b in zip(accumulated_gradient,gradient)]\n",
    "\n",
    "        # Run the gradient descent\n",
    "\n",
    "        for i,p in enumerate(model.trainable_variables):\n",
    "            p.assign(p - lr * accumulated_gradient[i])\n",
    "    \n",
    "    epochtime = time.time()-start_time\n",
    "    epochtime_list.append(epochtime)\n",
    "    # Record the result\n",
    "    for images, labels in test_ds:\n",
    "        predictions = model(images, training=False)\n",
    "        loss = loss_obj(labels, predictions)\n",
    "        test_loss(loss)\n",
    "        test_accuracy(labels, predictions)\n",
    "                    \n",
    "    print('Epoch %i, Loss: %f, Accuracy: %f' % \n",
    "          (epoch + 1, test_loss.result(), test_accuracy.result() * 100))\n",
    "    \n",
    "    accuracy_list.append(test_accuracy.result().numpy())\n",
    "    test_loss.reset_states()\n",
    "    test_accuracy.reset_states()\n",
    "\n",
    "tf.keras.backend.clear_session()\n",
    "\n",
    "record_dict = {'accuracy':accuracy_list,'time':epochtime_list}\n",
    "with open(f'pickles/DPVanillaGA_MNIST{batch_size}.pickle','wb') as handle:\n",
    "    pickle.dump(record_dict,handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1M0DqJKq7Jrz"
   },
   "source": [
    "# DPJL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Privacy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for JL_dim in JLdim_list_privacy:\n",
    "\n",
    "#     tf.random.set_seed(1)\n",
    "#     np.random.seed(1)\n",
    "    \n",
    "    privacy_args.num_JL=JL_dim\n",
    "    print('JL_dim:',privacy_args.num_JL)\n",
    "    privacy_engine=dpdl_privacy.privacy_engine_dpdl(privacy_args,verbose=True,interpolate_flag=True)\n",
    "\n",
    "    \n",
    "\n",
    "    \n",
    "    epsilon_list=[]\n",
    "\n",
    "    for i in range(epochs):\n",
    "        \n",
    "        print(f'Starting epoch {i+1}....')\n",
    "        privacy_engine.add_epochs(1)\n",
    "        eps=privacy_engine.calculate_eps()\n",
    "        print('epsilon: ',eps)\n",
    "        epsilon_list.append(eps)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "    record_dict = {'epsilon':epsilon_list,'privacy_params':privacy_args}\n",
    "    with open(f'pickles/DPJL{JL_dim}_MNIST_epsilons{batch_size}.pickle','wb') as handle:\n",
    "        pickle.dump(record_dict,handle)\n",
    "        \n",
    "\n",
    "\n",
    "    plt.plot(range(1,epochs+1),epsilon_list)\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('eps')\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    privacy_engine.plot_eps_delta_curve(max_eps=50,min_delta=10**(-10))\n",
    "    privacy_engine.print_stdQ_stats()\n",
    "    print('1-total mass of stdQ_pdf:', 1-np.sum(privacy_engine.stdQ_pdf))\n",
    "    \n",
    "    \n",
    "        \n",
    "    \n",
    "    print('\\n\\n\\n\\n\\n')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bxY6ZNin5XIa"
   },
   "outputs": [],
   "source": [
    "\n",
    "for JL_dim in JLdim_list:\n",
    "\n",
    "#     tf.random.set_seed(1)\n",
    "#     np.random.seed(1)\n",
    "\n",
    "\n",
    "    \n",
    "    privacy_args.num_JL=JL_dim\n",
    "    print('JL_dim:',privacy_args.num_JL)\n",
    "#     privacy_engine=dpdl_privacy.privacy_engine_dpdl(privacy_args,verbose=True,interpolate_flag=True)\n",
    "\n",
    "    \n",
    "\n",
    "    model = get_MNIST_model()#(train_ds)\n",
    "    n_var = len(model.trainable_variables)\n",
    "\n",
    "\n",
    "    loss_obj = tf.keras.losses.CategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)\n",
    "    train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "    train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')\n",
    "    \n",
    "    accuracy_list=[]\n",
    "    epochtime_list=[]\n",
    "    epochJLtime_list=[]\n",
    "    #epsilon_list=[]\n",
    "\n",
    "    step=0\n",
    "    for i in range(epochs):\n",
    "        \n",
    "        print(f'Starting epoch {i+1}....')\n",
    "#         privacy_engine.add_epochs(1)\n",
    "#         eps=privacy_engine.calculate_eps()\n",
    "#         print('epsilon: ',eps)\n",
    "#         epsilon_list.append(eps)\n",
    "        \n",
    "        \n",
    "        \n",
    "        epochJL_time=0\n",
    "        epoch_start_time=time.time()\n",
    "        start_epoch_time=time.time()\n",
    "        for images, labels in tqdm(train_ds):\n",
    "            n_images = images.shape[0]\n",
    "\n",
    "            # Compute the mask\n",
    "            JL_time=time.time()\n",
    "            grad_norm = approximate_grad_norm(images, labels, model,loss_obj, JL_dim)\n",
    "            epochJL_time+=(time.time()-JL_time)\n",
    "            mask = tf.math.minimum(1.0, clipping_norm / grad_norm)\n",
    "\n",
    "\n",
    "            # Compute the gradient\n",
    "            with tf.GradientTape() as tape:\n",
    "                predictions = model(images, training=True)\n",
    "                loss_ind = loss_obj(labels, predictions)\n",
    "                loss = tf.reduce_sum(mask * loss_ind) / batch_size\n",
    "            gradients = tape.gradient(loss, model.trainable_variables)\n",
    "\n",
    "            # Add the noise into the gradient\n",
    "            if sigma > 0.0:\n",
    "                noise_level = sigma * clipping_norm/ batch_size\n",
    "                for k in range(n_var):\n",
    "                    gradients[k] += noise_level * (tf.random.normal(gradients[k].shape, dtype=tf.float32))\n",
    "\n",
    "            # Run the gradient descent\n",
    "            k = 0\n",
    "            for p in model.trainable_variables:\n",
    "                p.assign(p - lr * gradients[k])\n",
    "                k = k + 1\n",
    "\n",
    "\n",
    "        epoch_time = time.time()-epoch_start_time\n",
    "        epochtime_list.append(epoch_time)\n",
    "        epochJLtime_list.append(epochJL_time)\n",
    "        # Record the result\n",
    "        for images, labels in test_ds:\n",
    "            predictions = model(images, training=False)\n",
    "            loss = loss_obj(labels, predictions)\n",
    "            train_loss(loss)\n",
    "            train_accuracy(labels, predictions)\n",
    "\n",
    "\n",
    "        print('Epoch %i, Loss: %f, Accuracy: %f' % \n",
    "              (i + 1, train_loss.result(), train_accuracy.result() * 100))\n",
    "        accuracy_list.append(train_accuracy.result().numpy())\n",
    "\n",
    "\n",
    "        train_loss.reset_states()\n",
    "        train_accuracy.reset_states()\n",
    "        print('Per epoch time:', epoch_time)\n",
    "        print('Total per epoch JL time: ',epochJL_time)\n",
    "        print('Total per epoch JL time per JL projection: ',epochJL_time/JL_dim)\n",
    "        print('\\n')\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "    tf.keras.backend.clear_session()\n",
    "    \n",
    "    record_dict = {'accuracy':accuracy_list,'time':epochtime_list,\n",
    "                   'JLtime':epochJLtime_list,'privacy_params':privacy_args}\n",
    "    with open(f'pickles/DPJL{JL_dim}_MNIST{batch_size}.pickle','wb') as handle:\n",
    "        pickle.dump(record_dict,handle)\n",
    "        \n",
    "        \n",
    "\n",
    "    print('\\n\\n\\n\\n\\n')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "DPJLembedding.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc-autonumbering": true,
  "toc-showcode": true,
  "toc-showmarkdowntxt": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
