#########################################################################################################
directory = "./"
depth = 3
n_splits = 10
#########################################################################################################
import numpy as np
import pandas as pd
import time
import GPy
from sklearn.model_selection import KFold
from SGP import SurrogateGaussianProcess
for data in ["diabetes", "abalone", "cancer"]:  
    D = pd.read_csv("./data/" + data + ".csv")
    if data == "diabetes":
        m = 10
        likelihood_function = 'Gaussian'
        time_limit = 3600 * 1            
    elif data == "abalone":
        m = 100  
        likelihood_function = 'Poisson'        
        time_limit = 3600 * 5          
    elif data == "cancer":
        m = 10
        likelihood_function = 'Bernoulli'
        time_limit = 3600 * 1          
    col = list(D.columns)[:-1]
    kf = KFold(n_splits=n_splits, shuffle=True)
    result = []
    for num, (train_index, test_index) in enumerate(kf.split(D)):
        X = np.array([D.iloc[i,:-1] for i in range(len(D)) if i in train_index])
        y = np.array([D.iloc[i,-1] for i in range(len(D)) if i in train_index])
        X_ = np.array([D.iloc[i,:-1] for i in range(len(D)) if i in test_index])        
        SGP = SurrogateGaussianProcess(col)
        kernel = GPy.kern.Bias(1) * GPy.kern.RBF(X.shape[1], ARD=True) + GPy.kern.White(1)        
        SGP.learn_GP(X, y, kernel, m=m, likelihood_function=likelihood_function, tau=False)        
        SGP.make_posterior(X_)
        omega = SGP.CART(depth=depth)
        score_CART = SGP.get_score(omega)
        t0 = time.time()
        omega = SGP.find_cluster(2 ** depth, depth=depth, time_limit=time_limit, surrogate_model="tree")
        t1 = time.time() - t0
        score_MIQP = SGP.get_score(omega)
        print(data, num, depth, min(time_limit, t1), score_CART, score_MIQP)
        result += [[data, num, depth, min(time_limit, t1), score_CART, score_MIQP]]
    result = pd.DataFrame(result)
    result.columns = ["data", "num", "depth", "time", "score_CART", "score_MIQP"]
    result.to_csv("tree" + "_" + data + ".csv", index=False)