{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Heavy-Tail Phenomenon in SGD\n",
    "\n",
    "This notebook contains the code for reproducing the synthetic data experiments reported in the paper. \n",
    "\n",
    "The current code handles the case where the step-size $\\eta$ and the batch-size $b$ are varying; however, it is straightforward to modify the code for varying dimension $d$ and curvature $\\sigma$.\n",
    "\n",
    "For the neural network experiments, we used the code provided in \n",
    "U. Simsekli, L. Sagun, M. Gurbuzbalaban, \"A Tail-Index Analysis of Stochastic Gradient Noise in Deep Neural Networks\", ICML 2019. We are also submitting our version for convenience. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Required libraries are listed below. \n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "from scipy.stats import levy_stable\n",
    "import scipy.io as sio\n",
    "import math\n",
    "import levy\n",
    "import powerlaw \n",
    "from joblib import Parallel, delayed\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Corollary 2.4 in Mohammadi 2014 - for multi-d  - taken from Simsekli et al. ICML 2019.\n",
    "def alpha_estimator_multi(m, X):\n",
    "    # X is N by d matrix\n",
    "    N = len(X)\n",
    "    n = int(N/m) # must be an integer\n",
    "    Y = torch.sum(X.view(n, m, -1), 1)\n",
    "    eps = np.spacing(1)\n",
    "    Y_log_norm = torch.log(Y.norm(dim=1) + eps).mean()\n",
    "    X_log_norm = torch.log(X.norm(dim=1) + eps).mean()\n",
    "    diff = (Y_log_norm - X_log_norm) / math.log(m)\n",
    "    return 1 / diff.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# total number of data points \n",
    "n = 10000\n",
    "\n",
    "# dimension\n",
    "d = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Synthetic Data Generation\n",
    "sig_A = 1.0\n",
    "sig_x = 3\n",
    "sig_Y = 3\n",
    "\n",
    "# explanatory vars.\n",
    "A = sig_A * np.random.randn(n,d)\n",
    "\n",
    "# true regressor\n",
    "x0 = sig_x * np.random.randn(d)\n",
    "\n",
    "# responses\n",
    "Y = np.dot(A,x0) + sig_Y * np.random.randn(n)\n",
    "\n",
    "# Mode of the distribution\n",
    "x_star = np.dot(np.linalg.inv(np.dot(A.T,A)), np.dot(A.T,Y))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss function\n",
    "def loss_fn(x,A,y,nsamp):\n",
    "    res = np.sum((np.dot(A,x) - y)**2)/(2*nsamp)\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sgd_linreg(A,Y,n,d,K,eta,b, compute_loss=False):\n",
    "    burn_in = int(K*0.5)\n",
    "    loss_sgd = np.zeros(K)\n",
    "\n",
    "    x_cur = 5 * np.random.randn(d)\n",
    "    x_mc  = np.zeros(d)\n",
    "    \n",
    "    if(compute_loss is True):\n",
    "        loss_sgd[0] = loss_fn(x_cur,A,Y,n)\n",
    "\n",
    "    for k in range(1,K):\n",
    "        ix = np.random.permutation(n)\n",
    "        S_k = ix[0:int(b)]\n",
    "        A_k = A[S_k,:]\n",
    "        y_k = Y[S_k]\n",
    "        \n",
    "        grad = np.dot(A_k.T, np.dot(A_k,x_cur)) - np.dot(A_k.T, y_k)\n",
    "        grad = grad / b\n",
    "        x_cur = x_cur - eta * grad\n",
    "        \n",
    "        if(k >= burn_in):\n",
    "            k_loc = (k - burn_in) * 1.0\n",
    "            x_mc  = (k_loc/(k_loc+1)) * x_mc + 1/(k_loc+1)*x_cur\n",
    "\n",
    "        if(compute_loss is True):\n",
    "            loss_sgd[k] = loss_fn(x_cur,A,Y,n)\n",
    "\n",
    "    x_final = x_cur\n",
    "    if(compute_loss is True):\n",
    "        plt.figure()\n",
    "        plt.plot(loss_sgd)\n",
    "        \n",
    "    return [x_final,x_mc]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### SGD\n",
    "\n",
    "# Numer of repetitions\n",
    "m = 40\n",
    "numRep = m*m\n",
    "\n",
    "# Number of iteratios\n",
    "K = 1000\n",
    "\n",
    "#batch size\n",
    "b_all   = [1, 2,3,4,5,10,20]\n",
    "#step-size\n",
    "eta_all = np.linspace(0.02,0.2,10)\n",
    "\n",
    "alpha_mc_sas     = np.zeros((len(b_all), len(eta_all)))\n",
    "\n",
    "for bi, b in enumerate(b_all):\n",
    "    for ei, eta in enumerate(eta_all):\n",
    "        print('b =',b,' eta =',eta)\n",
    "        \n",
    "        x_final = np.zeros((numRep,d))\n",
    "        x_mc = np.zeros((numRep,d))\n",
    "\n",
    "        tmp_res = Parallel(n_jobs=10)(delayed(sgd_linreg)(A,Y,n,d,K,eta,b) for i in range(numRep))\n",
    "        tmp_res = np.asarray(tmp_res)\n",
    "        \n",
    "        x_final = np.squeeze(tmp_res[:,0,:])\n",
    "        x_mc    = np.squeeze(tmp_res[:,1,:])\n",
    "        \n",
    "        Xfn = x_final - x_star\n",
    "        Xmc = x_mc - x_star\n",
    "         \n",
    "        alp_tmp = [alpha_estimator_multi(mm, torch.from_numpy(Xmc)) for mm in (2, 5, 10, 20, 50, 100, m)] \n",
    "        alpha_mc_sas [bi,ei] = np.median(alp_tmp)\n",
    "        \n",
    "        print('\\t',alp_tmp)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
