# computes number of trees, patterns and diversity for Monks 3 under different theta

import pandas as pd
import numpy as np
import json
import os
import time
import pathlib
from treefarms import TREEFARMS

from gmpy2 import mpz
from gmpy2 import hamdist

class BinSequence:
    def __init__(self, value=0, length=0):
        self.x = mpz(value)
        self.len = length
    
    def copy(self):
        return BinSequence(self.x, self.len)
    
    def append_0(self):
        self.len += 1
        return self
    
    def append_1(self):
        self.x = self.x.bit_set(self.len)
        self.len += 1
        return self
    
    def all_bits_set(self):
        if self.len == 0:
            return False
        tmp = mpz(0)
        tmp = tmp.bit_set(self.len)
        return tmp == self.x + 1
    
    def to_array(self):
        array = np.zeros((self.len,))
        next_1 = self.x.bit_scan1(0)
        while next_1 != None:
            array[next_1] = 1
            next_1 = self.x.bit_scan1(next_1 + 1)
        return array
    
    def to_mpz(self):
        return self.x
    
    def from_array(self, arr):
        #for item in arr:
        for i in range(len(arr)):
            item = arr[i]
            if item == 1:
                self.x = self.x.bit_set(self.len)
            self.len += 1
        return self
                
class BinSequenceTests:
    def correctness_test(self):
        s = BinSequence()
        
        s.append_0()
        a = s.to_array()
        assert a[0] == 0
        assert len(a) == 1
        assert s.len == 1
        
        s.append_1()
        print(s.x)
        a = s.to_array()
        assert a[0] == 0
        assert a[1] == 1
        assert len(a) == 2
        assert s.len == 2
        
        s1 = s.copy()
        assert s1.x == s.x
        assert s1.len == s.len
        
        s = BinSequence(1, 2)
        a = s.to_array()
        assert a[0] == 1
        assert a[1] == 0
        
    def all_bits_set_test(self):
        s = BinSequence()
        assert s.all_bits_set() == False
        s.append_1()
        assert s.all_bits_set() == True
        s.append_1()
        assert s.all_bits_set() == True
        s.append_0()
        assert s.all_bits_set() == False
        print("Success")
        
def compute_diversity(patterns, n):
    P = len(patterns)
    assert(P > 0)
    div = 0.
    for i in range(P):
        for j in range(i + 1, P):
            div += hamdist(patterns[i], patterns[j])
    return 2.0 * div / P / P / n

from sklearn.model_selection import KFold
from gosdt import GOSDT

def tune_hyperparameters(df, reg_array, depth_array, k):
    """
    Tune hyperparameters for AlgorithmA using k-fold cross-validation.
    """
    best_score = 0
    best_hyperparameters = None
    
    X, y = df.iloc[:,:-1].values, df.iloc[:,-1].values
    h = df.columns[:-1]

    # Loop over hyperparameters
    for a in reg_array:
        for b in depth_array:
            
            config = {
            "regularization": a,
            "depth_budget": b,
            "allow_small_reg": True}
            
            # Initialize k-fold cross-validation
            kf = KFold(n_splits=k, shuffle=True, random_state = 0)

            # Loop over folds
            test_scores = []
            for train_index, test_index in kf.split(X):
                X_train, X_test = X[train_index], X[test_index]
                y_train, y_test = y[train_index], y[test_index]
                
                X_train = pd.DataFrame(X_train, columns=h)
                X_test = pd.DataFrame(X_test, columns=h)
                y_train = pd.DataFrame(y_train)
                y_test = pd.DataFrame(y_test)

                # Train the model on the training data
                model = GOSDT(config)
                model.fit(X_train, y_train)

                # Compute accuracy score and store test score
                score = model.score(X_test, y_test)
                test_scores.append(score)

            # Compute average test score and update best hyperparameters if necessary
            avg_score = np.mean(test_scores)
            if avg_score > best_score:
                best_score = avg_score
                best_hyperparameters = (a, b)

    return best_hyperparameters, best_score

job_id = int(os.getenv('SLURM_ARRAY_TASK_ID')) - 1

theta_array = [0.01, 0.03, 0.05, 0.07, 0.09, 0.11]

name = 'monks3'
dataset = '../datasets/' + name + '.csv'
df = pd.read_csv(dataset)
params, acc = tune_hyperparameters(df, [0.01], [5], 10)
print("---- params", params, acc)


