import pandas as pd
import numpy as np
import json
import os
import time
import pathlib
from treefarms import TREEFARMS

from gmpy2 import mpz
from gmpy2 import hamdist

class BinSequence:
    def __init__(self, value=0, length=0):
        self.x = mpz(value)
        self.len = length
    
    def copy(self):
        return BinSequence(self.x, self.len)
    
    def append_0(self):
        self.len += 1
        return self
    
    def append_1(self):
        self.x = self.x.bit_set(self.len)
        self.len += 1
        return self
    
    def all_bits_set(self):
        if self.len == 0:
            return False
        tmp = mpz(0)
        tmp = tmp.bit_set(self.len)
        return tmp == self.x + 1
    
    def to_array(self):
        array = np.zeros((self.len,))
        next_1 = self.x.bit_scan1(0)
        while next_1 != None:
            array[next_1] = 1
            next_1 = self.x.bit_scan1(next_1 + 1)
        return array
    
    def to_mpz(self):
        return self.x
    
    def from_array(self, arr):
        #for item in arr:
        for i in range(len(arr)):
            item = arr[i]
            if item == 1:
                self.x = self.x.bit_set(self.len)
            self.len += 1
        return self
                
class BinSequenceTests:
    def correctness_test(self):
        s = BinSequence()
        
        s.append_0()
        a = s.to_array()
        assert a[0] == 0
        assert len(a) == 1
        assert s.len == 1
        
        s.append_1()
        print(s.x)
        a = s.to_array()
        assert a[0] == 0
        assert a[1] == 1
        assert len(a) == 2
        assert s.len == 2
        
        s1 = s.copy()
        assert s1.x == s.x
        assert s1.len == s.len
        
        s = BinSequence(1, 2)
        a = s.to_array()
        assert a[0] == 1
        assert a[1] == 0
        
    def all_bits_set_test(self):
        s = BinSequence()
        assert s.all_bits_set() == False
        s.append_1()
        assert s.all_bits_set() == True
        s.append_1()
        assert s.all_bits_set() == True
        s.append_0()
        assert s.all_bits_set() == False
        print("Success")
        
def compute_diversity(patterns, n):
    P = len(patterns)
    assert(P > 0)
    div = 0.
    for i in range(P):
        for j in range(i + 1, P):
            div += hamdist(patterns[i], patterns[j])
    return 2.0 * div / P / P / n

from sklearn.model_selection import KFold
from gosdt import GOSDT

def tune_hyperparameters(df, reg_array, depth_array, k):
    """
    Tune hyperparameters for AlgorithmA using k-fold cross-validation.
    """
    best_score = 0
    best_hyperparameters = None
    
    X, y = df.iloc[:,:-1].values, df.iloc[:,-1].values
    h = df.columns[:-1]

    # Loop over hyperparameters
    for a in reg_array:
        for b in depth_array:
            
            config = {
            "regularization": a,
            "depth_budget": b,
            "allow_small_reg": True}
            
            # Initialize k-fold cross-validation
            kf = KFold(n_splits=k, shuffle=True, random_state = 0)

            # Loop over folds
            test_scores = []
            for train_index, test_index in kf.split(X):
                X_train, X_test = X[train_index], X[test_index]
                y_train, y_test = y[train_index], y[test_index]
                
                X_train = pd.DataFrame(X_train, columns=h)
                X_test = pd.DataFrame(X_test, columns=h)
                y_train = pd.DataFrame(y_train)
                y_test = pd.DataFrame(y_test)

                # Train the model on the training data
                model = GOSDT(config)
                model.fit(X_train, y_train)

                # Compute accuracy score and store test score
                score = model.score(X_test, y_test)
                test_scores.append(score)

            # Compute average test score and update best hyperparameters if necessary
            avg_score = np.mean(test_scores)
            if avg_score > best_score:
                best_score = avg_score
                best_hyperparameters = (a, b)

    return best_hyperparameters, best_score




job_id = int(os.getenv('SLURM_ARRAY_TASK_ID')) - 1

data_list = ['car_evaluation'
             , 'monks2'
             , 'monks1'
             , 'monks3'
             , 'bar7'
             , 'compas'
             , 'fico'
             , 'bcw_bin'
             , 'carryout_takeaway'
             , 'restaurant_20'
             , 'bar'
             , 'coffee_house']


