import pymurtree
import numpy as np
from numpy import genfromtxt
from sklearn.preprocessing import OneHotEncoder
from time import time

for f in ["avila"]:
    print(f)
    filename = "{}.txt".format(f)
    data = genfromtxt("data/"+filename, delimiter="\",", dtype=str)
    data_test = genfromtxt("data/"+f+"_test_binarised.txt", delimiter="\",", dtype=str)
    X, Y = data[:,:-1], data[:,-1]
    X_test, Y_test = data_test[1:,:-1], data_test[1:, -1]
    Y = np.array([int(list(y)[1]) for y in Y], dtype=np.int32)
    Y_test = np.array([int(list(y)[1]) for y in Y_test], dtype=np.int32)
    X_all = np.vstack((X, X_test))
    one_hot_data = OneHotEncoder(handle_unknown='ignore').fit(X_all)
    one_hot_data_X_train = one_hot_data.transform(X_all[:len(X)]).toarray()
    print(one_hot_data_X_train.shape)
    one_hot_data_X_test = one_hot_data.transform(X_all[len(X):]).toarray()
    one_hot_data_X_test_ = one_hot_data_X_test[:int(len(one_hot_data_X_test)/2)]
    Y_test_ = Y_test[:int(len(Y_test)/2)]
    Y_val_ = Y_test[int(len(Y_test)/2):]
    one_hot_data_X_val_ = one_hot_data_X_test[int(len(one_hot_data_X_test)/2):]


    # Build tree classifier
    # models = []
    
    ts = time()
    all_tests = []
    all_val = []
    for n in range(5, 2**5):
        model = pymurtree.OptimalDecisionTreeClassifier()
        model.fit(one_hot_data_X_train, Y, max_num_nodes=n, max_depth=5, time=300)

        test_lab = model.predict(one_hot_data_X_test_)

        val_labs = model.predict(one_hot_data_X_val_)
        test_acc = np.sum(test_lab == Y_test_)/len(Y_test_)
        val_acc = np.sum(val_labs == Y_val_)/len(Y_val_)
        all_tests.append(test_acc)
        all_val.append(val_acc)
        # models.append(model)
    time_ = time() - ts

    print(time_)
    best_test = np.argmax(all_tests)
    print(all_val[best_test])
    
    # for m in models:
    #     test_lab = m.predict(one_hot_data_X_test_)
    #     test_acc = np.sum(test_lab == Y_test_)/len(Y_test_)
    #     if test_acc > best_test:
    #         best_test = test_acc
    #         best_model = m
    
    # val_labs = best_model.predict()
    # val_acc = np.sum(val_labs == Y_val_)/len(Y_val_)
    # print(val_acc)