{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "executionInfo": {
     "elapsed": 2215,
     "status": "ok",
     "timestamp": 1605804223736,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "hnL7BWDOf1U8"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "data = tf.keras.datasets.mnist\n",
    "import numpy as np\n",
    "import copy\n",
    "import time\n",
    "import tqdm\n",
    "from tensorflow.keras import layers\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2971,
     "status": "ok",
     "timestamp": 1605804231090,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "kM7NxPkoEb-Y",
    "outputId": "87d64766-08c3-49fb-ddf0-e94a98856b95"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
      "11493376/11490434 [==============================] - 0s 0us/step\n"
     ]
    }
   ],
   "source": [
    "data = tf.keras.datasets.mnist\n",
    "(train_images, train_labels), (test_images, test_labels) = data.load_data()\n",
    "train_images = train_images / 255.0\n",
    "test_images = test_images / 255.0\n",
    "\n",
    "X2 = np.zeros_like(train_images)\n",
    "X3 = np.zeros_like(train_images)\n",
    "X4 = np.zeros_like(train_images)\n",
    "X5 = np.zeros_like(train_images)\n",
    "X2[:,1:,1:]=train_images[:,:-1,:-1]\n",
    "X3[:,1:,:-1]=train_images[:,:-1,1:]\n",
    "X4[:,:-1,1:]=train_images[:,1:,:-1]\n",
    "X5[:,:-1,:-1]=train_images[:,1:,1:]\n",
    "\n",
    "Train_images= np.concatenate([train_images, X2, X3, X4, X5], 0)\n",
    "Train_labels =  np.concatenate([train_labels,train_labels,train_labels,train_labels,train_labels])\n",
    "\n",
    "\n",
    "Train_images=Train_images.reshape(-1,28,28,1)\n",
    "Test_images =test_images.reshape(-1,28,28,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bCh8RhY-DqpI"
   },
   "source": [
    "# $\\epsilon=10^{-1}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "executionInfo": {
     "elapsed": 913,
     "status": "ok",
     "timestamp": 1605804235567,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "rfmc55IkDZLI"
   },
   "outputs": [],
   "source": [
    "eps=1e-1\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 212208,
     "status": "ok",
     "timestamp": 1605804449471,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "AXb2W2jqDoaY",
    "outputId": "f7c0a19f-3796-44a8-cd21-e58a3b345a5a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2254 - accuracy: 0.9638 - val_loss: 0.0386 - val_accuracy: 0.9932\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0292 - accuracy: 0.9949 - val_loss: 0.0222 - val_accuracy: 0.9941\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0166 - accuracy: 0.9961 - val_loss: 0.0192 - val_accuracy: 0.9944\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0113 - accuracy: 0.9972 - val_loss: 0.0399 - val_accuracy: 0.9893\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0089 - accuracy: 0.9976 - val_loss: 0.0182 - val_accuracy: 0.9956\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 28ms/step - loss: 0.0024 - accuracy: 0.9995 - val_loss: 0.0143 - val_accuracy: 0.9961\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0013 - accuracy: 0.9998 - val_loss: 0.0136 - val_accuracy: 0.9965\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0010 - accuracy: 0.9999 - val_loss: 0.0137 - val_accuracy: 0.9965\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 7.9787e-04 - accuracy: 0.9999 - val_loss: 0.0139 - val_accuracy: 0.9964\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 6.6254e-04 - accuracy: 0.9999 - val_loss: 0.0144 - val_accuracy: 0.9960\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 5.4681e-04 - accuracy: 1.0000 - val_loss: 0.0142 - val_accuracy: 0.9961\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 4.5972e-04 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9962\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.8430e-04 - accuracy: 1.0000 - val_loss: 0.0147 - val_accuracy: 0.9961\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.2573e-04 - accuracy: 1.0000 - val_loss: 0.0150 - val_accuracy: 0.9959\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.7919e-04 - accuracy: 1.0000 - val_loss: 0.0150 - val_accuracy: 0.9961\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.4069e-04 - accuracy: 1.0000 - val_loss: 0.0153 - val_accuracy: 0.9959\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.0358e-04 - accuracy: 1.0000 - val_loss: 0.0155 - val_accuracy: 0.9960\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.8100e-04 - accuracy: 1.0000 - val_loss: 0.0154 - val_accuracy: 0.9960\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.6243e-04 - accuracy: 1.0000 - val_loss: 0.0157 - val_accuracy: 0.9960\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.3430e-04 - accuracy: 1.0000 - val_loss: 0.0160 - val_accuracy: 0.9961\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe41421ff60>"
      ]
     },
     "execution_count": 5,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ljRBopFTFvFQ"
   },
   "source": [
    "# $\\epsilon=10^{-2}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "executionInfo": {
     "elapsed": 709,
     "status": "ok",
     "timestamp": 1605804522226,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "bX_SMa5dF7DS"
   },
   "outputs": [],
   "source": [
    "eps=1e-2\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 197627,
     "status": "ok",
     "timestamp": 1605804721818,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "GtQUffB2F7DS",
    "outputId": "9b357c8f-cb3a-4c81-ec23-99afacffff5b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.1934 - accuracy: 0.9794 - val_loss: 0.0383 - val_accuracy: 0.9923\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0300 - accuracy: 0.9947 - val_loss: 0.0345 - val_accuracy: 0.9910\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0168 - accuracy: 0.9963 - val_loss: 0.0232 - val_accuracy: 0.9932\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0115 - accuracy: 0.9971 - val_loss: 0.0199 - val_accuracy: 0.9944\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0089 - accuracy: 0.9976 - val_loss: 0.0162 - val_accuracy: 0.9951\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0019 - accuracy: 0.9996 - val_loss: 0.0118 - val_accuracy: 0.9968\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0010 - accuracy: 0.9999 - val_loss: 0.0120 - val_accuracy: 0.9966\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 7.8867e-04 - accuracy: 0.9999 - val_loss: 0.0119 - val_accuracy: 0.9966\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 6.3269e-04 - accuracy: 1.0000 - val_loss: 0.0121 - val_accuracy: 0.9964\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 5.3554e-04 - accuracy: 1.0000 - val_loss: 0.0123 - val_accuracy: 0.9964\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 4.4682e-04 - accuracy: 1.0000 - val_loss: 0.0125 - val_accuracy: 0.9963\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.9105e-04 - accuracy: 1.0000 - val_loss: 0.0126 - val_accuracy: 0.9964\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.3321e-04 - accuracy: 1.0000 - val_loss: 0.0131 - val_accuracy: 0.9961\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.8406e-04 - accuracy: 1.0000 - val_loss: 0.0131 - val_accuracy: 0.9963\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.5695e-04 - accuracy: 1.0000 - val_loss: 0.0130 - val_accuracy: 0.9962\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.3265e-04 - accuracy: 1.0000 - val_loss: 0.0131 - val_accuracy: 0.9963\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.0276e-04 - accuracy: 1.0000 - val_loss: 0.0136 - val_accuracy: 0.9962\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.7497e-04 - accuracy: 1.0000 - val_loss: 0.0137 - val_accuracy: 0.9963\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.5792e-04 - accuracy: 1.0000 - val_loss: 0.0137 - val_accuracy: 0.9963\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.3811e-04 - accuracy: 1.0000 - val_loss: 0.0140 - val_accuracy: 0.9963\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe3ff562a58>"
      ]
     },
     "execution_count": 7,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UOeAQdeTHNxh"
   },
   "source": [
    "# $\\epsilon=10^{-4}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 1174,
     "status": "ok",
     "timestamp": 1605805027734,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "7oWuKbpDHKcO"
   },
   "outputs": [],
   "source": [
    "eps=1e-4\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 199244,
     "status": "ok",
     "timestamp": 1605805228741,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "3vPaqtKuHKcO",
    "outputId": "7dcdb082-6be3-4500-f03f-d2e2670b7548"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2294 - accuracy: 0.9663 - val_loss: 0.0399 - val_accuracy: 0.9929\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0289 - accuracy: 0.9947 - val_loss: 0.0268 - val_accuracy: 0.9932\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0166 - accuracy: 0.9963 - val_loss: 0.0212 - val_accuracy: 0.9941\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0113 - accuracy: 0.9972 - val_loss: 0.0160 - val_accuracy: 0.9957\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0084 - accuracy: 0.9978 - val_loss: 0.0147 - val_accuracy: 0.9955\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0020 - accuracy: 0.9996 - val_loss: 0.0117 - val_accuracy: 0.9967\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0013 - accuracy: 0.9998 - val_loss: 0.0112 - val_accuracy: 0.9966\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 9.7998e-04 - accuracy: 0.9999 - val_loss: 0.0114 - val_accuracy: 0.9969\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 7.9734e-04 - accuracy: 0.9999 - val_loss: 0.0117 - val_accuracy: 0.9969\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 6.3033e-04 - accuracy: 0.9999 - val_loss: 0.0114 - val_accuracy: 0.9971\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 5.3931e-04 - accuracy: 0.9999 - val_loss: 0.0118 - val_accuracy: 0.9970\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 4.5177e-04 - accuracy: 1.0000 - val_loss: 0.0123 - val_accuracy: 0.9969\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.7253e-04 - accuracy: 1.0000 - val_loss: 0.0121 - val_accuracy: 0.9969\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.1372e-04 - accuracy: 1.0000 - val_loss: 0.0125 - val_accuracy: 0.9968\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.6909e-04 - accuracy: 1.0000 - val_loss: 0.0131 - val_accuracy: 0.9963\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.3494e-04 - accuracy: 1.0000 - val_loss: 0.0127 - val_accuracy: 0.9968\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.0027e-04 - accuracy: 1.0000 - val_loss: 0.0128 - val_accuracy: 0.9967\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.7253e-04 - accuracy: 1.0000 - val_loss: 0.0132 - val_accuracy: 0.9968\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.5140e-04 - accuracy: 1.0000 - val_loss: 0.0133 - val_accuracy: 0.9969\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.3013e-04 - accuracy: 1.0000 - val_loss: 0.0135 - val_accuracy: 0.9966\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe3ff0e76a0>"
      ]
     },
     "execution_count": 9,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wNlY_QLPJucw"
   },
   "source": [
    "# $\\epsilon=10^{-6}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "executionInfo": {
     "elapsed": 1023,
     "status": "ok",
     "timestamp": 1605805425406,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "-KW-OhKAJo-o"
   },
   "outputs": [],
   "source": [
    "eps=1e-6\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 196934,
     "status": "ok",
     "timestamp": 1605805639557,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "Ru1RXC9PJo-o",
    "outputId": "090eae0b-14f7-499f-fe43-d203eb4b1886"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2336 - accuracy: 0.9669 - val_loss: 0.0351 - val_accuracy: 0.9946\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0297 - accuracy: 0.9946 - val_loss: 0.0218 - val_accuracy: 0.9952\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0169 - accuracy: 0.9960 - val_loss: 0.0193 - val_accuracy: 0.9948\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0120 - accuracy: 0.9969 - val_loss: 0.0165 - val_accuracy: 0.9956\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0089 - accuracy: 0.9975 - val_loss: 0.0182 - val_accuracy: 0.9939\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0023 - accuracy: 0.9996 - val_loss: 0.0130 - val_accuracy: 0.9959\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0013 - accuracy: 0.9998 - val_loss: 0.0127 - val_accuracy: 0.9960\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0010 - accuracy: 0.9999 - val_loss: 0.0127 - val_accuracy: 0.9963\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 8.0197e-04 - accuracy: 0.9999 - val_loss: 0.0126 - val_accuracy: 0.9963\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 6.9145e-04 - accuracy: 0.9999 - val_loss: 0.0125 - val_accuracy: 0.9964\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 6.0296e-04 - accuracy: 1.0000 - val_loss: 0.0129 - val_accuracy: 0.9966\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 5.1377e-04 - accuracy: 1.0000 - val_loss: 0.0130 - val_accuracy: 0.9968\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 4.3005e-04 - accuracy: 1.0000 - val_loss: 0.0132 - val_accuracy: 0.9966\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.6353e-04 - accuracy: 1.0000 - val_loss: 0.0132 - val_accuracy: 0.9967\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.1380e-04 - accuracy: 1.0000 - val_loss: 0.0133 - val_accuracy: 0.9967\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.6608e-04 - accuracy: 1.0000 - val_loss: 0.0131 - val_accuracy: 0.9963\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.2919e-04 - accuracy: 1.0000 - val_loss: 0.0131 - val_accuracy: 0.9964\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.0137e-04 - accuracy: 1.0000 - val_loss: 0.0133 - val_accuracy: 0.9966\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.7259e-04 - accuracy: 1.0000 - val_loss: 0.0134 - val_accuracy: 0.9967\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.4987e-04 - accuracy: 1.0000 - val_loss: 0.0136 - val_accuracy: 0.9967\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe3cc0df630>"
      ]
     },
     "execution_count": 11,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SF7-yvgUK3vA"
   },
   "source": [
    "# $\\epsilon=10^{-8}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "executionInfo": {
     "elapsed": 1060,
     "status": "ok",
     "timestamp": 1605805660844,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "zaQnghzfK3vA"
   },
   "outputs": [],
   "source": [
    "eps=1e-8\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 196878,
     "status": "ok",
     "timestamp": 1605805860260,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "aJRIdre0K3vA",
    "outputId": "4c2b7a33-c952-4f4c-8973-39a7bca31736"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2804 - accuracy: 0.9631 - val_loss: 0.0377 - val_accuracy: 0.9946\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0289 - accuracy: 0.9946 - val_loss: 0.0270 - val_accuracy: 0.9936\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 15s 6ms/step - loss: 0.0162 - accuracy: 0.9962 - val_loss: 0.0280 - val_accuracy: 0.9924\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 15s 6ms/step - loss: 0.0112 - accuracy: 0.9970 - val_loss: 0.0248 - val_accuracy: 0.9931\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 15s 6ms/step - loss: 0.0083 - accuracy: 0.9977 - val_loss: 0.0178 - val_accuracy: 0.9949\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0024 - accuracy: 0.9995 - val_loss: 0.0150 - val_accuracy: 0.9963\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0012 - accuracy: 0.9998 - val_loss: 0.0144 - val_accuracy: 0.9964\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 9.0233e-04 - accuracy: 0.9999 - val_loss: 0.0146 - val_accuracy: 0.9963\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 7.0284e-04 - accuracy: 0.9999 - val_loss: 0.0152 - val_accuracy: 0.9962\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 5.5922e-04 - accuracy: 1.0000 - val_loss: 0.0148 - val_accuracy: 0.9962\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 4.6038e-04 - accuracy: 1.0000 - val_loss: 0.0150 - val_accuracy: 0.9962\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.7760e-04 - accuracy: 1.0000 - val_loss: 0.0150 - val_accuracy: 0.9962\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.2079e-04 - accuracy: 1.0000 - val_loss: 0.0155 - val_accuracy: 0.9962\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.7624e-04 - accuracy: 1.0000 - val_loss: 0.0153 - val_accuracy: 0.9965\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.4513e-04 - accuracy: 1.0000 - val_loss: 0.0152 - val_accuracy: 0.9961\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.0975e-04 - accuracy: 1.0000 - val_loss: 0.0155 - val_accuracy: 0.9962\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.8567e-04 - accuracy: 1.0000 - val_loss: 0.0159 - val_accuracy: 0.9963\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.5858e-04 - accuracy: 1.0000 - val_loss: 0.0159 - val_accuracy: 0.9963\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.3989e-04 - accuracy: 1.0000 - val_loss: 0.0164 - val_accuracy: 0.9963\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.2231e-04 - accuracy: 1.0000 - val_loss: 0.0161 - val_accuracy: 0.9962\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe332a46828>"
      ]
     },
     "execution_count": 13,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KRhX_edHM34u"
   },
   "source": [
    "# $\\epsilon=10^{-10}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "executionInfo": {
     "elapsed": 663,
     "status": "ok",
     "timestamp": 1605805920559,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "Ul8TDSpMM34v"
   },
   "outputs": [],
   "source": [
    "eps=1e-10\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 198597,
     "status": "ok",
     "timestamp": 1605806121142,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "JoFBEIVCM34v",
    "outputId": "5a3704fb-126a-49e7-dfc2-c4fdd7c561b5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2356 - accuracy: 0.9572 - val_loss: 0.0341 - val_accuracy: 0.9944\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0296 - accuracy: 0.9946 - val_loss: 0.0230 - val_accuracy: 0.9953\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0167 - accuracy: 0.9961 - val_loss: 0.0232 - val_accuracy: 0.9935\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0108 - accuracy: 0.9972 - val_loss: 0.0149 - val_accuracy: 0.9957\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0088 - accuracy: 0.9976 - val_loss: 0.0242 - val_accuracy: 0.9935\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0020 - accuracy: 0.9996 - val_loss: 0.0140 - val_accuracy: 0.9958\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0012 - accuracy: 0.9999 - val_loss: 0.0136 - val_accuracy: 0.9962\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 8.9953e-04 - accuracy: 0.9999 - val_loss: 0.0139 - val_accuracy: 0.9961\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 7.3894e-04 - accuracy: 1.0000 - val_loss: 0.0137 - val_accuracy: 0.9962\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 6.0464e-04 - accuracy: 1.0000 - val_loss: 0.0135 - val_accuracy: 0.9962\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 5.1470e-04 - accuracy: 1.0000 - val_loss: 0.0137 - val_accuracy: 0.9964\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 4.4443e-04 - accuracy: 1.0000 - val_loss: 0.0136 - val_accuracy: 0.9961\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.8079e-04 - accuracy: 1.0000 - val_loss: 0.0139 - val_accuracy: 0.9961\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.2704e-04 - accuracy: 1.0000 - val_loss: 0.0143 - val_accuracy: 0.9961\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.8086e-04 - accuracy: 1.0000 - val_loss: 0.0142 - val_accuracy: 0.9963\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.3822e-04 - accuracy: 1.0000 - val_loss: 0.0143 - val_accuracy: 0.9964\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.1059e-04 - accuracy: 1.0000 - val_loss: 0.0144 - val_accuracy: 0.9961\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.8928e-04 - accuracy: 1.0000 - val_loss: 0.0148 - val_accuracy: 0.9961\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.5808e-04 - accuracy: 1.0000 - val_loss: 0.0149 - val_accuracy: 0.9964\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.4090e-04 - accuracy: 1.0000 - val_loss: 0.0151 - val_accuracy: 0.9965\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe2f6a8ac88>"
      ]
     },
     "execution_count": 15,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oTOA5FaLOaaf"
   },
   "source": [
    "# $\\epsilon=10^{-20}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "executionInfo": {
     "elapsed": 1303,
     "status": "ok",
     "timestamp": 1605806126378,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "a43U5zzuOaaf"
   },
   "outputs": [],
   "source": [
    "eps=1e-20\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW+eps)+self.b\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 197689,
     "status": "ok",
     "timestamp": 1605806325792,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "E1TY-alSOaaf",
    "outputId": "3b3c7807-7a84-4954-b4a0-0d797cc025d3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2183 - accuracy: 0.9668 - val_loss: 0.0353 - val_accuracy: 0.9940\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0296 - accuracy: 0.9947 - val_loss: 0.0266 - val_accuracy: 0.9943\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0169 - accuracy: 0.9961 - val_loss: 0.0155 - val_accuracy: 0.9960\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0118 - accuracy: 0.9970 - val_loss: 0.0226 - val_accuracy: 0.9925\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0087 - accuracy: 0.9977 - val_loss: 0.0165 - val_accuracy: 0.9946\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 28ms/step - loss: 0.0024 - accuracy: 0.9995 - val_loss: 0.0112 - val_accuracy: 0.9963\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0012 - accuracy: 0.9998 - val_loss: 0.0106 - val_accuracy: 0.9963\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 9.1080e-04 - accuracy: 0.9999 - val_loss: 0.0108 - val_accuracy: 0.9960\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 6.9919e-04 - accuracy: 0.9999 - val_loss: 0.0112 - val_accuracy: 0.9963\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 5.5252e-04 - accuracy: 1.0000 - val_loss: 0.0113 - val_accuracy: 0.9963\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 4.7808e-04 - accuracy: 1.0000 - val_loss: 0.0116 - val_accuracy: 0.9961\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.9548e-04 - accuracy: 1.0000 - val_loss: 0.0117 - val_accuracy: 0.9961\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 3.4155e-04 - accuracy: 1.0000 - val_loss: 0.0120 - val_accuracy: 0.9961\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.9580e-04 - accuracy: 1.0000 - val_loss: 0.0123 - val_accuracy: 0.9961\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.6680e-04 - accuracy: 1.0000 - val_loss: 0.0122 - val_accuracy: 0.9959\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.2423e-04 - accuracy: 1.0000 - val_loss: 0.0125 - val_accuracy: 0.9959\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.9679e-04 - accuracy: 1.0000 - val_loss: 0.0124 - val_accuracy: 0.9961\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 1.7443e-04 - accuracy: 1.0000 - val_loss: 0.0129 - val_accuracy: 0.9962\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.5602e-04 - accuracy: 1.0000 - val_loss: 0.0127 - val_accuracy: 0.9960\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.3445e-04 - accuracy: 1.0000 - val_loss: 0.0130 - val_accuracy: 0.9962\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe2f63fd7f0>"
      ]
     },
     "execution_count": 17,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PJL81pHPQGl6"
   },
   "source": [
    "# $\\epsilon=0$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "U725EzYQiGWx"
   },
   "source": [
    "It worked for $\\epsilon=0$, but we think it is simply lucky. We think it is important to add $\\epsilon$ in $\\frac{|w|^2-|x|^2}{|x-w|^2+\\epsilon}$ to ensure numerical stabiliy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "executionInfo": {
     "elapsed": 919,
     "status": "ok",
     "timestamp": 1605806343997,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "F1KCfVCmQGl6"
   },
   "outputs": [],
   "source": [
    "eps=0\n",
    "class layer_Poisson_linear(tf.keras.layers.Layer):\n",
    "  def __init__(self, in_features, out_features, **kwargs):\n",
    "    super().__init__(**kwargs)\n",
    "\n",
    "    self.w = tf.Variable(\n",
    "      tf.random.normal([in_features, out_features]), name='w')\n",
    "    self.a = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='a')\n",
    "    self.b = tf.Variable(\n",
    "        tf.random.normal([out_features]), name='b')\n",
    "  def call(self, x):\n",
    "    xnorm2 = tf.reduce_sum(x*x, axis=-1, keepdims=True)\n",
    "    Wnorm2 = tf.reduce_sum(self.w*self.w, 0)\n",
    "    xW = tf.matmul(x, self.w)\n",
    "    return self.a*(Wnorm2-xnorm2)/(Wnorm2+xnorm2-2*xW)+self.b\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 196994,
     "status": "ok",
     "timestamp": 1605806542611,
     "user": {
      "displayName": "Mingxi Wang",
      "photoUrl": "",
      "userId": "01165593761653277599"
     },
     "user_tz": -60
    },
    "id": "VovwYQvyQGl6",
    "outputId": "343a0d55-7890-4393-c5c8-e4a787cdfd62"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.2104 - accuracy: 0.9726 - val_loss: 0.0371 - val_accuracy: 0.9938\n",
      "Epoch 2/5\n",
      "2344/2344 [==============================] - 16s 7ms/step - loss: 0.0304 - accuracy: 0.9945 - val_loss: 0.0185 - val_accuracy: 0.9953\n",
      "Epoch 3/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0169 - accuracy: 0.9961 - val_loss: 0.0227 - val_accuracy: 0.9939\n",
      "Epoch 4/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0117 - accuracy: 0.9970 - val_loss: 0.0188 - val_accuracy: 0.9954\n",
      "Epoch 5/5\n",
      "2344/2344 [==============================] - 15s 7ms/step - loss: 0.0086 - accuracy: 0.9977 - val_loss: 0.0181 - val_accuracy: 0.9946\n",
      "Epoch 1/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 0.0023 - accuracy: 0.9995 - val_loss: 0.0132 - val_accuracy: 0.9964\n",
      "Epoch 2/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 0.0012 - accuracy: 0.9999 - val_loss: 0.0132 - val_accuracy: 0.9967\n",
      "Epoch 3/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 8.9770e-04 - accuracy: 0.9999 - val_loss: 0.0132 - val_accuracy: 0.9968\n",
      "Epoch 4/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 6.9337e-04 - accuracy: 1.0000 - val_loss: 0.0136 - val_accuracy: 0.9968\n",
      "Epoch 5/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 5.7583e-04 - accuracy: 1.0000 - val_loss: 0.0137 - val_accuracy: 0.9969\n",
      "Epoch 6/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 4.7911e-04 - accuracy: 1.0000 - val_loss: 0.0142 - val_accuracy: 0.9969\n",
      "Epoch 7/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.8751e-04 - accuracy: 1.0000 - val_loss: 0.0142 - val_accuracy: 0.9969\n",
      "Epoch 8/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 3.3340e-04 - accuracy: 1.0000 - val_loss: 0.0143 - val_accuracy: 0.9969\n",
      "Epoch 9/15\n",
      "293/293 [==============================] - 8s 26ms/step - loss: 2.9754e-04 - accuracy: 1.0000 - val_loss: 0.0141 - val_accuracy: 0.9971\n",
      "Epoch 10/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.5702e-04 - accuracy: 1.0000 - val_loss: 0.0145 - val_accuracy: 0.9968\n",
      "Epoch 11/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.2126e-04 - accuracy: 1.0000 - val_loss: 0.0146 - val_accuracy: 0.9970\n",
      "Epoch 12/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.9764e-04 - accuracy: 1.0000 - val_loss: 0.0149 - val_accuracy: 0.9968\n",
      "Epoch 13/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.8587e-04 - accuracy: 1.0000 - val_loss: 0.0149 - val_accuracy: 0.9970\n",
      "Epoch 14/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 2.1325e-04 - accuracy: 1.0000 - val_loss: 0.0151 - val_accuracy: 0.9972\n",
      "Epoch 15/15\n",
      "293/293 [==============================] - 8s 27ms/step - loss: 1.4265e-04 - accuracy: 1.0000 - val_loss: 0.0148 - val_accuracy: 0.9970\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fe2f2112668>"
      ]
     },
     "execution_count": 20,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "model = tf.keras.Sequential([\n",
    "    layers.Conv2D(32, (3, 3),padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(64, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Conv2D(128, (3, 3),padding='same', activation='relu'),\n",
    "    layers.MaxPooling2D((2, 2)),\n",
    "    layers.Flatten(),\n",
    "    layers.BatchNormalization(),\n",
    "    layers.Dense(1000,activation='relu'),\n",
    "    layers.BatchNormalization(),\n",
    "    layer_Poisson_linear(1000,10),\n",
    "    tf.keras.layers.Activation('sigmoid'),\n",
    "    layers.BatchNormalization()\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=128, epochs=5)\n",
    "position_poisson =[ layer.name.find('poisson')>0 for  layer in   model.layers].index(True)\n",
    "model.layers[position_poisson].trainable=False\n",
    "model.fit(Train_images, Train_labels, validation_data=(Test_images, test_labels), batch_size=1024, epochs=15)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyOdxgYEQgETfP/OpFPHa/M8",
   "collapsed_sections": [],
   "name": "Stability_epsilon.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
