{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "from models import DenseReparam, DenseWN\n",
    "from utils import plot_loss\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "GLOBAL_SEED = 0\n",
    "\n",
    "os.environ['PYTHONHASHSEED']=str(GLOBAL_SEED)\n",
    "os.environ['TF_CUDNN_DETERMINISTIC'] = '1'\n",
    "random.seed(GLOBAL_SEED)\n",
    "np.random.seed(GLOBAL_SEED)\n",
    "tf.random.set_seed(GLOBAL_SEED)\n",
    "\n",
    "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\" # disable GPU\n",
    "dtype = 'float64'\n",
    "tf.keras.backend.set_floatx(dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "banana_dataset = pd.read_csv(\"2d_banana.csv\", sep=\",\")\n",
    "\n",
    "x = banana_dataset[['At1','At2']].to_numpy()\n",
    "y = banana_dataset.Class.to_numpy().reshape(-1, 1)\n",
    "y[y==-1] = 0\n",
    "num_data_points = 1000\n",
    "random_idx = np.random.choice(x.shape[0], size=num_data_points, replace=False) \n",
    "x, y = x[random_idx, :], y[random_idx, :]\n",
    "\n",
    "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=GLOBAL_SEED)\n",
    "\n",
    "# standardize features\n",
    "train_mean = tf.reduce_mean(x_train, axis=0, keepdims=True)\n",
    "train_std = tf.math.reduce_std(x_train, axis=0, keepdims=True)\n",
    "x_train = (x_train - train_mean) / train_std\n",
    "x_test = (x_test - train_mean) / train_std\n",
    "\n",
    "print(\"train set size:\", x_train.shape, y_train.shape)\n",
    "print(\"test set size:\", x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "units = 10\n",
    "n_epochs = 5000\n",
    "batch_size = 128\n",
    "early_stop_patience = 100\n",
    "\n",
    "lr_gmp = 0.1\n",
    "lr_others = 0.01\n",
    "\n",
    "os.makedirs(\"./figs\", exist_ok=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Geometric Parameterization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_reparam_dnn_model(lr):\n",
    "    model = tf.keras.Sequential([\n",
    "        DenseReparam(units, activation='relu'),\n",
    "        tf.keras.layers.Dense(1, activation='sigmoid'),\n",
    "    ])\n",
    "    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), optimizer=tf.keras.optimizers.Adam(lr), metrics=[tf.keras.metrics.BinaryAccuracy()])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reparam_dnn_model = build_reparam_dnn_model(lr=lr_gmp)\n",
    "\n",
    "lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_binary_accuracy', factor=0.1, patience=early_stop_patience//2, verbose=0, mode='max')\n",
    "early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=early_stop_patience, verbose=0, mode='max', restore_best_weights=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reparam_locs_x = [np.zeros(units)]\n",
    "reparam_locs_y = [np.zeros(units)]\n",
    "\n",
    "def get_feature_locs(epoch, logs):\n",
    "    weights = reparam_dnn_model.get_weights()[0]\n",
    "    thetas = weights[0]\n",
    "    lambdas = weights[1]\n",
    "\n",
    "    reparam_locs_x.append(-1.0*lambdas*np.cos(thetas))\n",
    "    reparam_locs_y.append(-1.0*lambdas*np.sin(thetas))\n",
    "\n",
    "print_reparam_locs = tf.keras.callbacks.LambdaCallback(on_epoch_end=get_feature_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_reparam_dnn = reparam_dnn_model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=batch_size,\n",
    "    validation_data=(x_test, y_test),\n",
    "    verbose=0, \n",
    "    epochs=n_epochs,\n",
    "    callbacks=[print_reparam_locs, lr_decay, early_stop]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_results_reparam_dnn = reparam_dnn_model.evaluate(x_train, y_train, verbose=0)\n",
    "print(\"Train Accuracy:\", train_results_reparam_dnn[1])\n",
    "test_results_reparam_dnn = reparam_dnn_model.evaluate(x_test, y_test, verbose=0)\n",
    "print(\"Test Accuracy:\", test_results_reparam_dnn[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('ggplot')\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "x_grid = np.linspace(-4.5, 4.5, 100)\n",
    "xx, yy = np.meshgrid(x_grid, x_grid)\n",
    "Xplot = np.vstack((xx.flatten(), yy.flatten())).T\n",
    "\n",
    "pred = reparam_dnn_model(Xplot).numpy().reshape(xx.shape)  # here we only care about the mean\n",
    "\n",
    "plt.contourf(xx, yy, pred, alpha=0.3)\n",
    "cs = plt.contour(xx, yy, pred, [0.5], colors='black', linewidths=3)\n",
    "plt.plot([0, 0], [0, 0], label='Classification decision boundary', c='black')\n",
    "\n",
    "plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train)\n",
    "plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker='x')\n",
    "\n",
    "\n",
    "weights = reparam_dnn_model.get_weights()[0]\n",
    "thetas = weights[0]\n",
    "lambdas = weights[1]\n",
    "print(weights[2])\n",
    "\n",
    "feature_locs_x = -1.0*lambdas*np.cos(thetas)\n",
    "feature_locs_y = -1.0*lambdas*np.sin(thetas)\n",
    "plt.scatter(feature_locs_x, feature_locs_y, color='red', label=\"Spetial location of ReLU feature\", marker='*',s=200)\n",
    "\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.title(\"Test Acc: {0:.3f}\".format(test_results_reparam_dnn[1]), fontsize=20)\n",
    "plt.savefig(\"./figs/banana_gmp.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reparam_locs_x_np = np.vstack(reparam_locs_x).T\n",
    "reparam_locs_y_np = np.vstack(reparam_locs_y).T\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "for i in range(units):\n",
    "    plt.plot(reparam_locs_x_np[i], reparam_locs_y_np[i], linewidth=1, marker='x')\n",
    "\n",
    "plt.xlim(-3, 3)\n",
    "plt.ylim(-3, 3)\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.savefig(\"./figs/banana_gmp_trajectory.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Standard Parameterization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dnn_model(lr):\n",
    "    model = tf.keras.Sequential([\n",
    "        tf.keras.layers.Dense(units, activation='relu'),\n",
    "        tf.keras.layers.Dense(1, activation='sigmoid'),\n",
    "    ])\n",
    "    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), optimizer=tf.keras.optimizers.Adam(lr), metrics=[tf.keras.metrics.BinaryAccuracy()])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dnn_model = build_dnn_model(lr_others)\n",
    "\n",
    "lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_binary_accuracy', factor=0.1, patience=early_stop_patience//2, verbose=0, mode='max')\n",
    "early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=early_stop_patience, verbose=0, mode='max', restore_best_weights=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dnn_locs_x = [np.zeros(units)]\n",
    "dnn_locs_y = [np.zeros(units)]\n",
    "\n",
    "def get_feature_locs(epoch, logs):\n",
    "    weights = dnn_model.get_weights()\n",
    "    w = weights[0]\n",
    "    b = weights[1]\n",
    "    norm = np.linalg.norm(w, axis=0)\n",
    "\n",
    "    dnn_locs_x.append(-1.0*b*w[0]/norm**2)\n",
    "    dnn_locs_y.append(-1.0*b*w[1]/norm**2)\n",
    "\n",
    "print_dnn_locs = tf.keras.callbacks.LambdaCallback(on_epoch_end=get_feature_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_dnn = dnn_model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=batch_size,\n",
    "    validation_data=(x_test, y_test),\n",
    "    verbose=0, \n",
    "    epochs=n_epochs,\n",
    "    callbacks=[print_dnn_locs, lr_decay, early_stop]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_results_dnn = dnn_model.evaluate(x_train, y_train, verbose=0)\n",
    "print(\"Train Accuracy:\", train_results_dnn[1])\n",
    "test_results_dnn = dnn_model.evaluate(x_test, y_test, verbose=0)\n",
    "print(\"Test Accuracy:\", test_results_dnn[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('ggplot')\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "x_grid = np.linspace(-4.5, 4.5, 100)\n",
    "xx, yy = np.meshgrid(x_grid, x_grid)\n",
    "Xplot = np.vstack((xx.flatten(), yy.flatten())).T\n",
    "\n",
    "pred = dnn_model(Xplot).numpy().reshape(xx.shape)  # here we only care about the mean\n",
    "\n",
    "plt.contourf(xx, yy, pred, alpha=0.3)\n",
    "cs = plt.contour(xx, yy, pred, [0.5], colors='black', linewidths=3)\n",
    "plt.plot([0, 0], [0, 0], label='Classification decision boundary', c='black')\n",
    "\n",
    "plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train)\n",
    "plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker='x')\n",
    "\n",
    "\n",
    "weights = dnn_model.get_weights()\n",
    "w = weights[0]\n",
    "b = weights[1]\n",
    "norm = np.linalg.norm(w, axis=0)\n",
    "feature_locs_x = -1.0*b*w[0]/norm**2\n",
    "feature_locs_y = -1.0*b*w[1]/norm**2\n",
    "plt.scatter(feature_locs_x, feature_locs_y, color='red', label=\"Neural activation boundary induced feature\", marker='*',s=200)\n",
    "\n",
    "plt.title(\"Test Acc: {0:.3f}\".format(test_results_dnn[1]), fontsize=20)\n",
    "plt.legend(fontsize=12.5)\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.savefig(\"./figs/banana_sp.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dnn_locs_x_np = np.vstack(dnn_locs_x).T\n",
    "dnn_locs_y_np = np.vstack(dnn_locs_y).T\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "for i in range(units):\n",
    "    plt.plot(dnn_locs_x_np[i], dnn_locs_y_np[i], linewidth=1, marker='x')\n",
    "\n",
    "plt.xlim(-3, 3)\n",
    "plt.ylim(-3, 3)\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.savefig(\"./figs/banana_sp_trajectory.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batch Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_bn_dnn_model(lr):\n",
    "    model = tf.keras.Sequential([\n",
    "        tf.keras.layers.Dense(units),\n",
    "        tf.keras.layers.BatchNormalization(),\n",
    "        tf.keras.layers.ReLU(),\n",
    "        tf.keras.layers.Dense(1, activation='sigmoid'),\n",
    "    ])\n",
    "    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), optimizer=tf.keras.optimizers.Adam(lr), metrics=[tf.keras.metrics.BinaryAccuracy()])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn_dnn_model = build_bn_dnn_model(lr=lr_others)\n",
    "\n",
    "lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_binary_accuracy', factor=0.1, patience=early_stop_patience//2, verbose=0, mode='max')\n",
    "early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=early_stop_patience, verbose=0, mode='max', restore_best_weights=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn_locs_x = [np.zeros(units)]\n",
    "bn_locs_y = [np.zeros(units)]\n",
    "\n",
    "def get_feature_locs(epoch, logs):\n",
    "    bn_layer = bn_dnn_model.get_layer(\"batch_normalization\")\n",
    "    mu = bn_layer.moving_mean.numpy()\n",
    "    sigma = np.sqrt(bn_layer.moving_variance.numpy() + bn_layer.epsilon)\n",
    "    gamma = bn_layer.gamma.numpy()\n",
    "    beta = bn_layer.beta.numpy()\n",
    "\n",
    "    weights = bn_dnn_model.get_weights()\n",
    "    w = weights[0]\n",
    "    b = weights[1] + sigma * beta / gamma - mu\n",
    "    norm = np.linalg.norm(w, axis=0)\n",
    "\n",
    "    bn_locs_x.append(-1.0*b*w[0]/norm**2)\n",
    "    bn_locs_y.append(-1.0*b*w[1]/norm**2)\n",
    "\n",
    "print_bn_locs = tf.keras.callbacks.LambdaCallback(on_epoch_end=get_feature_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_bn_dnn = bn_dnn_model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=batch_size,\n",
    "    validation_data=(x_test, y_test),\n",
    "    verbose=0, \n",
    "    epochs=n_epochs,\n",
    "    callbacks=[print_bn_locs, lr_decay, early_stop]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_results_bn_dnn = bn_dnn_model.evaluate(x_train, y_train, verbose=0)\n",
    "print(\"Train Accuracy:\", train_results_bn_dnn[1])\n",
    "test_results_bn_dnn = bn_dnn_model.evaluate(x_test, y_test, verbose=0)\n",
    "print(\"Test Accuracy:\", test_results_bn_dnn[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('ggplot')\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "x_grid = np.linspace(-4.5, 4.5, 100)\n",
    "xx, yy = np.meshgrid(x_grid, x_grid)\n",
    "Xplot = np.vstack((xx.flatten(), yy.flatten())).T\n",
    "\n",
    "pred = bn_dnn_model(Xplot).numpy().reshape(xx.shape)  # here we only care about the mean\n",
    "\n",
    "plt.contourf(xx, yy, pred, alpha=0.3)\n",
    "cs = plt.contour(xx, yy, pred, [0.5], colors='black', linewidths=3)\n",
    "plt.plot([0, 0], [0, 0], label='Classification decision boundary', c='black')\n",
    "\n",
    "plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train)\n",
    "plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker='x')\n",
    "\n",
    "bn_layer = bn_dnn_model.get_layer(\"batch_normalization\")\n",
    "mu = bn_layer.moving_mean.numpy()\n",
    "sigma = np.sqrt(bn_layer.moving_variance.numpy() + bn_layer.epsilon)\n",
    "gamma = bn_layer.gamma.numpy()\n",
    "beta = bn_layer.beta.numpy()\n",
    "\n",
    "weights = bn_dnn_model.get_weights()\n",
    "w = weights[0]\n",
    "b = weights[1] + sigma * beta / gamma - mu\n",
    "norm = np.linalg.norm(w, axis=0)\n",
    "feature_locs_x = -1.0*b*w[0]/norm**2\n",
    "feature_locs_y = -1.0*b*w[1]/norm**2\n",
    "plt.scatter(feature_locs_x, feature_locs_y, color='red', label=\"Spetial location of ReLU feature\", marker='*',s=200)\n",
    "\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.title(\"Test Acc: {0:.3f}\".format(test_results_bn_dnn[1]), fontsize=20)\n",
    "plt.savefig(\"./figs/banana_batchnorm.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn_locs_x_np = np.vstack(bn_locs_x).T\n",
    "bn_locs_y_np = np.vstack(bn_locs_y).T\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "for i in range(units):\n",
    "    plt.plot(bn_locs_x_np[i], bn_locs_y_np[i], linewidth=1, marker='x')\n",
    "\n",
    "plt.xlim(-3, 3)\n",
    "plt.ylim(-3, 3)\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.savefig(\"./figs/banana_batchnorm_trajectory.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Weight Normalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_wn_dnn_model(lr):\n",
    "    model = tf.keras.Sequential([\n",
    "        DenseWN(units, activation='relu'),\n",
    "        DenseWN(1),\n",
    "        tf.keras.layers.Activation('sigmoid'),\n",
    "    ])\n",
    "    \n",
    "    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=False), optimizer=tf.keras.optimizers.Adam(lr), metrics=[tf.keras.metrics.BinaryAccuracy()])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wn_dnn_model = build_wn_dnn_model(lr=lr_others)\n",
    "\n",
    "lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_binary_accuracy', factor=0.1, patience=early_stop_patience//2, verbose=0, mode='max')\n",
    "early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_binary_accuracy', patience=early_stop_patience, verbose=0, mode='max', restore_best_weights=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wn_locs_x = [np.zeros(units)]\n",
    "wn_locs_y = [np.zeros(units)]\n",
    "\n",
    "def get_feature_locs(epoch, logs):\n",
    "    weights = wn_dnn_model.get_weights()\n",
    "    w = weights[0][:-2, :] / np.linalg.norm(weights[0][:-2, :], axis=0) * weights[0][-1:, :]\n",
    "    b = weights[1][-2:-1, :]\n",
    "    norm = np.linalg.norm(w, axis=0)\n",
    "\n",
    "    wn_locs_x.append(-1.0*b*w[0]/norm**2)\n",
    "    wn_locs_y.append(-1.0*b*w[1]/norm**2)\n",
    "\n",
    "print_wn_locs = tf.keras.callbacks.LambdaCallback(on_epoch_end=get_feature_locs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "history_wn_dnn = wn_dnn_model.fit(\n",
    "    x_train,\n",
    "    y_train,\n",
    "    batch_size=batch_size,\n",
    "    validation_data=(x_test, y_test),\n",
    "    verbose=0, \n",
    "    epochs=n_epochs,\n",
    "    callbacks=[print_wn_locs, lr_decay, early_stop]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_results_wn_dnn = wn_dnn_model.evaluate(x_train, y_train, verbose=0)\n",
    "print(\"Train Accuracy:\", train_results_wn_dnn[1])\n",
    "test_results_wn_dnn = wn_dnn_model.evaluate(x_test, y_test, verbose=0)\n",
    "print(\"Test Accuracy:\", test_results_wn_dnn[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('ggplot')\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "x_grid = np.linspace(-4.5, 4.5, 100)\n",
    "xx, yy = np.meshgrid(x_grid, x_grid)\n",
    "Xplot = np.vstack((xx.flatten(), yy.flatten())).T\n",
    "\n",
    "pred = wn_dnn_model(Xplot).numpy().reshape(xx.shape)  # here we only care about the mean\n",
    "\n",
    "plt.contourf(xx, yy, pred, alpha=0.3)\n",
    "cs = plt.contour(xx, yy, pred, [0.5], colors='black', linewidths=3)\n",
    "plt.plot([0, 0], [0, 0], label='Classification decision boundary', c='black')\n",
    "\n",
    "plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train)\n",
    "plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, marker='x')\n",
    "\n",
    "\n",
    "weights = wn_dnn_model.get_weights()\n",
    "w = weights[0][:-2, :] / np.linalg.norm(weights[0][:-2, :], axis=0) * weights[0][-1:, :]\n",
    "b = weights[1][-2:-1, :]\n",
    "norm = np.linalg.norm(w, axis=0)\n",
    "feature_locs_x = -1.0*b*w[0]/norm**2\n",
    "feature_locs_y = -1.0*b*w[1]/norm**2\n",
    "plt.scatter(feature_locs_x, feature_locs_y, color='red', label=\"Spetial location of ReLU feature\", marker='*',s=200)\n",
    "\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.title(\"Test Acc: {0:.3f}\".format(test_results_wn_dnn[1]), fontsize=20)\n",
    "plt.savefig(\"./figs/banana_weightnorm.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wn_locs_x_np = np.vstack(wn_locs_x).T\n",
    "wn_locs_y_np = np.vstack(wn_locs_y).T\n",
    "\n",
    "plt.figure(figsize=(6, 6))\n",
    "\n",
    "for i in range(units):\n",
    "    plt.plot(wn_locs_x_np[i], wn_locs_y_np[i], linewidth=1, marker='x')\n",
    "\n",
    "plt.xlim(-3, 3)\n",
    "plt.ylim(-3, 3)\n",
    "plt.xlabel(\"$x_1$\")\n",
    "plt.ylabel(\"$x_2$\")\n",
    "plt.savefig(\"./figs/banana_weightnorm_trajectory.pdf\", format=\"pdf\", bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
