import numpy as np
from numba import jit, prange

from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC,SVR
from sklearn.preprocessing import MinMaxScaler

from persistence_spheres_utils import make_weighting, from_DGMS_to_H

from persim import sliced_wasserstein as SW
from persim import PersistenceImager

from pers_splines import make_PersSplines_vec

from persistence_spheres import PerSPhere

def make_CV_fold(LIST):

    Y, DGMS, n_theta, alpha, k, model, train_idx, test_idx = LIST

    weighting = make_weighting(K=k,alpha=alpha)
    
    dataset = PerSPhere(DGMS,weighting = weighting,n_theta = n_theta, n_phi = 2*n_theta).sph_armonics

    train = np.zeros(Y.shape, dtype=bool)
    train[train_idx] = True
    
    test = np.zeros(Y.shape, dtype=bool)
    test[test_idx] = True
    
    train_data = dataset[train,:]
    train_y = Y[train]
    
    test_data = dataset[test,:]
    test_y = Y[test]
                
#    scaling = MinMaxScaler(feature_range=(-1,1)).fit(train_data)
#    train_data = scaling.transform(train_data)
#    test_data = scaling.transform(test_data)

    model.fit(train_data,train_y)
    
    score = model.score(test_data, test_y)

    results ={}

    results['n_theta']=n_theta
    results['alpha']=alpha
    results['k']=k
    results['model']=model
    results['score']=score

    return results


def make_CV_fold_PI(LIST):

    Y, DGMS, model, pix, n_sigma, p, train_idx, test_idx = LIST

    pimgr = PersistenceImager()
    pimgr.pixel_size = pix
    pimgr.kernel_params = {'sigma': pix/n_sigma}
    pimgr.weight_params['n'] = p
        
    dataset = pimgr.fit_transform(DGMS)
    dataset = np.array([data.flatten() for data in dataset])
    
    train = np.zeros(Y.shape, dtype=bool)
    train[train_idx] = True
    
    test = np.zeros(Y.shape, dtype=bool)
    test[test_idx] = True
    
    train_data = dataset[train,:]
    train_y = Y[train]
    
    test_data = dataset[test,:]
    test_y = Y[test]
                
    model.fit(train_data,train_y)
    
    score = model.score(test_data, test_y)

    results ={}

    results['p']= p
    results['pix']= pix
    results['sigma']=pix/n_sigma
    results['model']=model
    results['score']=score

    return results


def make_CV_fold_PL(LIST):

    dataset, Y, model, train_idx, test_idx = LIST
            
    train = np.zeros(Y.shape, dtype=bool)
    train[train_idx] = True
    
    test = np.zeros(Y.shape, dtype=bool)
    test[test_idx] = True
    
    train_data = dataset[train,:]
    train_y = Y[train]
    
    test_data = dataset[test,:]
    test_y = Y[test]
                
    model.fit(train_data,train_y)
    
    score = model.score(test_data, test_y)

    results ={}

    results['model']=model
    results['score']=score

    return results
    
    

@jit(nopython=True, parallel=False, fastmath=False)    
def make_coeffs(COEFFS,DGMS,scaling):
    
    out = np.zeros((len(DGMS),COEFFS[0].shape[1]))

    for i in prange(len(DGMS)):

        aux = COEFFS[i]*lambda_p_wrap(DGMS[i],heatmap=scaling).reshape(-1,1)

        out[i,:]=np.sum(aux,axis=0)

    return out


def kernel_make_CV_fold(LIST):

    D, Y, sigma,  model, train_idx, test_idx = LIST
        
    SW_kernel = np.exp(-D/(2*(sigma**2)))

    train = np.zeros(Y.shape, dtype=bool)
    train[train_idx] = True
    
    test = np.zeros(Y.shape, dtype=bool)
    test[test_idx] = True

    SW_train = SW_kernel[train,:]
    SW_train = SW_train[:,train]

    SW_test = SW_kernel[test,:]
    SW_test = SW_test[:,train]

    train_y = Y[train]
    test_y = Y[test]

    model.fit(SW_train,train_y)
    
    score = model.score(SW_test,test_y)

    results ={}

    results['sigma']=sigma
    results['model']=model
    results['score']=score

    return results


def make_graham(DGMS,M=10):

    D = np.zeros((len(DGMS),len(DGMS)))

    for i in range(len(DGMS)):
        for j in range(i):
            D[i,j] = SW(DGMS[i],DGMS[j],M=M)
        
    D = D+D.T

    return D


def wrap_SW(LIST):
    D0,D1, M = LIST
    d = SW(D0,D1,M=M)
    return d


def make_CV_fold_PersSplines(LIST):

    Y, DGMS, model, m_b,M_b,M_p, h, iterations, train_idx, test_idx = LIST


    dataset = [make_PersSplines_vec(dgm,m_b,M_b,M_p,h,sig=1e-10,iteration=iterations) for dgm in DGMS]
            
    dataset = np.array([data.flatten() for data in dataset])
    
    train = np.zeros(Y.shape, dtype=bool)
    train[train_idx] = True
    
    test = np.zeros(Y.shape, dtype=bool)
    test[test_idx] = True
    
    train_data = dataset[train,:]
    train_y = Y[train]
    
    test_data = dataset[test,:]
    test_y = Y[test]
                
    model.fit(train_data,train_y)
    
    score = model.score(test_data, test_y)

    results ={}

    results['h']= h
    results['iterations']= iterations
    results['model']= model
    results['score']= score

    return results













    