{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HOAG usage example\n",
    "\n",
    "We will load an example dataset (20news from scikit-learn) and learn with HOAG the optimal L2 reagularization parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pylab inline\n",
    "\n",
    "from hoag import LogisticRegressionCV\n",
    "\n",
    "# load some data\n",
    "from sklearn import datasets, linear_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get a training set and test set\n",
    "data_train = datasets.fetch_20newsgroups_vectorized(subset='train')\n",
    "data_test = datasets.fetch_20newsgroups_vectorized(subset='test')\n",
    "\n",
    "X_train = data_train.data\n",
    "X_test = data_test.data\n",
    "y_train = data_train.target\n",
    "y_test = data_test.target\n",
    "\n",
    "# binarize labels\n",
    "y_train[data_train.target < 10] = -1\n",
    "y_train[data_train.target >= 10] = 1\n",
    "y_test[data_test.target < 10] = -1\n",
    "y_test[data_test.target >= 10] = 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# optimize model parameters and hyperparameters jointly\n",
    "# using HOAG\n",
    "clf = LogisticRegressionCV()\n",
    "clf.fit(X_train, y_train, X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Regularization chosen by HOAG: alpha=%s' % (clf.alpha_[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So above is the the optimal regularization parameter chose by HOAG.\n",
    "\n",
    "Lets now plot the cost function to see how well HOAG did.  Note that HOAG minimizes a parametrization in which ```clf.alpha_``` is inside an exponential hence we need to take this into account when comparing with the scikit-learn estimator. \n",
    "\n",
    "WARNING: this might take a long time to compute (~10 min)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# range for regularization parameters\n",
    "alphas = np.linspace(-25, 10, 40)\n",
    "\n",
    "def cost_func(a):\n",
    "    clf = linear_model.LogisticRegression(\n",
    "        solver='lbfgs',\n",
    "        C=np.exp(-a), fit_intercept=False, \n",
    "        tol=1e-22, max_iter=500)\n",
    "\n",
    "    clf.fit(X_train, y_train)\n",
    "    cost = linear_model._logistic._logistic_loss(clf.coef_.ravel(), X_test, y_test, 0.)\n",
    "    return cost\n",
    "\n",
    "scores = [cost_func(a) for a in alphas]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make the plot bigger than default\n",
    "plt.rcParams['figure.figsize'] = (8.0, 6.0)\n",
    "plt.rcParams['font.size'] = 20\n",
    "\n",
    "# plot the scores\n",
    "plt.plot(alphas, scores, lw=3, label='cross-validation error')\n",
    "plt.xlabel(r'$\\alpha$', fontsize=40)\n",
    "\n",
    "# plot HOAG value of alpha\n",
    "plt.plot((clf.alpha_[0], clf.alpha_[0]), (1000, 7000), c='k', linestyle='--', \n",
    "         label=r'value of regularization ($\\alpha$) found by HOAG')\n",
    "plt.legend(fontsize=20)\n",
    "plt.grid()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
