{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "## generate data\n",
    "\n",
    "d = 10000\n",
    "n = 100\n",
    "\n",
    "theta_star = np.random.rand(d,1)\n",
    "# theta_star[10:-1,0]=0\n",
    "theta_star /= 10\n",
    "\n",
    "zeta = np.random.randn(n, d)\n",
    "\n",
    "for i in range(len(zeta)):\n",
    "    xi = np.random.rand()*0.5+0.5\n",
    "    zeta[i,:] = zeta[i,:]/np.sqrt(np.sum(zeta[i,:]**2))*xi\n",
    "\n",
    "\n",
    "X = zeta\n",
    "Y = X.dot(theta_star)\n",
    "\n",
    "X_inv = np.linalg.inv(X.dot(X.T))\n",
    "Proj = np.dot(np.dot(X.T, X_inv), X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## helper functions\n",
    "\n",
    "def get_grad(X_batch, Y_batch, theta):\n",
    "    batch_size = len(X_batch)\n",
    "    return 2/batch_size * (np.dot(X_batch.T,X_batch.dot(theta)) - np.dot(X_batch.T, Y_batch))\n",
    "\n",
    "def loss(theta):\n",
    "    output = X.dot(theta)\n",
    "    return 1/n * np.sum((output-Y)**2)\n",
    "def distance(theta):\n",
    "    return np.sum((Proj.dot(theta - theta_star))**2)\n",
    "\n",
    "def raleigh(theta):\n",
    "    direction = np.dot(np.dot((theta.T - theta_star.T), np.dot(X.T, X)), (theta - theta_star))\n",
    "    \n",
    "    return direction/(np.sum((Proj.dot(theta - theta_star))**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## moderate LR regime\n",
    "theta = np.zeros([d, 1])\n",
    "b = 1\n",
    "lr = 1.05\n",
    "iter_num = 3000\n",
    "early_stop_iter = 2500\n",
    "\n",
    "\n",
    "raleigh_sgd = []\n",
    "\n",
    "for i in range(iter_num): \n",
    "    batch_ind = i % int(n / b) \n",
    "    if batch_ind == 0:\n",
    "        mask = np.random.choice(n, n, False)\n",
    "        X, Y = X[mask], Y[mask]\n",
    "        batch_list = []\n",
    "        for j in range(n // b):\n",
    "            X_batched, Y_batched = X[b*j:b*(j+1)], Y[b*j:b*(j+1)]\n",
    "            batch_list += [[X_batched, Y_batched]]\n",
    "    \n",
    "    X_batch, Y_batch = batch_list[batch_ind]\n",
    "    \n",
    "    grad = get_grad(X_batch, Y_batch, theta)\n",
    "    theta -= lr * grad\n",
    "    \n",
    "    if i%100 ==0:\n",
    "        print (raleigh(theta), distance(theta), loss(theta))\n",
    "    if i>early_stop_iter:\n",
    "        lr = 0.1\n",
    "    if i%100 ==0:\n",
    "        raleigh_sgd += [[raleigh(theta)[0,0], distance(theta), loss(theta)]]\n",
    "\n",
    "raleigh_gd = []\n",
    "theta = np.zeros([d, 1])\n",
    "b = n\n",
    "lr = 1.05\n",
    "for i in range(iter_num): \n",
    "    \n",
    "    X_batch, Y_batch = X, Y\n",
    "    \n",
    "    grad = get_grad(X_batch, Y_batch, theta)\n",
    "    theta -= lr * grad\n",
    "    \n",
    "    if i%100 ==0:\n",
    "        print (raleigh(theta), distance(theta), loss(theta))\n",
    "    if i>2500:\n",
    "        lr = 0.1\n",
    "    if i%100 ==0:\n",
    "        raleigh_gd += [[raleigh(theta)[0,0], distance(theta), loss(theta)]]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## small LR regime\n",
    "theta = np.zeros([d, 1])\n",
    "b = 1\n",
    "lr = 0.2\n",
    "iter_num = 10000\n",
    "\n",
    "\n",
    "raleigh_sgd_smallLR = []\n",
    "\n",
    "for i in range(iter_num): \n",
    "    batch_ind = i % int(n / b) \n",
    "    if batch_ind == 0:\n",
    "        mask = np.random.choice(n, n, False)\n",
    "        X, Y = X[mask], Y[mask]\n",
    "        batch_list = []\n",
    "        for j in range(n // b):\n",
    "            X_batched, Y_batched = X[b*j:b*(j+1)], Y[b*j:b*(j+1)]\n",
    "            batch_list += [[X_batched, Y_batched]]\n",
    "    \n",
    "    X_batch, Y_batch = batch_list[batch_ind]\n",
    "    \n",
    "    grad = get_grad(X_batch, Y_batch, theta)\n",
    "    theta -= lr * grad\n",
    "    \n",
    "    if i%100 ==0:\n",
    "        print (raleigh(theta), distance(theta), loss(theta))\n",
    "        raleigh_sgd_smallLR += [[raleigh(theta)[0,0], distance(theta), loss(theta)]]\n",
    "\n",
    "\n",
    "theta = np.zeros([d, 1])\n",
    "b = n\n",
    "raleigh_gd_smallLR = []\n",
    "lr = 0.2\n",
    "for i in range(iter_num): \n",
    "    \n",
    "    X_batch, Y_batch = X, Y\n",
    "    \n",
    "    grad = get_grad(X_batch, Y_batch, theta)\n",
    "    theta -= lr * grad\n",
    "    \n",
    "    if i%100 ==0:\n",
    "        print (raleigh(theta), distance(theta), loss(theta))\n",
    "        raleigh_gd_smallLR += [[raleigh(theta)[0,0], distance(theta), loss(theta)]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(value_sgd)):\n",
    "    value_sgd[i] = value_sgd[i][0,0]\n",
    "    value_gd[i] = value_gd[i][0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_eigen, min_eigen = 1.03, 0.24445014017813266\n",
    "\n",
    "iterations = np.arange(0, len(raleigh_sgd_smallLR), 1) * 100\n",
    "\n",
    "plt.plot(iterations, [value[0] for value in raleigh_sgd_smallLR], \"--\", color=\"darkorange\",markersize = 4, linewidth=2)\n",
    "plt.plot(iterations, [value[0] for value in raleigh_gd_smallLR], \"--\", color=\"royalblue\", markersize = 4, linewidth=2)\n",
    "\n",
    "iterations = np.arange(0, len(raleigh_sgd_new), 1) * 100\n",
    "\n",
    "plt.plot(iterations, [value[0] for value in raleigh_sgd_new], \"-r\",markersize = 3,linewidth=2)\n",
    "plt.plot(iterations, [value[0] for value in raleigh_gd], \"-b\",markersize = 3, linewidth=2)\n",
    "\n",
    "\n",
    "plt.xlabel(\"# Iteration\", fontsize=16)\n",
    "plt.ylabel(\"Rayleigh quotient\", fontsize=16)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.xlim([-200,10750])\n",
    "plt.hlines(max_eigen, 0, 10000, linestyle=\":\", color=\"brown\", linewidth=2)\n",
    "plt.hlines(min_eigen, 0, 10000, linestyle=\":\", color=\"darkgreen\",linewidth=2)\n",
    "plt.legend([\"SGD, small LR\", \"GD, small LR\", \"SGD, moderate LR\", \"GD, moderate LR\"], fontsize=16, loc=7)\n",
    "plt.text(10050, 1.02, \"$\\gamma_1$\", fontsize=16)\n",
    "plt.text(10050, 0.25, \"$\\gamma_n$\", fontsize=16)\n",
    "plt.savefig(\"synthetic.pdf\",bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = np.arange(0, len(value_sgd_smallLR), 1) * 100\n",
    "\n",
    "plt.plot(iterations, [value[0] for value in value_sgd_smallLR], linewidth=2)\n",
    "plt.plot(iterations, [value[0] for value in value_gd_smallLR], linewidth=2)\n",
    "plt.xlabel(\"# Iteration\", fontsize=16)\n",
    "plt.ylabel(\"Rayleigh quotient\", fontsize=16)\n",
    "plt.legend([\"SGD\", \"GD\"], fontsize=16)\n",
    "plt.xticks(fontsize=16)\n",
    "plt.yticks(fontsize=16)\n",
    "plt.hlines(min_eigen, 0, 10000, linestyle=\"--\", color=\"green\")\n",
    "plt.xlim([-500,10800])\n",
    "plt.text(10050, 0.245, \"$\\gamma_n$\", fontsize=16)\n",
    "plt.savefig(\"synthetic_smallLR.pdf\",bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
