from sklearn.datasets import fetch_covtype,load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
import timeit
import arsenic as ars
import numpy as np
from scipy.io import loadmat
import scipy.sparse
import sys

if True:
    data=np.load('ckn_mnist.npz')
    #data=np.load('small_mnist.npz')
    y=data['y']
    X=data['X'].astype('float64')
#    X=np.copy(X[0:2000,0:100])
#    y=np.copy(y[0:2000])
    ars.preprocess(X,centering=True,normalize=True,columns=False) # check if can handle views or force copy ?
    Xunorm=np.copy(X)
    ch=np.random.rand(X.shape[0])+0.01
    Xunorm=ch[:,np.newaxis] * X
    ars.preprocess(Xunorm,centering=True,normalize=False,columns=False)

if False:
    data = np.load('rcv1.npz')
    y=data['y']
    X=data['X']
    X = scipy.sparse.csc_matrix(X.all()).T # n x p matrix, csr format 
    X=X.astype('float32')
    Xunorm=X.copy()
    ars.preprocess(X,normalize=True,columns=False) # check if can handle views or force copy ?



## test l2
if False:
    print("ISTA with square loss and l2 regularization")
    lambd=0.1
    print("Test without intercept")
    classifier=ars.MultiClassifier(loss='square',penalty='l2',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=30)
    print("\nTest with intercept")
    classifier=ars.MultiClassifier(loss='square',penalty='l2',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)
    input("Press Enter to continue...")

## test l1
if False:
    print("ISTA with square loss and l1 regularization")
    lambd=0.1
    print("Test without intercept")
    classifier=ars.MultiClassifier(loss='square',penalty='l1',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=30)
    print("\nTest with intercept")
    classifier=ars.MultiClassifier(loss='square',penalty='l1',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)
#    input("Press Enter to continue...")

## test l2
if False:
    print("ISTA with square loss and l2 regularization")
    lambd=0.1
    print("Test without intercept")
    classifier=ars.MultiClassifier(loss='logistic',penalty='l2',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=30)
    print("\nTest with intercept")
    classifier=ars.MultiClassifier(loss='logistic',penalty='l2',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd)
    input("Press Enter to continue...")

## test l2
if False:
    print("ISTA with square loss and l2 regularization")
    lambd=0.1
    print("Test without intercept")
    classifier=ars.MultiClassifier(loss='multiclass-logistic',penalty='l2',intercept=False)
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=10)
    print("\nTest with intercept")
    classifier=ars.MultiClassifier(loss='multiclass-logistic',penalty='l2',intercept=True)
    #classifier.fit(X,y,it0=10,lambd=lambd)
    input("Press Enter to continue...")

# logistic + l1l2 not working ?
# intercept not working ?
# parallel ok with intercept
## test l2
if True:
    lambd=1/(100*X.shape[0])
    lambd=0.0001
    print("Test without intercept")
    classifier=ars.MultiClassifier(loss='multiclass-logistic',penalty='l1',intercept=True)
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='qning-miso')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='catalyst-miso')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='qning-ista')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='qning-svrg')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='acc-svrg')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='catalyst-svrg')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='svrg')
    classifier.fit(X,y,it0=10,lambd=lambd,nepochs=200,solver='miso')
    print("\nTest with intercept")
    classifier=ars.MultiClassifier(loss='multiclass-logistic',penalty='l2',intercept=True)
    #classifier.fit(X,y,it0=10,lambd=lambd,solver='svrg')


