import cyanure as ars
import numpy as np
import scipy.sparse
import argparse
import os

parser = argparse.ArgumentParser()
parser.add_argument("--dataset");
parser.add_argument("--penalty");
parser.add_argument("--fact",type=float);
parser.add_argument("--seed",type=int);
parser.add_argument("--lambda",type=int);
args=parser.parse_args();

dataset=args.dataset;
penalty=args.penalty;
fact=args.fact;
seed=args.seed;

nthreads=8
datapath='/scratch/clear/mairal/large_datasets/'
logfiles='/scratch/clear/mairal/logs_cyanure/'
normalize=True
centering=True
it0=5
fit_intercept=False
solvers={'qning-miso','catalyst-miso','qning-svrg','catalyst-svrg','acc-svrg','svrg','ista','fista','catalyst-ista','qning-ista'}
multiclass=False
classif=True
if dataset=='ckn_mnist':
    data=np.load(datapath+dataset+'.npz')
    y=data['y']
    X=data['X'].astype('float64')
    y=np.squeeze(np.float64(y))
    multiclass=True

if dataset=='svhn':
    data=np.load(datapath+dataset+'.npz')
    y=data['arr_1']
    X=data['arr_0']
    multiclass=True

if dataset=='rcv1':
    data = np.load(datapath+'rcv1.npz',allow_pickle=True)
    y=data['y']
    X=data['X']
    X = scipy.sparse.csc_matrix(X.all()).T # n x p matrix, csr format 
    X=X.astype('float64')

if dataset=='alpha' or dataset=='covtype' or dataset=='epsilon' or dataset=='ocr':
    data=np.load(datapath+dataset+'.npz')
    y=data['arr_1']
    X=data['arr_0']

if dataset=='real-sim' or dataset=='webspam' or dataset=='kddb' or dataset=='criteo':
    dataY=np.load(datapath+dataset+'_y.npz',allow_pickle=True)
    y=dataY['arr_0']
    X = scipy.sparse.load_npz(datapath+dataset+'_X.npz')

ars.preprocess(X,centering=centering,normalize=normalize,columns=False) 

if classif:
    if multiclass:
        classifier=ars.MultiClassifier(loss='multiclass-logistic',penalty=penalty,fit_intercept=fit_intercept)
    else:
        classifier=ars.BinaryClassifier(loss='logistic',penalty=penalty,fit_intercept=fit_intercept)
else:
    classifier=ars.Regression(loss='square',penalty=penalty,fit_intercept=fit_intercept)

if penalty=='l2':
    lambd=fact/(X.shape[0])
else:
    lambd=fact

for solver in solvers:
    if solver in {'ista','fista','catalyst-ista','qning-ista'}:
        max_epochs=500
    else:
        max_epochs=200
    namelog=logfiles+dataset+"_"+solver+"_"+penalty+"_lambda"+str(fact)+"seed" + str(seed)
    print(namelog)
    if not os.path.isfile(namelog+'.npy'):
        optim=classifier.fit(X,y,it0=it0,lambd=lambd,solver=solver,max_epochs=max_epochs,nthreads=nthreads,seed=seed,tol=1e-7)
        np.save(namelog,optim)
    else:
        print(namelog+" already exists")

