{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LassoNet instabilities\n",
    "\n",
    "We demonstrate that the official LassoNet implementation doesn't train well and it consistently obtains almost random performance.\n",
    "\n",
    "In the official experiments presented in the paper, we grid-searched the $L_1$ penalty hyper-parameter $\\lambda \\in \\{ 0.001, 0.01, 0.1, 1, 10, `auto'\\}$ and the hierarchy coefficient $M \\in \\{ 0.1, 1, 3, 10 \\}$. In this toy experiment we use the best hyper-parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, './src')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from src.dataset import *\n",
    "from main import parse_arguments\n",
    "from lassonet import LassoNetClassifier\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train size: 141\n",
      "\n",
      "Valid size: 16\n",
      "\n",
      "Test size: 40\n",
      "\n",
      "Weights for the classification loss: [0.3525   2.9375   2.35     2.517857]\n"
     ]
    }
   ],
   "source": [
    "# load dataset\n",
    "args = parse_arguments([\n",
    "\t'--dataset', 'lung'\n",
    "])\n",
    "\n",
    "data_module = create_data_module(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.3525  , 2.9375  , 2.35    , 2.517857], dtype=float32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "args.class_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LassoNetClassifier(\n",
    "\tlambda_start = 0.1,\n",
    "\tM = 10,\n",
    "\tn_iters = 200,\n",
    "\toptim = partial(torch.optim.AdamW, lr=1e-4, betas=[0.9, 0.98]),\n",
    "\thidden_dims=(100, 100, 10),\n",
    "\tclass_weight = args.class_weights, # use weighted loss\n",
    "\tdropout=0.2,\n",
    "\tbatch_size=8,\n",
    "\tbacktrack=True # if True, ensure the objective is decreasing\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.path(data_module.X_train, data_module.y_train,\n",
    "\t\t\t\tX_val = data_module.X_valid, y_val = data_module.y_valid);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Note: LassoNet cannot even fit the training data. It obtains 40% training accuracy on the 'lung' dataset, on which the other models obtain 90% test accuracy, and close to 100% train accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.40425531914893614"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.score(data_module.X_train, data_module.y_train) # train balanced accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.3125"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.score(data_module.X_valid, data_module.y_valid) # validation balanced accuracy"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.9 ('WPS')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "219ae80cc2456d1acf1023715d83e5ed138f590fd0d6277c2235c49a9ecd13fe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
