{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SARAH+ with fine-tuned hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "path = os.getcwd()\n",
    "parent_path = os.path.abspath(os.path.join(path, os.pardir))\n",
    "sys.path.append(parent_path)\n",
    "\n",
    "import random\n",
    "\n",
    "import pprint as pp\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "import shutil\n",
    "from numpy import genfromtxt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "torch.set_default_dtype(torch.float64)\n",
    "torch.set_num_threads(1) #cpu num\n",
    "\n",
    "import itertools\n",
    "import numpy.linalg  as lin\n",
    "\n",
    "import cProfile, pstats\n",
    "\n",
    "from collections import OrderedDict\n",
    "\n",
    "from Sparse_Init.sparseinit import *    \n",
    "from Sparse_Init.sparsedata import *\n",
    "from Sparse_Init.sparsemodule import * \n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print (device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algo = 'plus' # algorithm\n",
    "dname = 'rcv1' # dataset name\n",
    "BS = 64 # mini-batch size\n",
    "StrongConvex = True # L2 regularization\n",
    "if StrongConvex:\n",
    "    case = 'reg'\n",
    "else:\n",
    "    case = 'non_reg'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data - user need to download datasets from LIBSVM# generate data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify data directory\n",
    "datafolder = '../Data/'+dname+'/' # please download libsvm dataset to this folder before executing this code\n",
    "# Specify directory to save log files - optional\n",
    "logfolder = '../Logs/'+dname+'/'+case+'/'+algo+'/'\n",
    "# to run all hyper-parameters, please use the following log folders\n",
    "# logfolder = '../AllLogs/'+dname+'/'+case+'/'+algo+'/'\n",
    "\n",
    "if not os.path.exists(logfolder):\n",
    "    os.makedirs(logfolder)\n",
    "    \n",
    "\n",
    "# dataset files - need to be downloaded from LIBSVM website\n",
    "if dname == 'covtype':\n",
    "    file = datafolder+'covtype.libsvm.binary.scale.bz2'\n",
    "    \n",
    "if dname == 'ijcnn1':\n",
    "    trfile = datafolder+'ijcnn1.bz2'\n",
    "    tefile = datafolder+'ijcnn1.t.bz2'\n",
    "    \n",
    "if dname == 'rcv1':\n",
    "    trfile = datafolder+'rcv1_train.binary.bz2'\n",
    "    tefile = datafolder+'rcv1_test.binary.bz2'\n",
    "    \n",
    "if dname == 'news20':\n",
    "    file = datafolder+'news20.binary.bz2'\n",
    "    \n",
    "if dname == 'real-sim':\n",
    "    file = datafolder+'real-sim.bz2'\n",
    "    \n",
    "    \n",
    "try:\n",
    "    data = SparseData(dname,device,file=file)\n",
    "    csr = data.read()\n",
    "    normalize(csr[0],copy=False)\n",
    "    data.load(_csr=csr)\n",
    "except:\n",
    "    data = SparseData(dname,device,trfile=trfile,tefile=tefile)\n",
    "    train_csr, test_csr = data.read()\n",
    "    normalize(train_csr[0],copy=False)\n",
    "    normalize(test_csr[0],copy=False)\n",
    "    data.load(_trainCSR=train_csr,_testCSR=test_csr)\n",
    "print(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if StrongConvex:\n",
    "    lam = 1/data.trSize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# use best parameters\n",
    "## please see appendix for best hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# put hyper-parameters into the list\n",
    "LR = [0] # constant step-size\n",
    "GAMMA = [0] # early stopping parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LR,GAMMA"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### a. experiment setup "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = [0,1,2,3,4,5,6,7,8,9] # 10 random seeds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### b. parameters "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# running budget\n",
    "if case=='reg':\n",
    "    if dname =='rcv1':\n",
    "        TotalEP = 30.0\n",
    "    if dname =='ijcnn1':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='news20':\n",
    "        TotalEP = 40.0\n",
    "    if dname =='covtype':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='real-sim':\n",
    "        TotalEP = 20.0\n",
    "        \n",
    "if case=='non_reg':\n",
    "    if dname =='rcv1':\n",
    "        TotalEP = 40.0\n",
    "    if dname =='ijcnn1':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='news20':\n",
    "        TotalEP = 50.0\n",
    "    if dname =='covtype':\n",
    "        TotalEP = 20.0\n",
    "    if dname =='real-sim':\n",
    "        TotalEP = 30.0\n",
    "        \n",
    "perEpoch = data.trSize//BS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logfolder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### c. run "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for seed,alpha,gamma in itertools.product(SEED,LR,GAMMA):\n",
    "    timer=[] # timer\n",
    "    \n",
    "    run_status = logfolder+'RUN-lr-%s-gamma-%s-seed-%s/'%(alpha,gamma,seed)\n",
    "    done_status = logfolder+'DONE-lr-%s-gamma-%s-seed-%s/'%(alpha,gamma,seed)\n",
    "    savefile = logfolder+'lr-%s-gamma-%s-seed-%s.tar'%(alpha,gamma,seed)\n",
    "    \n",
    "    if os.path.exists(run_status) or os.path.exists(done_status) or os.path.exists(savefile):\n",
    "        print(done_status)\n",
    "        continue\n",
    "    else:\n",
    "        os.makedirs(run_status)\n",
    "    print('======\\nlr - %s | gamma - %s | seed - %s\\n======'%(alpha,gamma,seed))  \n",
    "    \n",
    "    # results\n",
    "    HIST=[]\n",
    "    STAT=[]\n",
    "    \n",
    "    TIME = time.time()# total run timer\n",
    "    \n",
    "    # initialize random stream\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    \n",
    "    # define one layer model for linear model with logistic regression\n",
    "    if StrongConvex:\n",
    "        # L2 regularized case\n",
    "        model = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True).to(device)\n",
    "        prev_net = ConvexModel(data.num_feature,data.num_label,lam=lam,StrongConvex=True).to(device)\n",
    "    else:\n",
    "        # un-regularized case\n",
    "        model = ConvexModel(data.num_feature,data.num_label).to(device)\n",
    "        prev_net = ConvexModel(data.num_feature,data.num_label).to(device)\n",
    "      \n",
    "    # for weight wrt features in testing but not in training dataset, set them to ZERO\n",
    "    if len(data.in_te_not_tr)>0:\n",
    "        model.del_in_te_not_tr(data.in_te_not_tr)\n",
    "        prev_net.del_in_te_not_tr(data.in_te_not_tr) # redundant - remember to remove before SUBMISSION !!!\n",
    "    \n",
    "    allSamples = list(range(data.trSize))\n",
    "        \n",
    "    # intialize counter    \n",
    "    ep=0.0 # count effective pass \n",
    "    innerT=0 # count inner iterations\n",
    "    outerT=0 # count outer iterations\n",
    "    st=0 # mini-batch loop counter\n",
    "    # initialize stopping flag\n",
    "    converge=False\n",
    "    fatal=False\n",
    "    # epoch time - time for one epoch\n",
    "    epoch_time = time.time()\n",
    "    \n",
    "    t=0\n",
    "    \n",
    "    outer_record=True # for print&save log purpose\n",
    "    while ep<=TotalEP+1:\n",
    "        \n",
    "        if converge or fatal: \n",
    "            break\n",
    "            \n",
    "        for wi,pi in zip(model.parameters(),prev_net.parameters()):\n",
    "            with torch.no_grad():\n",
    "                pi.set_(wi+0.0)   \n",
    "          \n",
    "        # compute batch loss,grad,test\n",
    "        Loss, V = prev_net.LossGrad(data)\n",
    "        Grad = np.sum([(gi.data**2).sum().item() for gi in V])\n",
    "        Test = prev_net.ComputeAccuracy(data)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            for wi,vi in zip(model.parameters(),V):\n",
    "                wi.sub_(alpha*vi)\n",
    "        \n",
    "        if outer_record:\n",
    "            timeT = time.time() - epoch_time\n",
    "            if ep==0 and t==0:\n",
    "                timeT=0.0\n",
    "            HIST.append([ep,Loss,Grad,Test])\n",
    "            STAT.append([ep,outerT,innerT,timeT,1])\n",
    "            print('outer-ep: %.2f, alpha: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\\\n",
    "                  %(ep,alpha,Loss,Grad,Test,timeT,t))\n",
    "            epoch_time = time.time()\n",
    " \n",
    "        ep+=1.0\n",
    "        outerT+=data.trSize//BS\n",
    "        \n",
    "        if np.isnan(Loss) or np.isnan(Grad) or np.isnan(Test):\n",
    "            fatal = True\n",
    "        if Grad < 1e-15:\n",
    "            converge=True\n",
    "                    \n",
    "        # initialize inner loop\n",
    "        normV0 = Grad\n",
    "        normV = normV0\n",
    "        t=0\n",
    "        outer_record=True\n",
    "        # inner loop\n",
    "        while ep<=TotalEP+1:\n",
    "            \n",
    "            if fatal or converge or normV<gamma*normV0:\n",
    "                break\n",
    "            \n",
    "            # random mini-batch\n",
    "            st=st%perEpoch\n",
    "            if st==0:\n",
    "                np.random.shuffle(allSamples)\n",
    "            if st==perEpoch-1:\n",
    "                sample = allSamples[st*BS:]\n",
    "            else:\n",
    "                sample = allSamples[st*BS:(st+1)*BS]\n",
    "                \n",
    "            x_sample,y_sample = data.mb(sample)\n",
    "            # compute sample grad: g0\n",
    "            _,g0 = prev_net.LossGrad(data,sample=sample)  \n",
    "            # compute sample grad: g1\n",
    "            _,g1 = model.LossGrad(data,sample=sample)\n",
    "            # update recusrive gradient\n",
    "            V = [g1i.data - g0i.data + vi.data for g1i,g0i,vi in zip(g1,g0,V)]\n",
    "            \n",
    "            normV = np.sum([(vi.data**2).sum().item() for vi in V])\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for wi,pi,vi in zip(model.parameters(),prev_net.parameters(),V):\n",
    "                    pi.set_(wi+0.0)\n",
    "                    wi.sub_(alpha*vi)\n",
    "            \n",
    "            st+=1 # sample counter\n",
    "            t+=1 # count inner iteration\n",
    "            innerT+=1\n",
    "            ep+=BS/data.trSize # count effective pass\n",
    "            \n",
    "            inner_record=False\n",
    "            if (t-1)%perEpoch==0:\n",
    "                \n",
    "                timeT = time.time()-epoch_time\n",
    "                \n",
    "                inner_record=True\n",
    "                Lossprint, Vprint = model.LossGrad(data)\n",
    "                Gradprint = np.sum([(gi.data**2).sum().item() for gi in Vprint])\n",
    "                Testprint = model.ComputeAccuracy(data)\n",
    "\n",
    "                HIST.append([ep,Lossprint,Gradprint,Testprint])\n",
    "                STAT.append([ep,outerT,innerT,timeT,0])\n",
    "                print('inner-ep: %.2f, alpha: %.4f, loss: %.2e, Grad: %.2e, Test: %.4f, Time: %.2f, t: %d'\\\n",
    "                      %(ep,alpha,Lossprint,Gradprint,Testprint,timeT,t))\n",
    "                \n",
    "                epoch_time = time.time()\n",
    "            \n",
    "        if inner_record:\n",
    "            outer_record=False  \n",
    "            \n",
    "    TIME = time.time() - TIME # total running time per run\n",
    "    \n",
    "    RESULTS = OrderedDict()\n",
    "    RESULTS = {\n",
    "        'parm': [BS,seed,alpha,gamma],\n",
    "        'end': [Loss,Grad,Test,TIME,converge,fatal],\n",
    "        'hist': HIST,\n",
    "        'stat': STAT\n",
    "    }\n",
    "    torch.save(RESULTS,savefile)\n",
    "    \n",
    "    # update running status\n",
    "    if os.path.exists(run_status):\n",
    "        os.rmdir(run_status)\n",
    "    if not os.path.exists(done_status):\n",
    "        os.mkdir(done_status)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 Anaconda",
   "language": "python",
   "name": "python3anaconda"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
