from sklearn.datasets import fetch_covtype,load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
import timeit
import arsenic as ars
import numpy as np
from scipy.io import loadmat
import scipy.sparse
import sys
from sklearn.linear_model import LogisticRegression
import time

#data = fetch_covtype(return_X_y=False)
if False:
    data = load_iris(return_X_y=False)
    X,y= data['data'], data['target']
    Xunorm=np.copy(X)
    ars.preprocess(X,centering=True,normalize=True,columns=False) # check if can handle views or force copy ?
    ars.preprocess(Xunorm,centering=True,normalize=False,columns=False)
    y[y==0]=-1
    y[y > 0]=1
    y=np.float64(y)

if False:
    data=np.load('ckn_mnist.npz')
    y=data['y']
    X=data['X'].astype('float64')
    y=np.squeeze(np.float64(y))
    y[y != 0]=-1
    y[y == 0]=1
    ars.preprocess(X,centering=True,normalize=True,columns=False) # check if can handle views or force copy ?
    Xunorm=np.copy(X)
    ch=np.random.rand(X.shape[0])+0.01
    Xunorm=ch[:,np.newaxis] * X
    ars.preprocess(Xunorm,centering=True,normalize=False,columns=False)

if True:
    data = np.load('rcv1.npz',allow_pickle=True)
    y=data['y']
    X=data['X']
    X = scipy.sparse.csc_matrix(X.all()).T # n x p matrix, csr format 
    X=X.astype('float64')
    Xunorm=X.copy()
    ars.preprocess(X,normalize=True,columns=False) # check if can handle views or force copy ?

#data=np.load('data.npz')
#X=np.float64(data['X'])
#X=(data['X'])
#y=data['y']
#lambd=float(data['lambd'])



n=X.shape[0]
w0=np.zeros([X.shape[1],1])
XT=np.asfortranarray(X)
yT=np.expand_dims(y,axis=1)
param = {'numThreads' : -1,'verbose' : True,
        'it0' : 10, 'max_it' : 200,
        'L0' : 0.1, 'tol' : 1e-3, 
        'pos' : False}
param['ista']=True
param['loss'] = 'square'
param['regul'] = 'l2'

## test l2
if False:
    print("ISTA with square loss and l2 regularization")
    lambd=0.1
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='l2',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd)
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='l2',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)
    input("Press Enter to continue...")

## test l1
if False:
    lambd=0.1
    param['regul'] = 'l1'
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='l1',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd)
    (w, optim_info)=spams.fistaFlat(yT,XT,w0,True,**param,lambda1=lambd*n,intercept=False)
    print(optim_info[0]/X.shape[0])
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='l1',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)

## test elastic-net
if False:
    lambd=0.1
    lambd2=0.01
    param['regul'] = 'elastic-net'
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='elastic-net',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,lambd2=lambd2)
    (w, optim_info)=spams.fistaFlat(yT,XT,w0,True,**param,lambda1=lambd*n,lambda2=lambd2*n,intercept=False)
    print(optim_info[0]/X.shape[0])
    w=classifier.get_weights()
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='elastic-net',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd,lambd2=lambd2)
    (w,b)=classifier.get_weights()
    print(0.5*np.sum((np.squeeze(X.dot(w)+b) - y)**2)/X.shape[0] +  lambd*np.sum(np.abs(w))+lambd2*0.5*np.sum(w**2))

## test l1-ball 
if False:
    lambd=1
    param['regul'] = 'l1-constraint'
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='l1-ball',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd)
    (w, optim_info)=spams.fistaFlat(yT,XT,w0,True,**param,lambda1=lambd,intercept=False)
    print(optim_info[0]/X.shape[0])
    w=classifier.get_weights()
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='l1-ball',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)

## test fused Lasso 
if False:
    lambd=0.1
    lambd2=0.01
    lambd3=0.01
    param['regul'] = 'fused-lasso'
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='fused-lasso',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,lambd2=lambd2,lambd3=lambd3)
    (w, optim_info)=spams.fistaFlat(yT,XT,w0,True,**param,lambda1=lambd*n,lambda2=lambd2*n,lambda3=lambd3*n,intercept=False)
    print(optim_info[0]/X.shape[0])
    w=classifier.get_weights()
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='square',penalty='fused-lasso',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd,lambd2=lambd2,lambd3=lambd3)

## test l2 logistic
if False:
    lambd=0.001
    param['loss'] = 'logistic'
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd)
    (w, optim_info)=spams.fistaFlat(yT,XT,w0,True,**param,lambda1=lambd,intercept=False)
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)