#parameters to define experiment
regularization = params[0]
depth = params[1]
maximum_noise_limit = acc - 0.55

theta = theta_array[job_id]

X, Y = df.iloc[:, :-1], df.iloc[:, -1]
N = X.shape[0]


config = {
        "regularization": regularization,  # regularization penalizes the tree with more leaves. We recommend to set it to relative high value to find a sparse tree.
        "rashomon_bound_multiplier": theta,  # rashomon bound multiplier indicates how large of a Rashomon set would you like to get
        "depth_budget": depth,
        "allow_small_reg": True,
        "verbose": False
    }


rho_array = np.linspace(0,0.25, 20)
noise_array = [i for i in range(1,26)]

num_models = np.zeros((len(rho_array), len(noise_array)))
num_patterns = np.zeros((len(rho_array), len(noise_array)))
diversity = np.zeros((len(rho_array), len(noise_array)))
average_accuracy = np.zeros((len(rho_array), len(noise_array)))


for rho_id, rho in enumerate(rho_array):
    for noise_id, noise_seed in enumerate(noise_array):
        
        if rho > maximum_noise_limit:
            num_models[rho_id, noise_id] = -1
            num_patterns[rho_id, noise_id] = -1   
            diversity[rho_id, noise_id] = -1
            average_accuracy[rho_id, noise_id] = -1
        else:
        
            np.random.seed(noise_seed)
            random_numbers = np.random.random(size=len(Y))
            new_Y = Y.copy()
            for i in range(len(Y)):
                if random_numbers[i] < rho:
                    new_Y[i] = (new_Y[i] + 1) % 2

            model = TREEFARMS(config)
            model.fit(X, new_Y)
            tree_size = model.get_tree_count()
            num_models[rho_id, noise_id] = tree_size

            rset_pred = np.zeros((tree_size, N))

            train_accuracy = 0
            max_accuracy = 0

            for i in range(tree_size):
                rtree = model[i]
                rset_pred[i,:] = rtree.predict(X)
                acc = rtree.score(X, new_Y)
                train_accuracy += acc
                if acc > max_accuracy:
                    max_accuracy = acc

            patterns = []

            for i in range(tree_size):
                x = BinSequence(0,0)
                patterns += [x.from_array(rset_pred[i, :]).x]

            distinct_patterns = list(set(patterns))

            num_patterns[rho_id, noise_id] = len(distinct_patterns)    
            diversity[rho_id, noise_id] = compute_diversity(distinct_patterns, N)
            average_accuracy[rho_id, noise_id] = train_accuracy / tree_size

 

        
results = [num_models, num_patterns, diversity, average_accuracy]

import pickle

# open a file, where you ant to store the data
file = open('./monks3_' + str(theta), 'wb')

# dump information to that file
pickle.dump(results, file)

# close the file
file.close()
 
import numpy as np
import matplotlib.pyplot as plt
# if using a Jupyter notebook, include:

def to_plot_metric(rho_array, metric, name_fig, name_file):
    plt.figure()
    #idx = np.argwhere(metric == -1)[0][0]
    idx = metric.shape[0]
    arr1 = np.argwhere(metric == -1)
    if len(arr1) > 0:
        idx = np.argwhere(metric == -1)[0][0]

    to_plot = metric[0:idx]
    
    mean = np.mean(to_plot, 1)
    std = np.std(to_plot, 1)

    plt.plot(rho_array[0:idx], mean, 'k-')
    plt.fill_between(rho_array[0:idx], mean-std, mean+std)
    plt.title(name_fig)
    plt.savefig(name_file, bbox_inches='tight')
    
# to_plot_metric(rho_array
#               , diversity
#               , "Diversity " +str(theta) +" "+str(regularization) + " "+str(depth)
#               , "./div_" +str(theta)+'.png')

# to_plot_metric(rho_array
#               , num_patterns
#               , "num_patterns " +str(theta) +" "+str(regularization) + " "+str(depth)
#               , "./num_patterns_" +str(theta)+'.png')

# to_plot_metric(rho_array
#               , num_models
#               , "num_models " +str(theta) +" "+str(regularization) + " "+str(depth)
#               , "./num_models_" +str(theta)+'.png')

# to_plot_metric(rho_array
#               , average_accuracy
#               , "average_accuracy " +str(theta) +" "+str(regularization) + " "+str(depth)
#               , "./average_accuracy_" +str(theta)+'.png')