{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "273bf6f3",
   "metadata": {},
   "source": [
    "# Smooth Monotonic Networks: Time comparison\n",
    "\n",
    "Always difficult, depends on implmnetaiton, hardware, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "8e5836a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ec865337",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.metrics import mean_squared_error as mse\n",
    "from sklearn.metrics import r2_score as r2\n",
    "from sklearn.isotonic import IsotonicRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tnrange\n",
    "\n",
    "from xgboost import XGBRegressor\n",
    "\n",
    "from pmlayer.torch.layers import HLattice\n",
    "\n",
    "from MonotonicNN import SmoothMonotonicNN, MonotonicNN, MonotonicNNAlt\n",
    "from MonotonicNNPaperUtils import Progress, total_params\n",
    "\n",
    "from monotonenorm import GroupSort, direct_norm, SigmaNet\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b08989c6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.1\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "89fe55ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FastSMM2(nn.Module):\n",
    "    def __init__(self, in_features, K, b_z = 1., b_t = 1., beta=-1.):\n",
    "        super(FastSMM2, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.K = K\n",
    "        self.beta_init = beta\n",
    "        self.b_z = b_z\n",
    "        self.b_t = b_t\n",
    "        self.beta = torch.nn.Parameter(torch.ones(1), requires_grad=True)\n",
    "        self.log_weight = nn.Parameter(torch.Tensor(K*K, in_features))\n",
    "        self.biases = nn.Parameter(torch.Tensor(K*K))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters_(self):\n",
    "        nn.init.trunc_normal_(self.log_weight, std=self.b_z)\n",
    "        nn.init.trunc_normal_(self.biasses, std=self.b_t)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "        \n",
    "    def reset_parameters(self):\n",
    "        nn.init.constant_(self.log_weight, 1.)\n",
    "        nn.init.constant_(self.biases, 1.)\n",
    "        self.biases.data = torch.arange(self.K*self.K, dtype=torch.float)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "\n",
    "    def forward(self, input):\n",
    "        return forward_(input, self.log_weight.exp(), self.biases, self.beta.exp(), self.K)\n",
    "    \n",
    "@torch.jit.script\n",
    "def forward_(input, weights, biases, beta: float, K: int):\n",
    "    linear = nn.functional.linear(input, weights, biases).unfold(1, K, K) \n",
    "    linear = torch.logsumexp(beta * linear, dim=2)/beta\n",
    "    linear = -torch.logsumexp(-beta * linear, dim=1)/beta\n",
    "    return linear\n",
    "\n",
    "class FastSMM(nn.Module):\n",
    "    def __init__(self, in_features, K, b_z = 1., b_t = 1., beta=-1.):\n",
    "        super(FastSMM, self).__init__()\n",
    "        self.in_features = in_features\n",
    "        self.K = K\n",
    "        self.beta_init = beta\n",
    "        self.b_z = b_z\n",
    "        self.b_t = b_t\n",
    "        self.beta = torch.nn.Parameter(torch.ones(1), requires_grad=True)\n",
    "        self.log_weight = nn.Parameter(torch.Tensor(K*K, in_features))\n",
    "        self.biases = nn.Parameter(torch.Tensor(K*K))\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters_(self):\n",
    "        nn.init.trunc_normal_(self.log_weight, std=self.b_z)\n",
    "        nn.init.trunc_normal_(self.biasses, std=self.b_t)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "        \n",
    "    def reset_parameters(self):\n",
    "        nn.init.constant_(self.log_weight, 1.)\n",
    "        nn.init.constant_(self.biases, 1.)\n",
    "        self.biases.data = torch.arange(self.K*self.K, dtype=torch.float)\n",
    "        nn.init.constant_(self.beta, self.beta_init)\n",
    "\n",
    "    def forward(self, input):\n",
    "        beta = self.beta.exp()\n",
    "        linear = nn.functional.linear(input, self.log_weight.exp(), self.biases).unfold(1, self.K, self.K) \n",
    "        linear = torch.logsumexp(beta * linear, dim=2)/beta\n",
    "        linear = -torch.logsumexp(-beta * linear, dim=1)/beta\n",
    "        return linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "96a84c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit_torch(model, x, y, threshold=1e-3, max_iterations=1000):\n",
    "    P = Progress(5, threshold=threshold)\n",
    "    loss_function = nn.MSELoss()\n",
    "    optimizer = torch.optim.Rprop(model.parameters(), lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))\n",
    "    dead = 0\n",
    "    for epoch in range(max_iterations):\n",
    "        pred_y = model(x)\n",
    "        loss = loss_function(pred_y, y)\n",
    "        stop = P.update(loss.item())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.zero_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c15af35",
   "metadata": {},
   "source": [
    "## Multivariate experiments\n",
    "Section 4.2 in the manuscript."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "904d464a",
   "metadata": {},
   "outputs": [],
   "source": [
    "T = 21  # number of trials, odd number for having a \"median trial\"\n",
    "ls = 75  # lattice points (k in original paper)\n",
    "ls_small = 35\n",
    "K = 6  # number of SMM groups, we always use H_k = K\n",
    "N_train = 100  # number of examples in training data set\n",
    "N_test = 1000 # number of examples in test data set\n",
    "sigma = 0.01  # noise level, feel free to vary \n",
    "width_small = K\n",
    "width = K+2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "249b98e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generatePoly(dim=2, degree=2, sigma_train=0., sigma_test=0, N_train=100, N_test=100):\n",
    "    x_train = np.random.rand(N_train, dim)\n",
    "    x_test = np.random.rand(N_test, dim)\n",
    "    poly = PolynomialFeatures(degree)  # includes bias\n",
    "    x_poly_train = poly.fit_transform(x_train)\n",
    "    x_poly_test = poly.fit_transform(x_test)\n",
    "    w = np.random.rand(x_poly_train.shape[1])\n",
    "    w_sum = w.sum()\n",
    "    y_train = np.sum(x_poly_train * w, axis=1)/w_sum + sigma_train * np.random.normal(0, 1., N_train)\n",
    "    y_test = np.sum(x_poly_test * w, axis=1)/w_sum + sigma_test * np.random.normal(0, 1., N_test)\n",
    "    return x_train, y_train, x_test, y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "31a02732",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_train = 500\n",
    "N_test = 1000\n",
    "T = 21\n",
    "trees = 100\n",
    "trees2 = 200\n",
    "\n",
    "dims = [2, 4, 6]\n",
    "\n",
    "degree = 2\n",
    "ls = [10, 3, 2]\n",
    "K = 6\n",
    "sigma = 0.01  # noise level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "d0c861a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_it(method, dim, lattice_size): \n",
    "    total_t = 0\n",
    "    no_params = 0\n",
    "    for trial in tnrange(T):\n",
    "        seed = task_id + trial*N_tasks\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        lattice_sizes = list(np.ones(dim)*lattice_size)\n",
    "        lattice_sizes_plus = list(np.ones(dim)*(lattice_size + 1))\n",
    "        lattice_sizes_tensor = torch.tensor(lattice_sizes, dtype=torch.long)\n",
    "        lattice_sizes_tensor_plus = torch.tensor(lattice_sizes_plus, dtype=torch.long)\n",
    "        increasing = list(range(dim))\n",
    "\n",
    "        x_train, y_train, x_test, y_test = generatePoly(dim, degree=degree, sigma_train=sigma, sigma_test=0., N_train=N_train, N_test=N_test)\n",
    "        x_train_torch = torch.from_numpy(x_train.astype(np.float32)).clone()\n",
    "        y_train_torch = torch.from_numpy(y_train.astype(np.float32)).clone()\n",
    "        x_test_torch = torch.from_numpy(x_test.astype(np.float32)).clone()\n",
    "        y_test_torch = torch.from_numpy(y_test.astype(np.float32)).clone()\n",
    "\n",
    "        match method:\n",
    "            case 'xgboost':             \n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train, y_train)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'xgboost_val':             \n",
    "                x_train_small, x_val, y_train_small, y_val = train_test_split(x_train, y_train, test_size=.25, random_state=42)\n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees, \n",
    "                                     early_stopping_rounds=(trees // 10), verbosity=0)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train_small, y_train_small, eval_set=[(x_train_small, y_train_small), (x_val, y_val)], verbose=0)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'xgboost2':             \n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees2)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train, y_train)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'xgboost2_val':             \n",
    "                x_train_small, x_val, y_train_small, y_val = train_test_split(x_train, y_train, test_size=.25, random_state=42)\n",
    "                model = XGBRegressor(monotone_constraints=tuple(increasing), n_estimators=trees2, \n",
    "                                     early_stopping_rounds=(trees2 // 10), verbosity=0)\n",
    "                t0 = time.time()\n",
    "                model.fit(x_train_small, y_train_small, eval_set=[(x_train_small, y_train_small), (x_val, y_val)], verbose=0)\n",
    "                y_pred_train = model.predict(x_train)\n",
    "                y_pred_test = model.predict(x_test)\n",
    "            case 'lattice':\n",
    "                model = HLattice(dim, lattice_sizes_tensor, increasing)\n",
    "                if(trial==0):\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                    no_params = total_params(model)\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()      \n",
    "            case 'lattice_plus':\n",
    "                model = HLattice(dim, lattice_sizes_tensor_plus, increasing)\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()      \n",
    "            case 'smooth':\n",
    "                #model = SmoothMonotonicNN(dim, K, K, beta=-1.)\n",
    "                model = torch.jit.script(FastSMM(dim, K, beta=-1.))\n",
    "                model = FastSMM2(dim, K, beta=-1.)\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch)\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()\n",
    "            case 'lip_small':\n",
    "                width_small = int(dim + K)\n",
    "                model = torch.nn.Sequential(\n",
    "                    direct_norm(torch.nn.Linear(dim, width_small), kind=\"one-inf\"),\n",
    "                    GroupSort(width_small//2),\n",
    "                    direct_norm(torch.nn.Linear(width_small, width_small), kind=\"inf\"),\n",
    "                    GroupSort(width_small//2),\n",
    "                    direct_norm(torch.nn.Linear(width_small, 1), kind=\"inf\"),\n",
    "                )\n",
    "                model = SigmaNet(model, sigma=1, monotone_constraints=(1,))\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()\n",
    "            case 'lip':\n",
    "                width = int(dim + K) + 2\n",
    "                model = torch.nn.Sequential(\n",
    "                    direct_norm(torch.nn.Linear(dim, width), kind=\"one-inf\"),\n",
    "                    GroupSort(width//2),\n",
    "                    direct_norm(torch.nn.Linear(width, width), kind=\"inf\"),\n",
    "                    GroupSort(width//2),\n",
    "                    direct_norm(torch.nn.Linear(width, 1), kind=\"inf\"),\n",
    "                )\n",
    "                model = SigmaNet(model, sigma=1, monotone_constraints=(1,))\n",
    "                t0 = time.time()\n",
    "                fit_torch(model, x_train_torch, y_train_torch.reshape(-1,1))\n",
    "                t = time.time() - t0\n",
    "                y_pred_train = model(x_train_torch).detach().numpy()\n",
    "                y_pred_test = model(x_test_torch).detach().numpy()\n",
    "                if(trial==0):\n",
    "                    no_params = total_params(model)\n",
    "                    print(method, total_params(model), \"parameters\")\n",
    "        total_t += t\n",
    "    return total_t, no_params\n",
    "        \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1066eec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['smooth', 'lattice', 'lattice_plus', 'lip_small', 'lip']\n",
    "dims = [2, 4, 6]\n",
    "N_methods = len(methods)\n",
    "N_tasks = len(dims)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "252e0425",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dimension 2\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32179a7ad71f4bfd900f7e1f8dd533ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "smooth 109 parameters\n",
      "smooth training time: 11.538730382919312\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ba0865c8a4ef48e29926a16a3361a325",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice 100 parameters\n",
      "lattice training time: 403.8710014820099\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "631b6592fdf94a398a43be18246240b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice_plus 121 parameters\n",
      "lattice_plus training time: 498.7088816165924\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "745848566f324e4a8a342bdfff295fb4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lip_small 105 parameters\n",
      "lip_small training time: 12.010257482528687\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "024994725b0746d6b4f3228d0a7fc6d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lip 151 parameters\n",
      "lip training time: 12.47550892829895\n",
      "dimension 4\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47b86d0e655a45ed8e580ca4f4f79ce4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "smooth 181 parameters\n",
      "smooth training time: 11.466137886047363\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2056f0d243514036b7a59ec260c80cad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice 81 parameters\n",
      "lattice training time: 358.83428144454956\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "db3ef9e1c32243018e495c8005aaebaf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice_plus 256 parameters\n",
      "lattice_plus training time: 1351.89594912529\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d361e6879e54aa0bc5cba01b384cb83",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lip_small 171 parameters\n",
      "lip_small training time: 12.502386569976807\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "126cc7c2bff24a23a7b30d95033d0484",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lip 229 parameters\n",
      "lip training time: 12.745147943496704\n",
      "dimension 6\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "010e607b9398475c96049df6c87a0737",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "smooth 253 parameters\n",
      "smooth training time: 11.427737474441528\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "35314a2ad948473cb52fc3af27cb4e53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice 64 parameters\n",
      "lattice training time: 301.45093059539795\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "81eac53328c443f193a8510431391be1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lattice_plus 729 parameters\n",
      "lattice_plus training time: 8215.190698862076\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "727a4463057249779afa884d7abe459f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lip_small 253 parameters\n",
      "lip_small training time: 12.806957006454468\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92005c8fd85344eb91052d5e55f5fd39",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/21 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lip 323 parameters\n",
      "lip training time: 13.265650033950806\n"
     ]
    }
   ],
   "source": [
    "times = np.zeros((N_tasks, N_methods))\n",
    "params = np.zeros((N_tasks, N_methods))\n",
    "for task_id, dim in enumerate(dims):\n",
    "    print(\"dimension\", dim)\n",
    "    for method_id, method in enumerate(methods):\n",
    "        t, p = do_it(method, dim, ls[task_id])\n",
    "        print(method, \"training time:\", t)\n",
    "        times[task_id, method_id] = t\n",
    "        params[task_id, method_id] = p\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "45b19df3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "$d=2$ & 11.54  &  109 & 403.87  &  100 & 498.71  &  121 & 12.01  &  105 & 12.48  &  151 \\\\\n",
      "$d=4$ & 11.47  &  181 & 358.83  &  81 & 1351.90  &  256 & 12.50  &  171 & 12.75  &  229 \\\\\n",
      "$d=6$ & 11.43  &  253 & 301.45  &  64 & 8215.19  &  729 & 12.81  &  253 & 13.27  &  323 \\\\\n"
     ]
    }
   ],
   "source": [
    "functions = (\"$d=2$\", \"$d=4$\", \"$d=6$\")\n",
    "for f_id, f_name in enumerate(functions):\n",
    "    print(f_name, end= \" & \")\n",
    "    for i, (t, p) in enumerate(zip(times[f_id, :], params[f_id, ])):\n",
    "        print(\"{:.2f}\".format(t), \" & \", int(p), end='')\n",
    "        if(i < len(methods)-1):\n",
    "            print(\" & \", end='')\n",
    "        else:\n",
    "            print(\" \\\\\\\\\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15b4bb41",
   "metadata": {},
   "source": [
    "$d=2$ & 10.07  &  109 & 361.83  &  100 & 482.04  &  121 & 12.45  &  105 & 0.00  &  151 \\\\\n",
    "$d=4$ & 11.09  &  181 & 349.35  &  81 & 1360.44  &  256 & 12.75  &  171 & 0.00  &  229 \\\\\n",
    "$d=6$ & 10.90  &  253 & 274.73  &  64 & 8082.60  &  729 & 13.00  &  253 & 0.00  &  323 \\\\\n",
    "\n",
    "$d=2$ & 11.38  &  109 & 379.72  &  100 & 11.93  &  105 \\\\\n",
    "$d=4$ & 11.06  &  181 & 336.10  &  81 & 12.54  &  171 \\\\\n",
    "$d=6$ & 11.24  &  253 & 302.88  &  64 & 12.82  &  253 \\\\"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d998fb25",
   "metadata": {},
   "source": [
    "$d=2$ & 10.00  &  109 & 320.49  &  100 & 401.28  &  121 & 9.64  &  105 & 0.00  &  151 \\\\\n",
    "$d=4$ & 10.07  &  181 & 268.68  &  81 & 1178.69  &  256 & 10.16  &  171 & 0.00  &  229 \\\\\n",
    "$d=6$ & 10.55  &  253 & 241.72  &  64 & 7326.46  &  729 & 10.55  &  253 & 0.00  &  323 \\\\"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efc756d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
