{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "273bf6f3",
   "metadata": {},
   "source": [
    "# Smooth Monotonic Networks: Counting silent neurons\n",
    "\n",
    "## General definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e5836a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec865337",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.metrics import mean_squared_error as mse\n",
    "from sklearn.metrics import r2_score as r2\n",
    "from sklearn.isotonic import IsotonicRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tnrange\n",
    "\n",
    "from MonotonicNN import SmoothMonotonicNN, MonotonicNN, MonotonicNNAlt\n",
    "from MonotonicNNPaperUtils import Progress, total_params, fit_torch\n",
    "\n",
    "from monotonenorm import GroupSort, direct_norm, SigmaNet"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e87a426",
   "metadata": {},
   "source": [
    "## Univariate experiments \n",
    "Section 4.1 in the manuscript."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "904d464a",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 21  # number of trials, odd number for having a \"median trial\"\n",
    "ls = 75  # lattice points (k in original paper)\n",
    "ls_small = 35\n",
    "K = 6  # number of SMM groups, we always use H_k = K\n",
    "N_train = 100  # number of examples in training data set\n",
    "N_test = 1000 # number of examples in test data set\n",
    "sigma = 0.01  # noise level, feel free to vary \n",
    "width_small = K\n",
    "width = K+2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f0e9d8c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate1D(function_name, sigma=0., random=False, xrange=1., N=50):\n",
    "    if random:\n",
    "        x = np.random.rand(N) * xrange\n",
    "        x = np.sort(x, axis=0)\n",
    "    else:\n",
    "        xstep = xrange / N\n",
    "        x = np.arange(0, xrange, xstep)\n",
    "    match function_name:\n",
    "        case 'sigmoid10':\n",
    "            y = 1. /(1. + np.exp(-(x-xrange/2.) * 10.))\n",
    "        case 'sq':\n",
    "            y = x**2\n",
    "        case 'sqrt':\n",
    "            y = np.sqrt(x)\n",
    "    y = y + sigma*np.random.normal(0, 1., N)\n",
    "    return x.reshape(N, 1), y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "97309506",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d54faa3c3e704a66aac3dd48dc85507e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "monotonic_alt 73 parameters\n",
      "smooth 73 parameters\n"
     ]
    }
   ],
   "source": [
    "methods = ['monotonic_alt','smooth']\n",
    "tasks = ['sq', 'sqrt', 'sigmoid10']\n",
    "N_tasks = len(tasks)\n",
    "N_methods = len(methods)\n",
    "\n",
    "MSE_train = np.zeros((N_tasks, N_methods, T))\n",
    "MSE_test = np.zeros((N_tasks, N_methods, T))\n",
    "MSE_clip = np.zeros((N_tasks, N_methods, T))\n",
    "R2_train = np.zeros((N_tasks, N_methods, T))\n",
    "R2_test = np.zeros((N_tasks, N_methods, T))\n",
    "X_train = np.zeros((N_tasks, T, N_train))\n",
    "Y_train = np.zeros((N_tasks, T, N_train))\n",
    "X_test = np.zeros((N_tasks, T, N_test))\n",
    "Y_test = np.zeros((N_tasks, T, N_test))\n",
    "O_test = np.zeros((N_tasks, N_methods, T, N_test))\n",
    "no_params=np.zeros(N_methods)\n",
    "Active = np.zeros((N_tasks, T))\n",
    "Dead = np.zeros((N_tasks, T))\n",
    "ActiveInit = np.zeros((N_tasks, T))\n",
    "active = 0\n",
    "\n",
    "for trial in tnrange(T):\n",
    "    for task_id, task in enumerate(tasks):\n",
    "        seed = task_id + trial*N_tasks\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        x_train, y_train = generate1D(task, sigma=sigma, random=True, N=N_train)\n",
    "        x_test, y_test   = generate1D(task, sigma=0., random=False, N=N_test)\n",
    "        X_test[task_id, trial] = x_test.reshape(-1)\n",
    "        Y_test[task_id, trial] = y_test\n",
    "        X_train[task_id, trial] = x_train.reshape(-1)\n",
    "        Y_train[task_id, trial] = y_train\n",
    "        x_train_torch = torch.from_numpy(x_train.astype(np.float32)).clone()\n",
    "        y_train_torch = torch.from_numpy(y_train.astype(np.float32)).clone()\n",
    "        x_test_torch = torch.from_numpy(x_test.astype(np.float32)).clone()\n",
    "        y_test_torch = torch.from_numpy(y_test.astype(np.float32)).clone()\n",
    "\n",
    "        for method_id, method in enumerate(methods):\n",
    "            match method:\n",
    "                case 'smooth':\n",
    "                    model = SmoothMonotonicNN(1, K, K, beta=-1.)\n",
    "                    if(trial+task_id==0):\n",
    "                        no_params[method_id] = total_params(model)\n",
    "                        print(method, total_params(model), \"parameters\")\n",
    "                    fit_torch(model, x_train_torch, y_train_torch)\n",
    "                    y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                    y_pred_test = model(x_test_torch).detach().numpy()\n",
    "                    \n",
    "                    model.zero_grad()\n",
    "                    sum_y = torch.sum(model(x_test_torch))\n",
    "                    sum_y.backward()\n",
    "                    dead = model.check_grad()\n",
    "                    Dead[task_id, trial] = dead\n",
    "             \n",
    "                case 'monotonic_alt':\n",
    "                    model = MonotonicNNAlt(1, K, K)\n",
    "                    model.reset_active_max()\n",
    "                    y_pred_test = model(x_test_torch).detach().numpy()\n",
    "                    activeInit, _ = model.active_max()\n",
    "                    if(trial+task_id==0):\n",
    "                        no_params[method_id] = total_params(model)\n",
    "                        print(method, total_params(model), \"parameters\")\n",
    "                    fit_torch(model, x_train_torch, y_train_torch)\n",
    "                    y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                    model.reset_active_max()\n",
    "                    y_pred_test = model(x_test_torch).detach().numpy()\n",
    "                    active, _ = model.active_max()\n",
    "\n",
    "            MSE_train[task_id, method_id, trial] = mse(y_train, y_pred_train)\n",
    "            MSE_test[task_id, method_id, trial] = mse(y_test, y_pred_test)\n",
    "            Active[task_id, trial] = active\n",
    "            ActiveInit[task_id, trial] = activeInit\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "71b84249",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\fsq & 1 & 3.43 & 6 & 2 & 3.43 & 5 \\\\\n",
      "\\fsqrt & 2 & 3.86 & 7 & 1 & 2.14 & 4 \\\\\n",
      "\\fsig & 1 & 3.67 & 7 & 2 & 2.95 & 5 \\\\\n",
      "\n",
      "\\fsq & 34 & 35.71 & 36 \\\\\n",
      "\\fsqrt & 29 & 33.76 & 36 \\\\\n",
      "\\fsig & 34 & 35.62 & 36 \\\\\n",
      "\n",
      "\\fsq & 1 & 3.4 & 6 & 2 & 3.4 & 5 & 34 & 35.7 & 36 \\\\\n",
      "\\fsqrt & 2 & 3.9 & 7 & 1 & 2.1 & 4 & 29 & 33.8 & 36 \\\\\n",
      "\\fsig & 1 & 3.7 & 7 & 2 & 3.0 & 5 & 34 & 35.6 & 36 \\\\\n",
      "overall & 1 & 3.7 & 7 & 1 & 2.8 & 5 & 29 & 35.0 & 36 \\\\\n"
     ]
    }
   ],
   "source": [
    "functions = (\"\\\\fsq\", \"\\\\fsqrt\", \"\\\\fsig\")\n",
    "for f_id, f_name in enumerate(functions):\n",
    "    print(f_name, end=' & ')\n",
    "    print(int(ActiveInit.min(axis=1)[f_id]), end=' & ')\n",
    "    print(\"{:.2f}\".format((np.mean(ActiveInit, axis=1)[f_id])), end=' & ')\n",
    "    print(int(ActiveInit.max(axis=1)[f_id]), end=' & ')\n",
    "    print(int(Active.min(axis=1)[f_id]), end=' & ')\n",
    "    print(\"{:.2f}\".format((np.mean(Active, axis=1)[f_id])), end=' & ')\n",
    "    print(int(Active.max(axis=1)[f_id]), end=' ')\n",
    "    print(\"\\\\\\\\\")\n",
    "    \n",
    "print()\n",
    "\n",
    "n_neurons = K*K\n",
    "for f_id, f_name in enumerate(functions):\n",
    "    print(f_name, end=' & ')\n",
    "    print(n_neurons - int(Dead.max(axis=1)[f_id]), end=' & ')\n",
    "    print(\"{:.2f}\".format(n_neurons - (np.mean(Dead, axis=1)[f_id])), end=' & ')\n",
    "    print(n_neurons - int(Dead.min(axis=1)[f_id]), end=' ')\n",
    "    print(\"\\\\\\\\\")\n",
    "    \n",
    "print()\n",
    "\n",
    "for f_id, f_name in enumerate(functions):\n",
    "    print(f_name, end=' & ')\n",
    "    print(int(ActiveInit.min(axis=1)[f_id]), end=' & ')\n",
    "    print(\"{:.1f}\".format((np.mean(ActiveInit, axis=1)[f_id])), end=' & ')\n",
    "    print(int(ActiveInit.max(axis=1)[f_id]), end=' & ')\n",
    "    print(int(Active.min(axis=1)[f_id]), end=' & ')\n",
    "    print(\"{:.1f}\".format((np.mean(Active, axis=1)[f_id])), end=' & ')\n",
    "    print(int(Active.max(axis=1)[f_id]), end=' & ')\n",
    "    print(n_neurons - int(Dead.max(axis=1)[f_id]), end=' & ')\n",
    "    print(\"{:.1f}\".format(n_neurons - (np.mean(Dead, axis=1)[f_id])), end=' & ')\n",
    "    print(n_neurons - int(Dead.min(axis=1)[f_id]), end=' ')\n",
    "    print(\"\\\\\\\\\")\n",
    "print(\"overall & \", end='')\n",
    "print(int(ActiveInit.min()), end=' & ')\n",
    "print(\"{:.1f}\".format(np.mean(ActiveInit)), end=' & ')\n",
    "print(int(ActiveInit.max()), end=' & ')\n",
    "print(int(Active.min()), end=' & ')\n",
    "print(\"{:.1f}\".format(np.mean(Active)), end=' & ')\n",
    "print(int(Active.max()), end=' & ')\n",
    "print(n_neurons - int(Dead.max()), end=' & ')\n",
    "print(\"{:.1f}\".format(n_neurons - np.mean(Dead)), end=' & ')\n",
    "print(n_neurons - int(Dead.min()), end=' ')\n",
    "print(\"\\\\\\\\\")\n",
    "\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3cbc8265",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['monotonic_alt', 'smooth']\n",
      "[[1.0385e-04 9.7103e-06]\n",
      " [3.1755e-04 1.6241e-05]\n",
      " [2.1992e-04 6.3369e-06]] \n",
      "\n",
      "[[5.0000 3.0000 2.0000 5.0000 3.0000 3.0000 4.0000 3.0000 3.0000 4.0000\n",
      "  3.0000 4.0000 3.0000 3.0000 4.0000 3.0000 5.0000 2.0000 3.0000 5.0000\n",
      "  2.0000]\n",
      " [1.0000 3.0000 1.0000 4.0000 2.0000 2.0000 4.0000 3.0000 2.0000 1.0000\n",
      "  3.0000 2.0000 2.0000 4.0000 1.0000 1.0000 1.0000 3.0000 2.0000 1.0000\n",
      "  2.0000]\n",
      " [4.0000 3.0000 5.0000 3.0000 2.0000 3.0000 2.0000 2.0000 2.0000 3.0000\n",
      "  4.0000 4.0000 2.0000 2.0000 2.0000 3.0000 4.0000 2.0000 2.0000 4.0000\n",
      "  4.0000]]\n",
      "max: 5.0\n",
      "median: 3.0\n",
      "mean: 2.8412698412698414\n",
      "count 1: 7\n",
      "[[4.0000 1.0000 5.0000 4.0000 3.0000 3.0000 2.0000 4.0000 4.0000 4.0000\n",
      "  1.0000 3.0000 3.0000 3.0000 5.0000 4.0000 5.0000 2.0000 6.0000 3.0000\n",
      "  3.0000]\n",
      " [2.0000 4.0000 3.0000 3.0000 3.0000 7.0000 5.0000 6.0000 6.0000 4.0000\n",
      "  3.0000 5.0000 2.0000 6.0000 4.0000 5.0000 2.0000 3.0000 2.0000 2.0000\n",
      "  4.0000]\n",
      " [3.0000 4.0000 5.0000 3.0000 4.0000 3.0000 2.0000 2.0000 3.0000 4.0000\n",
      "  3.0000 4.0000 5.0000 2.0000 4.0000 1.0000 6.0000 5.0000 2.0000 5.0000\n",
      "  7.0000]]\n",
      "max: 7.0\n",
      "median: 4.0\n",
      "mean: 3.6507936507936507\n",
      "count 1: 3\n",
      "[[0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 2.0000\n",
      "  0.0000 0.0000 0.0000 0.0000 0.0000 2.0000 0.0000 2.0000 0.0000 0.0000\n",
      "  0.0000]\n",
      " [2.0000 6.0000 2.0000 5.0000 7.0000 0.0000 0.0000 6.0000 1.0000 0.0000\n",
      "  0.0000 0.0000 4.0000 0.0000 0.0000 4.0000 4.0000 0.0000 0.0000 6.0000\n",
      "  0.0000]\n",
      " [0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 2.0000 0.0000\n",
      "  0.0000 2.0000 0.0000 0.0000 0.0000 0.0000 0.0000 2.0000 0.0000 2.0000\n",
      "  0.0000]]\n",
      "max: 7.0\n",
      "median: 0.0\n",
      "mean: 0.9682539682539683\n"
     ]
    }
   ],
   "source": [
    "np.set_printoptions(precision=4, floatmode='fixed')\n",
    "print(methods)\n",
    "print(np.median(MSE_test, axis=2), '\\n')\n",
    "\n",
    "print(Active)\n",
    "print(\"max:\", Active.max())\n",
    "print(\"median:\", np.median(Active))\n",
    "print(\"mean:\", np.mean(Active))\n",
    "print(\"count 1:\", np.count_nonzero(Active == 1))\n",
    "\n",
    "print(ActiveInit)\n",
    "print(\"max:\", ActiveInit.max())\n",
    "print(\"median:\", np.median(ActiveInit))\n",
    "print(\"mean:\", np.mean(ActiveInit))\n",
    "print(\"count 1:\", np.count_nonzero(ActiveInit == 1))\n",
    "\n",
    "n_neurons = K*K\n",
    "print(Dead)\n",
    "print(\"max:\", Dead.max())\n",
    "print(\"median:\", np.median(Dead))\n",
    "print(\"mean:\", np.mean(Dead))\n",
    "#print(\"count 1:\", np.count_nonzero(Dead == 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dea75173",
   "metadata": {},
   "source": [
    "[[ 0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000  0.0000\n",
    "   2.0000  0.0000  0.0000  0.0000  2.0000  0.0000  2.0000  0.0000  2.0000\n",
    "   0.0000  0.0000  0.0000]\n",
    " [ 2.0000  8.0000  2.0000  6.0000 12.0000  2.0000  6.0000  6.0000  2.0000\n",
    "   0.0000  0.0000  0.0000  4.0000  1.0000  6.0000  4.0000  4.0000  0.0000\n",
    "   3.0000  6.0000  0.0000]\n",
    " [ 0.0000  0.0000  0.0000  2.0000  0.0000  1.0000  0.0000  0.0000  2.0000\n",
    "   0.0000  1.0000  4.0000  0.0000  0.0000  0.0000  0.0000  0.0000  4.0000\n",
    "   0.0000  6.0000  0.0000]]\n",
    "max: 0.16666666666666666\n",
    "mean: 0.022486772486772486\n",
    "median: 0.0\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
