{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "273bf6f3",
   "metadata": {},
   "source": [
    "# Smooth Monotonic Networks: Time comparison\n",
    "\n",
    "Always difficult, depends on implmnetaiton, hardware, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e5836a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec865337",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.metrics import mean_squared_error as mse\n",
    "from sklearn.metrics import r2_score as r2\n",
    "from sklearn.isotonic import IsotonicRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tnrange\n",
    "\n",
    "from xgboost import XGBRegressor\n",
    "\n",
    "from pmlayer.torch.layers import HLattice\n",
    "\n",
    "from MonotonicNN import SmoothMonotonicNN, MonotonicNN, MonotonicNNAlt\n",
    "from MonotonicNNPaperUtils import Progress, total_params\n",
    "\n",
    "from monotonenorm import GroupSort, direct_norm, SigmaNet\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b08989c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.1\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "89fe55ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FastSMM2(nn.Module):\n",
    "    def __init__(self, in_features, K, b_z = 1., b_t = 1., beta=-1.):\n",
    "        super(FastSMM2, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.K = K\n",
    "        self.beta_init = beta\n",
    "        self.b_z = b_z\n",
    "        self.b_t = b_t\n",
    "        self.beta = torch.nn.Parameter(torch.ones(1), requires_grad=True)\n",
    "        self.log_weight = nn.Parameter(torch.Tensor(K*K, in_features))\n",
    "        self.biases = nn.Parameter(torch.Tensor(K*K))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters_(self):\n",
    "        nn.init.trunc_normal_(self.log_weight, std=self.b_z)\n",
    "        nn.init.trunc_normal_(self.biasses, std=self.b_t)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "        \n",
    "    def reset_parameters(self):\n",
    "        nn.init.constant_(self.log_weight, 1.)\n",
    "        nn.init.constant_(self.biases, 1.)\n",
    "        self.biases.data = torch.arange(self.K*self.K, dtype=torch.float)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "\n",
    "    def forward(self, input):\n",
    "        return forward_(input, self.log_weight.exp(), self.biases, self.beta.exp(), self.K)\n",
    "    \n",
    "@torch.jit.script\n",
    "def forward_(input, weights, biases, beta: float, K: int):\n",
    "    linear = nn.functional.linear(input, weights, biases).unfold(1, K, K) \n",
    "    linear = torch.logsumexp(beta * linear, dim=2)/beta\n",
    "    linear = -torch.logsumexp(-beta * linear, dim=1)/beta\n",
    "    return linear\n",
    "\n",
    "class FastSMM(nn.Module):\n",
    "    def __init__(self, in_features, K, b_z = 1., b_t = 1., beta=-1.):\n",
    "        super(FastSMM, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.K = K\n",
    "        self.beta_init = beta\n",
    "        self.b_z = b_z\n",
    "        self.b_t = b_t\n",
    "        self.beta = torch.nn.Parameter(torch.ones(1), requires_grad=True)\n",
    "        self.log_weight = nn.Parameter(torch.Tensor(K*K, in_features))\n",
    "        self.biases = nn.Parameter(torch.Tensor(K*K))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters_(self):\n",
    "        nn.init.trunc_normal_(self.log_weight, std=self.b_z)\n",
    "        nn.init.trunc_normal_(self.biasses, std=self.b_t)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "        \n",
    "    def reset_parameters(self):\n",
    "        nn.init.constant_(self.log_weight, 1.)\n",
    "        nn.init.constant_(self.biases, 1.)\n",
    "        self.biases.data = torch.arange(self.K*self.K, dtype=torch.float)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "\n",
    "    def forward(self, input):\n",
    "        beta = self.beta.exp()\n",
    "        linear = nn.functional.linear(input, self.log_weight.exp(), self.biases).unfold(1, self.K, self.K) \n",
    "        linear = torch.logsumexp(beta * linear, dim=2)/beta\n",
    "        linear = -torch.logsumexp(-beta * linear, dim=1)/beta\n",
    "        return linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "96a84c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_torch(model, x, y, threshold=1e-3, max_iterations=1000):\n",
    "    P = Progress(5, threshold=threshold)\n",
    "    loss_function = nn.MSELoss()\n",
    "    optimizer = torch.optim.Rprop(model.parameters(), lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))\n",
    "    dead = 0\n",
    "    for epoch in range(max_iterations):\n",
    "        pred_y = model(x)\n",
    "        loss = loss_function(pred_y, y)\n",
    "        stop = P.update(loss.item())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.zero_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c15af35",
   "metadata": {},
   "source": [
    "## Multivariate experiments\n",
    "Section 4.2 in the manuscript."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "904d464a",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 21  # number of trials, odd number for having a \"median trial\"\n",
    "ls = 75  # lattice points (k in original paper)\n",
    "ls_small = 35\n",
    "K = 6  # number of SMM groups, we always use H_k = K\n",
    "N_train = 100  # number of examples in training data set\n",
    "N_test = 1000 # number of examples in test data set\n",
    "sigma = 0.01  # noise level, feel free to vary \n",
    "width_small = K\n",
    "width = K+2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "249b98e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generatePoly(dim=2, degree=2, sigma_train=0., sigma_test=0, N_train=100, N_test=100):\n",
    "    x_train = np.random.rand(N_train, dim)\n",
    "    x_test = np.random.rand(N_test, dim)\n",
    "    poly = PolynomialFeatures(degree)  # includes bias\n",
    "    x_poly_train = poly.fit_transform(x_train)\n",
    "    x_poly_test = poly.fit_transform(x_test)\n",
    "    w = np.random.rand(x_poly_train.shape[1])\n",
    "    w_sum = w.sum()\n",
    "    y_train = np.sum(x_poly_train * w, axis=1)/w_sum + sigma_train * np.random.normal(0, 1., N_train)\n",
    "    y_test = np.sum(x_poly_test * w, axis=1)/w_sum + sigma_test * np.random.normal(0, 1., N_test)\n",
    "    return x_train, y_train, x_test, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "31a02732",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_train = 500\n",
    "N_test = 1000\n",
    "T = 21\n",
    "trees = 100\n",
    "trees2 = 200\n",
    "\n",
    "dims = [2, 4, 6]\n",
    "\n",
    "degree = 2\n",
    "ls = [10, 3, 2]\n",
    "K = 6\n",
    "sigma = 0.01  # noise level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d0c861a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_it(method, dim, lattice_size): \n",
    "    total_t = 0\n",
    "    no_params = 0\n",
    "    for trial in tnrange(T):\n",
    "        seed = task_id + trial*N_tasks\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        lattice_sizes = list(np.ones(dim)*lattice_size)\n",
    "        lattice_sizes_plus = list(np.ones(dim)*(lattice_size + 1))\n",
    "        lattice_sizes_tensor = torch.tensor(lattice_sizes, dtype=torch.long)\n",
    "        lattice_sizes_tensor_plus = torch.tensor(lattice_sizes_plus, dtype=torch.long)\n",
    "        increasing = list(range(dim))\n",
    "\n",
    "        x_train, y_train, x_test, y_test = generatePoly(dim, degree=degree, sigma_train=sigma, sigma_test=0., N_train=N_train, N_test=N_test)\n",
    "        x_train_torch = torch.from_numpy(x_train.astype(np.float32)).clone()\n",
    "        y_train_torch = torch.from_numpy(y_train.astype(np.float32)).clone()\n",
    "        x_test_torch = torch.from_numpy(x_test.astype(np.float32)).clone()\n",
    "        y_test_torch = torch.from_numpy(y_test.astype(np.float32)).clone()\n",
    "\n",
    "        match method:\n",
    "            case 'xgboost':             \n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train, y_train)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'xgboost_val':             \n",
    "                x_train_small, x_val, y_train_small, y_val = train_test_split(x_train, y_train, test_size=.25, random_state=42)\n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees, \n",
    "                                     early_stopping_rounds=(trees // 10), verbosity=0)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train_small, y_train_small, eval_set=[(x_train_small, y_train_small), (x_val, y_val)], verbose=0)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'xgboost2':             \n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees2)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train, y_train)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'xgboost2_val':             \n",
    "                x_train_small, x_val, y_train_small, y_val = train_test_split(x_train, y_train, test_size=.25, random_state=42)\n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees2, \n",
    "                                     early_stopping_rounds=(trees2 // 10), verbosity=0)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train_small, y_train_small, eval_set=[(x_train_small, y_train_small), (x_val, y_val)], verbose=0)\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'lattice':\n",
    "                model = HLattice(dim, lattice_sizes_tensor, increasing)\n",
    "                if(trial==0):\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                    no_params = total_params(model)\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()      \n",
    "            case 'lattice_plus':\n",
    "                model = HLattice(dim, lattice_sizes_tensor_plus, increasing)\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()      \n",
    "            case 'smooth':\n",
    "                #model = SmoothMonotonicNN(dim, K, K, beta=-1.)\n",
    "                model = torch.jit.script(FastSMM(dim, K, beta=-1.))\n",
    "                model = FastSMM2(dim, K, beta=-1.)\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()\n",
    "            case 'lip_small':\n",
    "                width_small = int(dim + K)\n",
    "                model = torch.nn.Sequential(\n",
    "                    direct_norm(torch.nn.Linear(dim, width_small), kind=\"one-inf\"),\n",
    "                    GroupSort(width_small//2),\n",
    "                    direct_norm(torch.nn.Linear(width_small, width_small), kind=\"inf\"),\n",
    "                    GroupSort(width_small//2),\n",
    "                    direct_norm(torch.nn.Linear(width_small, 1), kind=\"inf\"),\n",
    "                )\n",
    "                model = SigmaNet(model, sigma=1, monotone_constraints=(1,))\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()\n",
    "            case 'lip':\n",
    "                width = int(dim + K) + 2\n",
    "                model = torch.nn.Sequential(\n",
    "                    direct_norm(torch.nn.Linear(dim, width), kind=\"one-inf\"),\n",
    "                    GroupSort(width//2),\n",
    "                    direct_norm(torch.nn.Linear(width, width), kind=\"inf\"),\n",
    "                    GroupSort(width//2),\n",
    "                    direct_norm(torch.nn.Linear(width, 1), kind=\"inf\"),\n",
    "                )\n",
    "                model = SigmaNet(model, sigma=1, monotone_constraints=(1,))\n",
    "                t0 = time.time()\n",
    "                t = time.time() - t0\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "        total_t += t\n",
    "    return total_t, no_params\n",
    "        \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1066eec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['smooth', 'lattice', 'lip_small']\n",
    "dims = [2, 4, 6]\n",
    "N_methods = len(methods)\n",
    "N_tasks = len(dims)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "252e0425",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dimension 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "699afb707eb34e298240325313622394",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "smooth 109 parameters\n",
      "smooth training time: 11.383250713348389\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8593e3a205ee4cc99e7a2176b8086757",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice 100 parameters\n"
     ]
    }
   ],
   "source": [
    "times = np.zeros((N_tasks, N_methods))\n",
    "params = np.zeros((N_tasks, N_methods))\n",
    "for task_id, dim in enumerate(dims):\n",
    "    print(\"dimension\", dim)\n",
    "    for method_id, method in enumerate(methods):\n",
    "        t, p = do_it(method, dim, ls[task_id])\n",
    "        print(method, \"training time:\", t)\n",
    "        times[task_id, method_id] = t\n",
    "        params[task_id, method_id] = p\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45b19df3",
   "metadata": {},
   "outputs": [],
   "source": [
    "functions = (\"$d=2$\", \"$d=4$\", \"$d=6$\")\n",
    "for f_id, f_name in enumerate(functions):\n",
    "    print(f_name, end= \" & \")\n",
    "    for i, (t, p) in enumerate(zip(times[f_id, :], params[f_id, ])):\n",
    "        print(\"{:.2f}\".format(t), \" & \", int(p), end='')\n",
    "        if(i < len(methods)-1):\n",
    "            print(\" & \", end='')\n",
    "        else:\n",
    "            print(\" \\\\\\\\\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bdbbee1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
