{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# *AI-SARAH (Algorithm 1)* in the main paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "path = os.getcwd()\n",
    "parent_path = os.path.abspath(os.path.join(path, os.pardir))\n",
    "sys.path.append(parent_path)\n",
    "\n",
    "import random\n",
    "\n",
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.set_num_threads(1) #cpu num\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from Sparse_Init.sparseinit import *    \n",
    "from Sparse_Init.sparsedata import *\n",
    "from Sparse_Init.sparsemodule_v2 import * # here, for AI-SARAH, user should use version 2 implementation\n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print (device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algo = 'ai_sarah' # algorithm\n",
    "dname = 'news20' # dataset name\n",
    "BS = 64 # mini-batch size\n",
    "StrongConvex = True # L2 regularization\n",
    "if StrongConvex:\n",
    "    case = 'reg'\n",
    "else:\n",
    "    case = 'non_reg'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data - user need to download datasets from LIBSVM# generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify data directory\n",
    "datafolder = '../Data/'+dname+'/' # please download libsvm dataset to this folder before executing this code\n",
    "# Specify directory to save log files - optional\n",
    "logfolder = '../Logs/'+dname+'/'+case+'/'+algo+'/'\n",
    "# to compare with other algorithms that have all hyper-parameters, please use the following log folders\n",
    "# logfolder = '../AllLogs/'+dname+'/'+case+'/'+algo+'/'\n",
    "\n",
    "if not os.path.exists(logfolder):\n",
    "    os.makedirs(logfolder)\n",
    "    \n",
    "\n",
    "# dataset files - need to be downloaded from LIBSVM website\n",
    "if dname == 'covtype':\n",
    "    file = datafolder+'covtype.libsvm.binary.scale.bz2'\n",
    "    \n",
    "if dname == 'ijcnn1':\n",
    "    trfile = datafolder+'ijcnn1.bz2'\n",
    "    tefile = datafolder+'ijcnn1.t.bz2'\n",
    "    \n",
    "if dname == 'rcv1':\n",
    "    trfile = datafolder+'rcv1_train.binary.bz2'\n",
    "    tefile = datafolder+'rcv1_test.binary.bz2'\n",
    "    \n",
    "if dname == 'news20':\n",
    "    file = datafolder+'news20.binary.bz2'\n",
    "    \n",
    "if dname == 'real-sim':\n",
    "    file = datafolder+'real-sim.bz2'\n",
    "    \n",
    "    \n",
    "try:\n",
    "    data = SparseData(dname,device,file=file)\n",
    "    csr = data.read()\n",
    "    normalize(csr[0],copy=False)\n",
    "    data.load(_csr=csr)\n",
    "except:\n",
    "    data = SparseData(dname,device,trfile=trfile,tefile=tefile)\n",
    "    train_csr, test_csr = data.read()\n",
    "    normalize(train_csr[0],copy=False)\n",
    "    normalize(test_csr[0],copy=False)\n",
    "    data.load(_trainCSR=train_csr,_testCSR=test_csr)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if StrongConvex:\n",
    "    lam = 1/data.trSize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### a. experiment setup "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 10 random seeds\n",
    "SEED = [0,1,2,3,4,5,6,7,8,9]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### b. parameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# running budget\n",
    "if case=='reg':\n",
    "    if dname =='rcv1':\n",
    "        TotalEP = 30.0\n",
    "    if dname =='ijcnn1':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='news20':\n",
    "        TotalEP = 40.0\n",
    "    if dname =='covtype':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='real-sim':\n",
    "        TotalEP = 20.0\n",
    "        \n",
    "if case=='non_reg':\n",
    "    if dname =='rcv1':\n",
    "        TotalEP = 40.0\n",
    "    if dname =='ijcnn1':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='news20':\n",
    "        TotalEP = 50.0\n",
    "    if dname =='covtype':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='real-sim':\n",
    "        TotalEP = 30.0\n",
    "        \n",
    "perEpoch = data.trSize//BS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### c. run "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for seed in SEED:\n",
    "    timer=[] # timer\n",
    "    \n",
    "    run_status = logfolder+'RUN-seed-%s/'%seed\n",
    "    done_status = logfolder+'DONE-seed-%s/'%seed\n",
    "    savefile = logfolder+'seed-%s.tar'%seed\n",
    "    \n",
    "    if os.path.exists(run_status) or os.path.exists(done_status) or os.path.exists(savefile):\n",
    "        print(done_status)\n",
    "        continue\n",
    "    else:\n",
    "        os.makedirs(run_status)\n",
    "    print('======\\nseed - %s\\n======'%seed)  \n",
    "    \n",
    "    # results\n",
    "    HIST=[]\n",
    "    STAT=[]\n",
    "    ALPHA=[]\n",
    "    alpha_max = np.inf # initial bound for first iteration - no upper bound\n",
    "    alpha = 0.0 # initial alpha - no initial step-size\n",
    "    TIME = time.time()# total run timer\n",
    "    \n",
    "    # initialize random stream\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    # define one layer model for linear model with logistic regression\n",
    "    if StrongConvex:\n",
    "        # L2 regularized case\n",
    "        model = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True)\n",
    "        prev_net = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True)\n",
    "    else:\n",
    "        # un-regularized case\n",
    "        model = ConvexModel(data.num_feature,data.num_label)\n",
    "        prev_net = ConvexModel(data.num_feature,data.num_label)\n",
    "      \n",
    "    # push model to GPU for non-leaf variable\n",
    "    model.to(device)\n",
    "    prev_net.to(device)\n",
    "    \n",
    "    # for weight wrt features in testing but not in training dataset, set them to ZERO\n",
    "    if len(data.in_te_not_tr)>0:\n",
    "        model.del_in_te_not_tr(data.in_te_not_tr)\n",
    "        prev_net.del_in_te_not_tr(data.in_te_not_tr)\n",
    "\n",
    "    for wi,pi in zip(model.parameters(),prev_net.parameters()):\n",
    "        with torch.no_grad():\n",
    "            pi.set_(wi+0.0)   \n",
    "    \n",
    "    allSamples = list(range(data.trSize))\n",
    "        \n",
    "    # intialize counter    \n",
    "    ep=0.0 # count effective pass \n",
    "    innerT=0 # count inner iterations\n",
    "    outerT=0 # count outer iterations\n",
    "    odmT=0 # count one-dim-minimization iterations\n",
    "    st=0 # mini-batch loop counter\n",
    "    # initialize stopping flag\n",
    "    converge=False\n",
    "    fatal=False\n",
    "    # epoch time - time for one epoch\n",
    "    epoch_time = time.time()\n",
    "    \n",
    "    alpha = 0.0\n",
    "    t=0\n",
    "    \n",
    "    outer_record=True # for print&save log purpose\n",
    "    while ep<=TotalEP+1:\n",
    "        \n",
    "        if converge or fatal: \n",
    "            break\n",
    "          \n",
    "        # compute batch loss,grad,test\n",
    "        Loss, V = prev_net.LossGrad(data)\n",
    "        Grad = np.sum([(gi.data**2).sum().item() for gi in V])\n",
    "        Test = prev_net.ComputeAccuracy(data)\n",
    "        \n",
    "        if outer_record:\n",
    "            timeT = time.time() - epoch_time\n",
    "            if ep==0 and t==0:\n",
    "                timeT=0.0\n",
    "            HIST.append([ep,Loss,Grad,Test])\n",
    "            STAT.append([ep,outerT,innerT,timeT,1])\n",
    "            print('outer-ep: %.2f, alpha: %.4f, max: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\\\n",
    "                  %(ep,alpha,alpha_max,Loss,Grad,Test,timeT,t))\n",
    "            epoch_time = time.time()\n",
    " \n",
    "        normV0 = Grad\n",
    "        normVold = Grad\n",
    "        \n",
    "        ep+=1.0\n",
    "        outerT+=data.trSize//BS\n",
    "        \n",
    "        if np.isnan(Loss) or np.isnan(Grad) or np.isnan(Test):\n",
    "            fatal = True\n",
    "        if Grad < 1e-15:\n",
    "            converge=True\n",
    "                    \n",
    "        # initialize inner loop\n",
    "        t=0\n",
    "        normV = np.inf\n",
    "        outer_record=True\n",
    "        # inner loop\n",
    "        while ep<=TotalEP+1:\n",
    "            \n",
    "            if fatal or converge or normV<normV0.item()/32.0:\n",
    "                break\n",
    "            \n",
    "            # random mini-batch\n",
    "            st=st%perEpoch\n",
    "            if st==0:\n",
    "                np.random.shuffle(allSamples)\n",
    "            if st==perEpoch-1:\n",
    "                sample = allSamples[st*BS:]\n",
    "            else:\n",
    "                sample = allSamples[st*BS:(st+1)*BS]\n",
    "                \n",
    "            x_sample,y_sample = data.mb(sample)\n",
    "            # compute sample grad: g0\n",
    "            _,g0 = prev_net.LossGrad(data,sample=sample)  \n",
    "                        \n",
    "            # intialize implicit method\n",
    "            ti=-1\n",
    "            alpha = 0.0 # starting point\n",
    "            newtonD = 0.0\n",
    "            # destruct computing graph\n",
    "            for w in model.parameters():\n",
    "                w.detach_()\n",
    "            # no loop version - only one iterration on newton    \n",
    "            alpha = torch.tensor(alpha - newtonD,requires_grad=True).to(device)\n",
    "            # update model parameter to potential w_1\n",
    "            for wi,pi,vi in zip(model.parameters(),prev_net.parameters(),V):\n",
    "                with torch.no_grad():\n",
    "                    wi.set_(pi+0.0) # no gradient in operation\n",
    "                wi.sub_(alpha*vi) # with gradient \n",
    "            # compute sample grad: g1\n",
    "            _,g1 = model.LossGrad(data,sample=sample,second_order=True)\n",
    "            Vtemp = [g1i - g0i + vi for g1i,g0i,vi in zip(g1,g0,V)]\n",
    "            normVtemp = torch.stack([(vi**2).sum() for vi in Vtemp]).sum()\n",
    "            # compute 1st/2nd derivative of alpha\n",
    "            alphaGrad = torch.autograd.grad(normVtemp,alpha,create_graph=True)\n",
    "            alphaHess = torch.autograd.grad(alphaGrad,alpha)\n",
    "            # newton direction\n",
    "            newtonD = 1.0/np.abs(alphaHess[0].item())*alphaGrad[0].item() # abs(hessian)\n",
    "            GradHess = np.abs(newtonD)\n",
    "            alpha = alpha.item()\n",
    "            improvement = normVtemp.item()/normVold\n",
    "            ti+=1\n",
    "            odmT+=1\n",
    "            alpha=alpha-newtonD\n",
    "\n",
    "            # update alpha for current iterate\n",
    "            alpha_newton = alpha\n",
    "            alpha = min(alpha_newton,alpha_max)\n",
    "            \n",
    "            # update upper bound for next iterate\n",
    "            if ep==1.0 and t==0:\n",
    "                delta = 1.0/alpha_newton\n",
    "                alpha_max = alpha_newton\n",
    "            else:\n",
    "                delta = 0.999*delta + 0.001*(1.0/alpha_newton)\n",
    "                alpha_max = 1.0/delta\n",
    "                                \n",
    "            # update iterate - wt\n",
    "            for wi,pi,vi in zip(model.parameters(),prev_net.parameters(),V):\n",
    "                with torch.no_grad():\n",
    "                    wi.set_(pi-alpha*vi)\n",
    "            _,g1 = model.LossGrad(data,sample=sample)\n",
    "            # update V_t, squared norm of V_t\n",
    "            V = [g1i.data - g0i.data + vi.data for g1i,g0i,vi in zip(g1,g0,V)]\n",
    "            normV = np.sum([(vi.data**2).sum().item() for vi in V]) \n",
    "            # update ratio\n",
    "            improvement = normV/normVold\n",
    "            \n",
    "            # update for next inner iteration\n",
    "            normVold = normV\n",
    "            for wi,pi in zip(model.parameters(),prev_net.parameters()):\n",
    "                with torch.no_grad():\n",
    "                    pi.set_(wi+0.0)\n",
    "                        \n",
    "            st+=1 # sample counter\n",
    "            t+=1 # count inner iteration\n",
    "            innerT+=1\n",
    "            ep+=BS/data.trSize # count effective pass\n",
    "            \n",
    "            inner_record=False\n",
    "            if (t-1)%perEpoch==0:\n",
    "                timeT = time.time()-epoch_time\n",
    "                inner_record=True\n",
    "                Lossprint, Vprint = prev_net.LossGrad(data)\n",
    "                Gradprint = np.sum([(gi.data**2).sum().item() for gi in Vprint])\n",
    "                Testprint = prev_net.ComputeAccuracy(data)\n",
    "\n",
    "                HIST.append([ep,Lossprint,Gradprint,Testprint])\n",
    "                STAT.append([ep,outerT,innerT,timeT,0])\n",
    "                print('inner-ep: %.2f, alpha: %.4f, max: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\\\n",
    "                      %(ep,alpha,alpha_max,Lossprint,Gradprint,Testprint,timeT,t))\n",
    "                \n",
    "                epoch_time = time.time()\n",
    "            \n",
    "        if inner_record:\n",
    "            outer_record=False  \n",
    "            \n",
    "    TIME = time.time() - TIME # total running time per run\n",
    "    \n",
    "    RESULTS = OrderedDict()\n",
    "    RESULTS = {\n",
    "        'parm': [BS,seed],\n",
    "        'end': [Loss,Grad,Test,TIME,converge,fatal],\n",
    "        'hist': HIST,\n",
    "        'stat': STAT\n",
    "    }\n",
    "    torch.save(RESULTS,savefile)\n",
    "    \n",
    "    # update running status\n",
    "    if os.path.exists(run_status):\n",
    "        os.rmdir(run_status)\n",
    "    if not os.path.exists(done_status):\n",
    "        os.mkdir(done_status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
