import arsenic as ars
import numpy as np
import scipy.sparse

dataset='rcv1'
dataset='criteo'
it0=10
nepochs=200
normalize=True
centering=True
penalty='l1'
fact=1/100
intercept=False
datapath='/scratch/clear/mairal/large_datasets/'
datapath='/local_scratch/mairal/large_datasets/'
nthreads=4


multiclass=False
classif=True
if dataset=='ckn_mnist':
    data=np.load(datapath+dataset+'.npz')
    y=data['y']
    X=data['X'].astype('float64')
    y=np.squeeze(np.float64(y))
    multiclass=True

if dataset=='svhn':
    data=np.load(datapath+dataset+'.npz')
    y=data['arr_1']
    X=data['arr_0']
    multiclass=True

if dataset=='rcv1':
    data = np.load(datapath+'rcv1.npz',allow_pickle=True)
    y=data['y']
    X=data['X']
    X = scipy.sparse.csc_matrix(X.all()).T # n x p matrix, csr format 
    X=X.astype('float64')

if dataset=='alpha' or dataset=='covtype' or dataset=='epsilon' or dataset=='ocr':
    data=np.load(datapath+dataset+'.npz')
    y=data['arr_1']
    X=data['arr_0']

if dataset=='real-sim' or dataset=='webspam' or dataset=='kddb' or dataset=='criteo':
    dataY=np.load(datapath+dataset+'_y.npz',allow_pickle=True)
    y=dataY['arr_0']
    X = scipy.sparse.load_npz(datapath+dataset+'_X.npz')

print(y)

ars.preprocess(X,centering=centering,normalize=normalize,columns=False) # check if can handle views or force copy ?

if classif:
    if multiclass:
        classifier=ars.MultiClassifier(loss='logistic',penalty=penalty,intercept=intercept)
    else:
        classifier=ars.BinaryClassifier(loss='logistic',penalty=penalty,intercept=intercept)
else:
    classifier=ars.Regression(loss='logistic',penalty=penalty,intercept=intercept)

if penalty=='l2':
    lambd=fact/(X.shape[0])
else:
    lambd=fact

optim=classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=nepochs,nthreads=nthreads)
optim=classifier.fit(X,y,it0=it0,lambd=lambd,solver='catalyst-miso',nepochs=nepochs,nthreads=nthreads)

classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-svrg',nepochs=nepochs,nthreads=nthreads)
classifier.fit(X,y,it0=it0,lambd=lambd,solver='catalyst-svrg',nepochs=nepochs,nthreads=nthreads)
classifier.fit(X,y,it0=it0,lambd=lambd,solver='svrg',nepochs=nepochs,nthreads=nthreads)
classifier.fit(X,y,it0=it0,lambd=lambd,solver='acc-svrg',nepochs=nepochs,nthreads=nthreads)

classifier.fit(X,y,it0=it0,lambd=lambd,solver='ista',nepochs=nepochs,nthreads=nthreads)
classifier.fit(X,y,it0=it0,lambd=lambd,solver='fista',nepochs=nepochs,nthreads=nthreads)
classifier.fit(X,y,it0=it0,lambd=lambd,solver='catalyst-ista',nepochs=nepochs,nthreads=nthreads)
classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-ista',nepochs=nepochs,nthreads=nthreads)

