import numpy as np
from numpy import genfromtxt
from sklearn.preprocessing import OneHotEncoder
from time import time
from pystreed import STreeDClassifier
from sklearn.exceptions import NotFittedError



for f in ["avila",
    "bank",
    "bean",
    "bidding",
    "eeg",
    "fault",
    "htru",
    "magic",
    "occupancy",
    "page",
    "raisin",
    "rice",
    "room",
    "segment",
    "skin",
    "wilt",]:
    print(f)
    filename = "{}.txt".format(f)
    data = genfromtxt("data/"+filename, delimiter="\",", dtype=str)
    data_test = genfromtxt("data/"+f+"_test_binarised.txt", delimiter="\",", dtype=str)
    X, Y = data[:,:-1], data[:,-1]
    X_test, Y_test = data_test[1:,:-1], data_test[1:, -1]
    Y = np.array([int(list(y)[1]) for y in Y], dtype=np.int32)
    Y_test = np.array([int(list(y)[1]) for y in Y_test], dtype=np.int32)
    X_all = np.vstack((X, X_test))
    one_hot_data = OneHotEncoder(handle_unknown='ignore').fit(X_all)
    one_hot_data_X_train = one_hot_data.transform(X_all[:len(X)]).toarray()
    print(one_hot_data_X_train.shape)
    one_hot_data_X_test = one_hot_data.transform(X_all[len(X):]).toarray()
    one_hot_data_X_test_ = one_hot_data_X_test[:int(len(one_hot_data_X_test)/2)]
    Y_test_ = Y_test[:int(len(Y_test)/2)]
    Y_val_ = Y_test[int(len(Y_test)/2):]
    one_hot_data_X_val_ = one_hot_data_X_test[int(len(one_hot_data_X_test)/2):]


    # Build tree classifier
    # models = []
    
    ts = time()
    all_tests = []
    all_val = []
    for n in range(5, 2**5):
        # Fit the model
        model = STreeDClassifier(max_depth = 5, max_num_nodes=n, time_limit=300)
        model.fit(one_hot_data_X_train, Y)
        try:
            test_lab = model.predict(one_hot_data_X_test_)
        except NotFittedError:
            print("No res for {}".format(f))
            break
        val_labs = model.predict(one_hot_data_X_val_)
        test_acc = np.sum(test_lab == Y_test_)/len(Y_test_)
        val_acc = np.sum(val_labs == Y_val_)/len(Y_val_)
        all_tests.append(test_acc)
        all_val.append(val_acc)
        # models.append(model)
    time_ = time() - ts

    print(time_)
    if len(all_tests) > 0:
        best_test = np.argmax(all_tests)
        print(all_val[best_test])
    