{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "41abf071",
   "metadata": {},
   "source": [
    "# Smooth Monotonic Networks: Plots and tables\n",
    "This file reproduces the plots, statistical tests, and LaTeX tables in the manuscript."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "af9bf1d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "\n",
    "from sklearn.metrics import mean_squared_error as mse\n",
    "from sklearn.metrics import r2_score as r2\n",
    "from sklearn.isotonic import IsotonicRegression\n",
    "\n",
    "from scipy.stats import wilcoxon\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n",
    "from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset\n",
    "\n",
    "from tqdm.notebook import tnrange\n",
    "\n",
    "from xgboost import XGBRegressor\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0ea6935",
   "metadata": {},
   "source": [
    "You need to set the following prefix variable for reading the result files produced by ``MonotonicNNFullyMonotoneExperiments.ipynb`` and ``UCI validation.ipynb``."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d204e74f",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"./iclr-\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e9fb400",
   "metadata": {},
   "source": [
    "## Univariate experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "039310a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['MSE_train', 'MSE_test', 'MSE_clip', 'R2_train', 'R2_test', 'X_train', 'Y_train', 'X_test', 'Y_test', 'O_test', 'no_params']\n"
     ]
    }
   ],
   "source": [
    "fn = path + \"univariate.npz\"\n",
    "data = np.load(fn, allow_pickle = True)\n",
    "print(data.files)\n",
    "\n",
    "mse_train = np.swapaxes(data['MSE_train'], 1, 2)\n",
    "mse_test = np.swapaxes(data['MSE_test'], 1, 2)\n",
    "mse_clip = np.swapaxes(data['MSE_clip'], 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cafd67bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 21, 8) 8\n",
      "0.3170680724066777 0.23604582805422694\n"
     ]
    }
   ],
   "source": [
    "labels = (r\"MM\", r\"SMM\", r\"XG\", r\"XG$_{\\text{val}}$\", r\"Iso\", r\"HLL\", r\"LMN$^{\\text{s}}$\", r\"LMN$^{\\text{l}}$\" )\n",
    "\n",
    "print(mse_train.shape, len(labels))\n",
    "print(mse_test.sum(), mse_clip.sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3c2a8eca",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (2572404131.py, line 13)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[5], line 13\u001b[0;36m\u001b[0m\n\u001b[0;31m    f_alt = mse_train[f,:,m] b\u001b[0m\n\u001b[0m                             ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "functions = (\"\\\\fsq\", \"\\\\fsqrt\", \"\\\\fsig\")\n",
    "alpha = .001  \n",
    "sig_test = np.zeros((len(functions), len(labels)))\n",
    "sig_train = np.zeros((len(functions), len(labels)))\n",
    "for f in range(len(functions)):\n",
    "    f_smm_test = mse_test[f,:,1]\n",
    "    f_smm_train = mse_train[f,:,1]\n",
    "    for m in range(len(labels)):\n",
    "        if(m!=1):\n",
    "            f_alt = mse_test[f,:,m] \n",
    "            pv = wilcoxon(f_smm_test, f_alt).pvalue\n",
    "            sig_test[f,m]=(pv < alpha)\n",
    "            f_alt = mse_train[f,:,m] \n",
    "            pv = wilcoxon(f_smm_train, f_alt).pvalue\n",
    "            sig_train[f,m]=(pv < alpha)\n",
    "            #print(pv< 0.01/len(title), pv)\n",
    "print(sig_test)\n",
    "print(sig_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c921782",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.set_printoptions(precision=4, floatmode='fixed')\n",
    "meds =  np.median(mse_test, axis=1)\n",
    "meds_min = np.min(meds, axis=1)\n",
    "\n",
    "scale=1000\n",
    "for f, f_name in enumerate(functions):\n",
    "    print(f_name, '&', end=' ')\n",
    "    for m in range(len(labels)):\n",
    "        med = meds[f, m]\n",
    "        if (med==meds_min[f]):\n",
    "            print(\"\\\\low{\"+\"{:.2f}\".format(med*scale)+\"}\", end='')\n",
    "        else:\n",
    "            print(\"{:.2f}\".format(med*scale), end='')\n",
    "        if sig_test[f,m]:\n",
    "            print(\"\\sigdif\", end='')\n",
    "        if m < len(labels)-1:\n",
    "            print(' & ', end='')\n",
    "        else:\n",
    "            print(' \\\\\\\\')\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d639884",
   "metadata": {},
   "outputs": [],
   "source": [
    "scale=1000\n",
    "for f, f_name in enumerate(functions):\n",
    "    print(f_name, '&', end=' ')\n",
    "    for m in range(len(labels)):\n",
    "        print(\"{:.2f}\".format(np.median(mse_train[f,:,m])*scale), end='')\n",
    "        if m < len(labels)-1:\n",
    "            print(' & ', end='')\n",
    "        else:\n",
    "            print(' \\\\\\\\')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c6e48e3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# methods = ['monotonic', 'smooth', 'xgboost', 'xgboost_val', 'iso', 'hll', 'lip_small', 'lip']\n",
    "\n",
    "fn_tag = \"\"\n",
    "\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{amsmath}')\n",
    "\n",
    "colors=('b', 'g', 'r', 'c', 'm', \"k\")\n",
    "title = (r\"\\Large$f_{\\text{sq}}$\", r\"\\Large$f_{\\text{sqrt}}$\", r\"\\Large$f_{\\text{sig}}$\")\n",
    "fig, ax = plt.subplots(2, 3, sharex=True, figsize=(10, 6), layout='constrained')\n",
    "for i in range(3):\n",
    "    ax[0,i].set_title(title[i])\n",
    "    ax[0,i].boxplot(mse_train[i], labels=labels);\n",
    "    ax[1,i].boxplot(mse_test[i], labels=labels);\n",
    "ax[0,0].set_ylabel(r'training MSE')\n",
    "ax[1,0].set_ylabel(r'test MSE (w/o noise)')\n",
    "plt.savefig(path + \"bar1D\"+ fn_tag + \".pdf\")\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e6c11f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{amsmath}')\n",
    "\n",
    "colors=('b', 'g', 'r', 'c', 'm', \"k\", 'b','y')\n",
    "title = (r\"\\Large$f_{\\text{sq}}$\", r\"\\Large$f_{\\text{sqrt}}$\", r\"\\Large$f_{\\text{sig}}$\")\n",
    "fig, ax = plt.subplots(1, 3, sharex=True, figsize=(10, 3), layout='constrained')\n",
    "for i in range(3):\n",
    "    ax[i].set_title(title[i])\n",
    "    ax[i].boxplot(mse_train[i], labels=labels);\n",
    "ax[0].set_ylabel(r'training MSE')\n",
    "plt.savefig(path + \"bar1D_train\"+ fn_tag + \".pdf\")\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23650ea1",
   "metadata": {},
   "outputs": [],
   "source": [
    "trial = 11\n",
    "fig, ax = plt.subplots(2, 3, figsize=(10, 6), layout='constrained')\n",
    "for i in range(3):\n",
    "    ax[0,i].set_title(title[i])\n",
    "    ax[0,i].plot(data['X_test'][i, trial], data['Y_test'][i, trial], '--', color='0.8')\n",
    "    ax[1,i].plot(data['X_test'][i, trial], data['Y_test'][i, trial], '--', color='0.8')\n",
    "    for j in [2,3,4]:\n",
    "        ax[0,i].plot(data['X_test'][i, trial], data['O_test'][i, j, trial], label=labels[j], color=colors[j])\n",
    "    for j in [0,1,5,7]:\n",
    "        ax[1,i].plot(data['X_test'][i, trial], data['O_test'][i, j, trial], label=labels[j], color=colors[j])\n",
    "    ax[0,0].legend(handlelength=4)\n",
    "    ax[1,0].legend(handlelength=4)\n",
    "    ax[0,0].set_ylabel(r'$y$')\n",
    "    ax[1,0].set_ylabel(r'$y$')\n",
    "    ax[1,i].set_xlabel(r'$x$')\n",
    "    x1, x2, y1, y2 = .7, .9, .875, .975\n",
    "    \n",
    "    if i>0:\n",
    "        axi = inset_axes(ax[0, i], width=\"40%\", height=\"20%\", loc=4, borderpad=1)\n",
    "        for j in [2,3,4]:\n",
    "            axi.plot(data['X_test'][i, trial], data['O_test'][i, j, trial], label=labels[j], color=colors[j])\n",
    "        axi.set_xlim(x1, x2)\n",
    "        axi.set_ylim(y1, y2)\n",
    "        axi.tick_params(labelleft=False, labelbottom=False)\n",
    "        mark_inset(ax[0, i], axi, loc1=2, loc2=4, fc=\"none\", ec=\"0.5\")\n",
    "        axi = inset_axes(ax[1, i], width=\"40%\", height=\"20%\", loc=4, borderpad=1)\n",
    "        for j in [0,1,5,7]:\n",
    "            axi.plot(data['X_test'][i, trial], data['O_test'][i, j, trial], label=labels[j], color=colors[j])\n",
    "        axi.set_xlim(x1, x2)\n",
    "        axi.set_ylim(y1, y2)\n",
    "        axi.tick_params(labelleft=False, labelbottom=False)\n",
    "        mark_inset(ax[1, i], axi, loc1=2, loc2=4, fc=\"none\", ec=\"0.5\")\n",
    "plt.savefig(path + \"example\" + str(trial) + \"_1D\"+ \".pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01583b7a",
   "metadata": {},
   "source": [
    "## Multivariate experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95ddd14f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fn = path + \"multivariate.npz\"\n",
    "data = np.load(fn, allow_pickle = True)\n",
    "print(data.files)\n",
    "\n",
    "mse_train = np.swapaxes(data['MSE_train'], 1, 2)\n",
    "mse_test = np.swapaxes(data['MSE_test'], 1, 2)\n",
    "\n",
    "meds =  np.median(mse_test, axis=1)\n",
    "meds_min = np.min(meds, axis=1)\n",
    "\n",
    "print(mse_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "599c0cf5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#methods = ['smooth', 'xgboost', 'xgboost_val', 'xgboost2', 'xgboost2_val','lattice', 'lattice_plus', 'lip_small', 'lip']\n",
    "#['smooth', 'xgboost', 'xgboost_val', 'xgboost2', 'xgboost2_val','lattice', 'lattice_plus']\n",
    "labels = (r\"SMM\", r\"XG$^{\\text{s}}$\", r\"XG$^{\\text{s}}_{\\text{val}}$\", r\"XG$^{\\text{l}}$\", r\"XG$^{\\text{l}}_{\\text{val}}$\",\n",
    "          r\"HLL$^{\\text{s}}$\", r\"HLL$^{\\text{l}}$\", r\"LMN$^{\\text{s}}$\", r\"LMN$^{\\text{l}}$\")\n",
    "print(len(labels))\n",
    "\n",
    "\n",
    "\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{amsmath}')\n",
    "\n",
    "colors=('b', 'g', 'r', 'c', 'm', \"k\")\n",
    "title = (r\"$d=2$\", r\"$d=4$\", r\"$d=6$\")\n",
    "fig, ax = plt.subplots(2, 3, sharex=True, figsize=(10, 6), layout='constrained')\n",
    "for i in range(3):\n",
    "    ax[0,i].set_title(title[i])\n",
    "    ax[0,i].boxplot(mse_train[i], labels=labels);\n",
    "    ax[1,i].boxplot(mse_test[i], labels=labels);\n",
    "ax[0,0].set_ylabel(r'training MSE')\n",
    "ax[1,0].set_ylabel(r'test MSE (w/o noise)')\n",
    "plt.savefig(path + \"barMulti\"+ \".pdf\")\n",
    "plt.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6ab5ba8",
   "metadata": {},
   "outputs": [],
   "source": [
    "functions = (\"$d=2$\", \"$d=4$\", \"$d=6$\")\n",
    "alpha = .001  \n",
    "ref_index = 0 # reference is SMM, index 0\n",
    "sig_test = np.zeros((len(functions), len(labels)))\n",
    "sig_train = np.zeros((len(functions), len(labels)))\n",
    "for f in range(len(functions)):\n",
    "    f_smm_test = mse_test[f,:,ref_index] \n",
    "    f_smm_train = mse_train[f,:,ref_index]\n",
    "    for m in range(len(labels)):\n",
    "        if(m!=ref_index):\n",
    "            f_alt = mse_test[f,:,m] \n",
    "            pv = wilcoxon(f_smm_test, f_alt).pvalue\n",
    "            sig_test[f,m]=(pv < alpha)\n",
    "            f_alt = mse_train[f,:,m] \n",
    "            pv = wilcoxon(f_smm_train, f_alt).pvalue\n",
    "            sig_train[f,m]=(pv < alpha)\n",
    "            #print(pv< 0.01/len(title), pv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b41bb427",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.set_printoptions(precision=4, floatmode='fixed')\n",
    "scale=1000\n",
    "\n",
    "for f, f_name in enumerate(functions):\n",
    "    print(f_name, '&', end=' ')\n",
    "    for m in range(len(labels)):\n",
    "        med = meds[f, m]\n",
    "        if (med==meds_min[f]):\n",
    "            print(\"\\\\low{\"+\"{:.2f}\".format(med*scale)+\"}\", end='')\n",
    "        else:\n",
    "            print(\"{:.2f}\".format(med*scale), end='')\n",
    "        if sig_test[f,m]:\n",
    "            print(\"\\sigdif\", end='')\n",
    "        if m < len(labels)-1:\n",
    "            print(' & ', end='')\n",
    "        else:\n",
    "            print(' \\\\\\\\')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bb8b242",
   "metadata": {},
   "outputs": [],
   "source": [
    "for f, f_name in enumerate(functions):\n",
    "    print(f_name, '&', end=' ')\n",
    "    for m in range(len(labels)):\n",
    "        print(\"{:.2f}\".format(np.median(mse_train[f,:,m])*scale), end='')\n",
    "        if m < len(labels)-1:\n",
    "            print(' & ', end='')\n",
    "        else:\n",
    "            print(' \\\\\\\\')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "348e1e02",
   "metadata": {},
   "source": [
    "## UCI tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe12754e",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"./iclr-\"\n",
    "methods = (\"\\\\MLPSMM\", \"\\\\SMM\", \"\\\\XG\", \"\\\\HLL\", \"\\\\LMNs\", \"\\\\LMNl\")\n",
    "tasks = (\"\\energyOne\", \"\\energyTwo\", \"\\qsar\", \"\\concrete\")\n",
    "data_energy1 = np.load(path + \"energy-y1-results-val.npz\", allow_pickle = True)\n",
    "data_energy2 = np.load(path + \"energy-y2-results-val.npz\", allow_pickle = True)\n",
    "data_qsar = np.load(path + \"qsar-results-val.npz\", allow_pickle = True)\n",
    "data_concrete = np.load(path + \"concrete-results-val.npz\", allow_pickle = True)\n",
    "print(*data_energy1)\n",
    "print(*data_energy2)\n",
    "print(*data_qsar)\n",
    "print(*data_concrete)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "007c84dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "means = []\n",
    "no_params = []\n",
    "mse_test = data_energy1['MSE_test']\n",
    "means.append(np.mean(mse_test, axis=1))\n",
    "no_params.append(data_energy1['no_params'])\n",
    "mse_test = data_energy2['MSE_test']\n",
    "means.append(np.mean(mse_test, axis=1))\n",
    "no_params.append(data_energy2['no_params'])\n",
    "mse_test = data_qsar['MSE_test']\n",
    "means.append(np.mean(mse_test, axis=1))\n",
    "no_params.append(data_qsar['no_params'])\n",
    "mse_test = data_concrete['MSE_test']\n",
    "means.append(np.mean(mse_test, axis=1))\n",
    "no_params.append(data_concrete['no_params'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a7281ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "scale = 100\n",
    "for i in methods:\n",
    "    print(\" & \\multicolumn{2}{c}{\", i, \"}\", end='')\n",
    "print(\" \\\\\\\\\")\n",
    "for i, (mean, no) in enumerate(zip(means, no_params)):\n",
    "    print(tasks[i], end='')\n",
    "    for j, k in zip(mean, no):\n",
    "        if (j==min(mean)):\n",
    "            print(\" & \", \"\\\\low{\"+\"{:.2f}\".format(j*scale)+'}', end='')\n",
    "        else:\n",
    "            print(\" & \", \"{:.2f}\".format(j*scale), end='')\n",
    "        print(\" & \", \"{:d}\".format(int(k)), end='')\n",
    "    print(\" \\\\\\\\\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65ab16fb",
   "metadata": {},
   "source": [
    "# Eval hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4249fb9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['MSE_train', 'MSE_test', 'MSE_clip', 'no_params']\n"
     ]
    }
   ],
   "source": [
    "fn = path + \"hyper.npz\"\n",
    "data = np.load(fn, allow_pickle = True)\n",
    "print(data.files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c99b2294",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 4, 5, 11)\n"
     ]
    }
   ],
   "source": [
    "d = data['MSE_test']\n",
    "print(d.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64eb041d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b71d9af0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[3.21876208e-06 4.04014247e-06 1.07526033e-06]\n"
     ]
    }
   ],
   "source": [
    "print(np.min(d, axis=(1,2,3)))\n",
    "#np.argmin(d, axis=(1,2,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "371c07a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "functions = (\"\\\\fsq\", \"\\\\fsqrt\", \"\\\\fsig\")\n",
    "K_values = (2, 4, 6, 8)\n",
    "beta_values = (-3., -2., -1., 0., 1.)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "c20d48f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " & -3.0 & -2.0 & -1.0 & 0.0 & 1.0\\\\\n",
      " & \\multicolumn{5}{c}{ \\fsq }\\\\\n",
      "2 &  0.0114 & \\low{ 0.0055} &  0.0111 &  0.0121 &  0.0087\\\\\n",
      "4 &  0.0069 &  0.0068 &  0.0064 &  0.0077 &  0.0077\\\\\n",
      "6 &  0.0064 &  0.0062 &  0.0072 &  0.0065 &  0.0062\\\\\n",
      "8 &  0.0070 &  0.0079 &  0.0080 &  0.0068 &  0.0073\\\\\n",
      " & \\multicolumn{5}{c}{ \\fsqrt }\\\\\n",
      "2 &  0.1030 &  0.0752 &  0.0756 &  0.0712 &  0.0700\\\\\n",
      "4 &  2.2955 &  2.2624 &  0.0157 &  0.0156 &  0.0184\\\\\n",
      "6 &  2.2960 &  2.2882 &  0.0220 &  0.0164 &  0.0180\\\\\n",
      "8 &  2.2976 &  0.0297 &  0.0177 & \\low{ 0.0123} &  0.0191\\\\\n",
      " & \\multicolumn{5}{c}{ \\fsig }\\\\\n",
      "2 &  7.8617 &  0.0154 &  0.0058 &  0.0056 &  0.0055\\\\\n",
      "4 &  7.8544 &  0.0096 & \\low{ 0.0051} &  0.0076 &  0.0123\\\\\n",
      "6 &  0.1012 &  0.0062 &  0.0052 &  0.0084 &  0.0113\\\\\n",
      "8 &  7.8559 &  0.0058 &  0.0054 &  0.0080 &  0.0119\\\\\n"
     ]
    }
   ],
   "source": [
    "m = np.median(d, axis=3)\n",
    "mv = np.min(m, axis=(1,2))\n",
    "for beta_id, beta in enumerate(beta_values):\n",
    "    print(\" &\", beta, end='')\n",
    "print('\\\\\\\\')\n",
    "for task_id, task in enumerate(functions):\n",
    "    print(\" & \\multicolumn{5}{c}{\", task, \"}\\\\\\\\\")\n",
    "    \n",
    "    for K_id, K in enumerate(K_values):\n",
    "        print ( K, end='')\n",
    "        for beta_id, beta in enumerate(beta_values):\n",
    "            v = m[task_id, K_id, beta_id]\n",
    "            if v==mv[task_id]:\n",
    "                print(\" & \\\\low{\", \"{:.4f}\".format(v*1000), end='}')\n",
    "            else:\n",
    "                print(\" & \", \"{:.4f}\".format(v*1000), end='')\n",
    "        print('\\\\\\\\')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e244a7cc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