name = data_list[job_id]
dataset = '../datasets/' + name + '.csv'
df = pd.read_csv(dataset)
params, acc = tune_hyperparameters(df, [0.01], [3], 10)
print("---- params", params, acc)


#parameters to define experiment
regularization = params[0]
depth = params[1]
maximum_noise_limit = acc - 0.55

theta = 0.05

X, Y = df.iloc[:, :-1], df.iloc[:, -1]
N = X.shape[0]


config = {
        "regularization": regularization,  
        "rashomon_bound_adder": theta,  
	"rashomon_bound_multiplier": 0,
        "depth_budget": depth,
        "allow_small_reg": True,
        "verbose": False
    }


rho_array = np.linspace(0,0.25, 40)
noise_array = [i for i in range(1,51)]

num_models = np.zeros((len(rho_array), len(noise_array)))
num_patterns = np.zeros((len(rho_array), len(noise_array)))
diversity = np.zeros((len(rho_array), len(noise_array)))
average_accuracy = np.zeros((len(rho_array), len(noise_array)))


for rho_id, rho in enumerate(rho_array):
    for noise_id, noise_seed in enumerate(noise_array):
        
        if rho > maximum_noise_limit:
            num_models[rho_id, noise_id] = -1
            num_patterns[rho_id, noise_id] = -1   
            diversity[rho_id, noise_id] = -1
            average_accuracy[rho_id, noise_id] = -1
        else:
        
            np.random.seed(noise_seed)
            random_numbers = np.random.random(size=len(Y))
            new_Y = Y.copy()
            for i in range(len(Y)):
                if random_numbers[i] < rho:
                    new_Y[i] = (new_Y[i] + 1) % 2

            model = TREEFARMS(config)
            model.fit(X, new_Y)
            tree_size = model.get_tree_count()
            num_models[rho_id, noise_id] = tree_size

            rset_pred = np.zeros((tree_size, N))

            train_accuracy = 0
            max_accuracy = 0

            for i in range(tree_size):
                rtree = model[i]
                rset_pred[i,:] = rtree.predict(X)
                acc = rtree.score(X, new_Y)
                train_accuracy += acc
                if acc > max_accuracy:
                    max_accuracy = acc

            patterns = []

            for i in range(tree_size):
                x = BinSequence(0,0)
                patterns += [x.from_array(rset_pred[i, :]).x]

            distinct_patterns = list(set(patterns))

            num_patterns[rho_id, noise_id] = len(distinct_patterns)    
            diversity[rho_id, noise_id] = compute_diversity(distinct_patterns, N)
            average_accuracy[rho_id, noise_id] = train_accuracy / tree_size

 

        
results = [num_models, num_patterns, diversity, average_accuracy]

import pickle

# open a file, where you to store the data
file = open('noise_trees_' + name, 'wb')

# dump information to that file
pickle.dump(results, file)

# close the file
file.close()
 
import numpy as np
import matplotlib.pyplot as plt

def to_plot_metric(rho_array, metric, name_fig, name_file):
    plt.figure()
    #idx = np.argwhere(metric == -1)[0][0]
    idx = metric.shape[0]
    arr1 = np.argwhere(metric == -1)
    if len(arr1) > 0:
        idx = np.argwhere(metric == -1)[0][0]

    to_plot = metric[0:idx]
    
    mean = np.mean(to_plot, 1)
    std = np.std(to_plot, 1)

    plt.plot(rho_array[0:idx], mean, 'k-')
    plt.fill_between(rho_array[0:idx], mean-std, mean+std)
    plt.title(name_fig)
    plt.savefig(name_file, bbox_inches='tight')
    
# to_plot_metric(rho_array
#               , diversity
#               , "Diversity " +name +" "+str(regularization) + " "+str(depth)
#               , "div_" +name+'.png')

# to_plot_metric(rho_array
#               , num_patterns
#               , "num_patterns " +name +" "+str(regularization) + " "+str(depth)
#               , "num_patterns_" +name+'.png')

# to_plot_metric(rho_array
#               , num_models
#               , "num_models " +name +" "+str(regularization) + " "+str(depth)
#               , "num_models_" +name+'.png')

# to_plot_metric(rho_array
#               , average_accuracy
#               , "average_accuracy " +name +" "+str(regularization) + " "+str(depth)
#               , "average_accuracy_" +name+'.png')