{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SARAH with fine-tuned hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "path = os.getcwd()\n",
    "parent_path = os.path.abspath(os.path.join(path, os.pardir))\n",
    "sys.path.append(parent_path)\n",
    "\n",
    "import random\n",
    "\n",
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.set_num_threads(1) #cpu num\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from Sparse_Init.sparseinit import *    \n",
    "from Sparse_Init.sparsedata import *\n",
    "from Sparse_Init.sparsemodule import * \n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print (device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algo = 'sarah' # algorithm\n",
    "dname = 'ijcnn1' # dataset name\n",
    "BS = 64 # mini-batch size\n",
    "StrongConvex = True # L2 regularization\n",
    "if StrongConvex:\n",
    "    case = 'reg'\n",
    "else:\n",
    "    case = 'non_reg'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data - user need to download datasets from LIBSVM# generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify data directory\n",
    "datafolder = '../Data/'+dname+'/' # please download libsvm dataset to this folder before executing this code\n",
    "# Specify directory to save log files - optional\n",
    "logfolder = '../Logs/'+dname+'/'+case+'/'+algo+'/'\n",
    "# to run all hyper-parameters, please use the following log folders\n",
    "# logfolder = '../AllLogs/'+dname+'/'+case+'/'+algo+'/'\n",
    "\n",
    "if not os.path.exists(logfolder):\n",
    "    os.makedirs(logfolder)\n",
    "    \n",
    "\n",
    "# dataset files - need to be downloaded from LIBSVM website\n",
    "if dname == 'covtype':\n",
    "    file = datafolder+'covtype.libsvm.binary.scale.bz2'\n",
    "    \n",
    "if dname == 'ijcnn1':\n",
    "    trfile = datafolder+'ijcnn1.bz2'\n",
    "    tefile = datafolder+'ijcnn1.t.bz2'\n",
    "    \n",
    "if dname == 'rcv1':\n",
    "    trfile = datafolder+'rcv1_train.binary.bz2'\n",
    "    tefile = datafolder+'rcv1_test.binary.bz2'\n",
    "    \n",
    "if dname == 'news20':\n",
    "    file = datafolder+'news20.binary.bz2'\n",
    "    \n",
    "if dname == 'real-sim':\n",
    "    file = datafolder+'real-sim.bz2'\n",
    "    \n",
    "    \n",
    "try:\n",
    "    data = SparseData(dname,device,file=file)\n",
    "    csr = data.read()\n",
    "    normalize(csr[0],copy=False)\n",
    "    data.load(_csr=csr)\n",
    "except:\n",
    "    data = SparseData(dname,device,trfile=trfile,tefile=tefile)\n",
    "    train_csr, test_csr = data.read()\n",
    "    normalize(train_csr[0],copy=False)\n",
    "    normalize(test_csr[0],copy=False)\n",
    "    data.load(_trainCSR=train_csr,_testCSR=test_csr)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if StrongConvex:\n",
    "    lam = 1/data.trSize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# use best parm\n",
    "## please see appendix for fine-tuned parameters for each dataset and case"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# put hyper-parameters into the list\n",
    "LR = [0] # constant step-size\n",
    "LOOPSIZE = [0] # loop size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LR,LOOPSIZE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### a. experiment setup "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = [0,1,2,3,4,5,6,7,8,9] # 10 random seeds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### b. parameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# running budget\n",
    "if case=='reg':\n",
    "    if dname =='rcv1':\n",
    "        TotalEP = 30.0\n",
    "    if dname =='ijcnn1':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='news20':\n",
    "        TotalEP = 40.0\n",
    "    if dname =='covtype':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='real-sim':\n",
    "        TotalEP = 20.0\n",
    "        \n",
    "if case=='non_reg':\n",
    "    if dname =='rcv1':\n",
    "        TotalEP = 40.0\n",
    "    if dname =='ijcnn1':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='news20':\n",
    "        TotalEP = 50.0\n",
    "    if dname =='covtype':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='real-sim':\n",
    "        TotalEP = 30.0\n",
    "        \n",
    "perEpoch = data.trSize//BS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logfolder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### c. run "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for seed,alpha,T in itertools.product(SEED,LR,LOOPSIZE):\n",
    "    timer=[] # timer\n",
    "    \n",
    "    run_status = logfolder+'RUN-lr-%s-loop-%s-seed-%s/'%(alpha,T,seed)\n",
    "    done_status = logfolder+'DONE-lr-%s-loop-%s-seed-%s/'%(alpha,T,seed)\n",
    "    savefile = logfolder+'lr-%s-loop-%s-seed-%s.tar'%(alpha,T,seed)\n",
    "    \n",
    "    if os.path.exists(run_status) or os.path.exists(done_status) or os.path.exists(savefile):\n",
    "        print(done_status)\n",
    "        continue\n",
    "    else:\n",
    "        os.makedirs(run_status)\n",
    "    print('======\\nlr - %s | T - %s | seed - %s\\n======'%(alpha,T,seed))  \n",
    "    \n",
    "    # results\n",
    "    HIST=[]\n",
    "    STAT=[]\n",
    "    \n",
    "    TIME = time.time()# total run timer\n",
    "    \n",
    "    # initialize random stream\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    # define one sparse layer model for linear model with logistic regression\n",
    "    if StrongConvex:\n",
    "        # L2 regularized case\n",
    "        model = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True).to(device)\n",
    "        prev_net = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True).to(device)\n",
    "    else:\n",
    "        # non-regularized case\n",
    "        model = ConvexModel(data.num_feature,data.num_label).to(device)\n",
    "        prev_net = ConvexModel(data.num_feature,data.num_label).to(device)\n",
    "      \n",
    "    # for weight wrt features in testing but not in training dataset, set them to ZERO\n",
    "    if len(data.in_te_not_tr)>0:\n",
    "        model.del_in_te_not_tr(data.in_te_not_tr)\n",
    "        prev_net.del_in_te_not_tr(data.in_te_not_tr)\n",
    "    \n",
    "    allSamples = list(range(data.trSize))\n",
    "        \n",
    "    # intialize counter    \n",
    "    ep=0.0 # count effective pass \n",
    "    innerT=0 # count inner iterations\n",
    "    outerT=0 # count outer iterations\n",
    "    st=0 # mini-batch loop counter\n",
    "    # initialize stopping flag\n",
    "    converge=False\n",
    "    fatal=False \n",
    "    # epoch time - time for one epoch\n",
    "    epoch_time = time.time()\n",
    "    \n",
    "    t=0\n",
    "    \n",
    "    outer_record=True # for print&save log purpose\n",
    "    while ep<=TotalEP+1:\n",
    "        \n",
    "        if converge or fatal: \n",
    "            break\n",
    "            \n",
    "        for wi,pi in zip(model.parameters(),prev_net.parameters()):\n",
    "            with torch.no_grad():\n",
    "                pi.set_(wi+0.0)   \n",
    "          \n",
    "        # compute batch loss,grad,test\n",
    "        Loss, V = prev_net.LossGrad(data)\n",
    "        Grad = np.sum([(gi.data**2).sum().item() for gi in V])\n",
    "        Test = prev_net.ComputeAccuracy(data)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for wi,vi in zip(model.parameters(),V):\n",
    "                wi.sub_(alpha*vi)\n",
    "        \n",
    "        if outer_record:\n",
    "            timeT = time.time() - epoch_time\n",
    "            if ep==0 and t==0:\n",
    "                timeT=0.0\n",
    "            HIST.append([ep,Loss,Grad,Test])\n",
    "            STAT.append([ep,outerT,innerT,timeT,1])\n",
    "            print('outer-ep: %.2f, alpha: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\\\n",
    "                  %(ep,alpha,Loss,Grad,Test,timeT,t))\n",
    "            epoch_time = time.time()\n",
    " \n",
    "        ep+=1.0\n",
    "        outerT+=data.trSize//BS\n",
    "        \n",
    "        if np.isnan(Loss) or np.isnan(Grad) or np.isnan(Test):\n",
    "            fatal = True\n",
    "        if Grad < 1e-15:\n",
    "            converge=True\n",
    "                    \n",
    "        # initialize inner loop\n",
    "        t=0\n",
    "        outer_record=True\n",
    "        # inner loop\n",
    "        while t<T and ep<=TotalEP+1:\n",
    "            \n",
    "            if fatal or converge:\n",
    "                break\n",
    "            \n",
    "            # random mini-batch\n",
    "            st=st%perEpoch\n",
    "            if st==0:\n",
    "                np.random.shuffle(allSamples)\n",
    "            if st==perEpoch-1:\n",
    "                sample = allSamples[st*BS:]\n",
    "            else:\n",
    "                sample = allSamples[st*BS:(st+1)*BS]\n",
    "                \n",
    "            x_sample,y_sample = data.mb(sample)\n",
    "            # compute sample grad: g0\n",
    "            _,g0 = prev_net.LossGrad(data,sample=sample)  \n",
    "            # compute sample grad: g1\n",
    "            _,g1 = model.LossGrad(data,sample=sample)\n",
    "            # update recusrive gradient\n",
    "            V = [g1i.data - g0i.data + vi.data for g1i,g0i,vi in zip(g1,g0,V)]\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for wi,pi,vi in zip(model.parameters(),prev_net.parameters(),V):\n",
    "                    pi.set_(wi+0.0)\n",
    "                    wi.sub_(alpha*vi)\n",
    "            \n",
    "            st+=1 # sample counter\n",
    "            t+=1 # count inner iteration\n",
    "            innerT+=1\n",
    "            ep+=BS/data.trSize # count effective pass\n",
    "            \n",
    "            inner_record=False\n",
    "            if (t-1)%perEpoch==0:\n",
    "                \n",
    "                timeT = time.time()-epoch_time\n",
    "                \n",
    "                inner_record=True\n",
    "                Lossprint, Vprint = model.LossGrad(data)\n",
    "                Gradprint = np.sum([(gi.data**2).sum().item() for gi in Vprint])\n",
    "                Testprint = model.ComputeAccuracy(data)\n",
    "\n",
    "                HIST.append([ep,Lossprint,Gradprint,Testprint])\n",
    "                STAT.append([ep,outerT,innerT,timeT,0])\n",
    "                print('inner-ep: %.2f, alpha: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\\\n",
    "                      %(ep,alpha,Lossprint,Gradprint,Testprint,timeT,t))\n",
    "                \n",
    "                epoch_time = time.time()\n",
    "            \n",
    "        if inner_record:\n",
    "            outer_record=False  \n",
    "            \n",
    "    TIME = time.time() - TIME # total running time per run\n",
    "    \n",
    "    RESULTS = OrderedDict()\n",
    "    RESULTS = {\n",
    "        'parm': [BS,seed,alpha,T],\n",
    "        'end': [Loss,Grad,Test,TIME,converge,fatal],\n",
    "        'hist': HIST,\n",
    "        'stat': STAT\n",
    "    }\n",
    "    torch.save(RESULTS,savefile)\n",
    "    \n",
    "    # update running status\n",
    "    if os.path.exists(run_status):\n",
    "        os.rmdir(run_status)\n",
    "    if not os.path.exists(done_status):\n",
    "        os.mkdir(done_status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