## test l2 squared hinge
if False:
    lambd=0.001
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='sqhinge',penalty='l2',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd)
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='sqhinge',penalty='l2',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)

## test FISTA l1 logistic
if False:
    lambd=0.01
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,solver='fista')
    classifier.fit(X,y,it0=10,lambd=lambd,solver='ista')
    param['loss'] = 'logistic'
    param['regul'] = 'l1'
    param['ista']=False
    (w, optim_info)=spams.fistaFlat(yT,XT,w0,True,**param,lambda1=lambd,intercept=False)
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd,solver='fista')
    classifier.fit(X,y,it0=10,lambd=lambd,solver='ista')

## test FISTA l1 safe-logistic
if False:
    lambd=0.01
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='safe-logistic',penalty='l1',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,solver='fista')
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='safe-logistic',penalty='l1',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd,solver='fista')

## test SVRG uniform sampling
if False:
    lambd=0.01
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,solver='svrg')
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd,solver='svrg')

## test SVRG non-uniform sampling
if False:
    lambd=0.01
    print("Test without intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=False)
    classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='svrg')
    print("\nTest with intercept")
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=True)
    classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='svrg')
    if True:
        # same un-normalized data, but uniform sampling
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=False)
        classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='svrg-uniform') # forces uniform sampling
        print("\nTest with intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=True)
        classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='svrg-uniform') # forces uniform sampling
    # test SVRG with Catalyst
    if True:
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=False)
        classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='catalyst-svrg')
        print("\nTest with intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=True)
        classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='catalyst-svrg')
    if True:
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=False)
        classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='acc-svrg')
        print("\nTest with intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l1',intercept=True)
        classifier.fit(Xunorm,y,it0=50,lambd=lambd,solver='acc-svrg')

## test SVRG non-uniform sampling
if False:
    lambd=0.0000001
    it0=10
    if False:
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=False)
        classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='svrg')
        print("\nTest with intercept")
        #classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=True)
        #classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='svrg')
    if False:
        # same un-normalized data, but uniform sampling
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=False)
        classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='svrg-uniform') # forces uniform sampling
        print("\nTest with intercept")
        #classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=True)
        #classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='svrg-uniform') # forces uniform sampling
    # test SVRG with Catalyst
    if True:
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=False)
        classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='catalyst-svrg')
        print("\nTest with intercept")
        #classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=True)
        #classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='catalyst-svrg')
    if True:
        print("Test without intercept")
        classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=False)
        classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='acc-svrg')
        print("\nTest with intercept")
        #classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=True)
        #classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='acc-svrg')

if True:
#    lambd=0.001
#    lambd=0.000000000001
    it0=10
#    lambd=0.000001
#    lambd=0.000000001
#    lambd=0.000001
#    lambd=0.000001
    lambd=1/(10*X.shape[0])
    classifier=ars.BinaryClassifier(loss='logistic',penalty='l2',fit_intercept=True)
#    svc_ars=ars.BinaryClassifier(loss='logistic',penalty='l2',intercept=False)

#    optim=classifier.eval(X,y,lambd=lambd)
#    print(sys.getrefcount(optim))

    n=X.shape[0]
    start_time = time.time()
#clf = LogisticRegression(random_state=0, solver='liblinear',C=1/(n*lambd),fit_intercept=False,verbose=True).fit(X, y)
    print("--- %s seconds ---" % (time.time() - start_time))
#    save_coeff=np.copy(svc_ars.w)
#    svc_ars.w=np.copy(clf.coef_.ravel())
#    optim=svc_ars.eval(X, y,lambd=lambd)
#    print('OPTIM SKLEARN : ', optim[1])
#    svc_ars.w=np.copy(save_coeff)

    print('call ista')
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='acc-svrg',nepochs=200)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='ista',nepochs=50,nthreads=4)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=50,nthreads=4)


    toto

    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=200,l_qning=50)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=200,l_qning=20)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=200,l_qning=10)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=200,l_qning=5)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-miso',nepochs=200,l_qning=1)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='miso',nepochs=200)

    input('...')
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='svrg',nepochs=200)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='catalyst-svrg',nepochs=200)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-svrg',nepochs=200)

    input('...')
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='fista',nepochs=200)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='catalyst-ista',nepochs=200)
    classifier.fit(X,y,it0=it0,lambd=lambd,solver='qning-ista',nepochs=200)

    #classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='miso',nepochs=200)

    #classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='svrg-uniform',nepochs=200)
#    classifier.fit(Xunorm,y,it0=it0,lambd=lambd,solver='svrg',nepochs=200)

    #classifier.fit(X,y,it0=it0,lambd=lambd,solver='svrg-uniform')
    #print(classifier.score(X,y))






