

import json

import numpy as np
from pred_pattern import BinSequence, pred_distances
from collections import defaultdict
class SklearnTree:
    def __init__(self, tree):
        self.tree_ = tree
        
    def predict(self, X):
        return self.tree_.predict(X)
    
    def predict_proba(self, X):
        return self.tree_.predict_proba(X)
    
    def score(self, X, y):
        return self.tree_.score(X, y)

class RFWrapper:
    def __init__(self, rset, n_estimators, rules, rng, X_train, y_train):
        self.rset = rset
        self.estimators_ = []
        self.n_features_ = self.rset.n_features_
        self.n_features_in_ = self.rset.n_features_in_
        self.n_classes_ = self.rset.n_classes_

        tree_indices = []
        if rules == "sparsest" or rules == "densest":
            leaves = []
            for i in range(self.rset.ntrees()):
                leaf = self.rset.get_tree_num_leaves(i)
                leaves.append(leaf)
            leaves = np.array(leaves)
            if rules == "sparsest":
                sorted_indices = np.argsort(leaves)
            else:
                sorted_indices = np.argsort(-leaves)
            if len(sorted_indices) < n_estimators:
                raise ValueError("Not enough trees to select from")
            print(leaves[sorted_indices[:5]])
            tree_indices.extend(sorted_indices[:n_estimators])
            self.init_estimators(tree_indices, X_train, y_train)
            return

        optimal_idx = self.rset.special_tree["optimal_tree"]
        tree_indices = [optimal_idx]
        if n_estimators == 1:
            self.init_estimators(tree_indices, X_train, y_train)
            return
        if rules == "random":
            candidates = np.arange(self.rset.ntrees())
            candidates = np.delete(candidates, optimal_idx)
            indices = rng.choice(candidates, n_estimators - 1, replace=False)
            tree_indices.extend(indices)
        else:
            predictions = self.rset.predict_all(X_train)
            predictions = np.array([BinSequence(0, 0).from_array(pred).x for pred in predictions])
            distance_to_opt = [pred_distances(predictions[i], predictions[optimal_idx]) for i in range(len(predictions))]
            if rules == "closest":
                sorted_indices = np.argsort(distance_to_opt)
                sorted_indices = np.delete(sorted_indices, optimal_idx)
                tree_indices.extend(sorted_indices[:n_estimators - 1])
            elif rules == "increment":
                sorted_indices = np.argsort(distance_to_opt)
                sorted_indices = np.delete(sorted_indices, optimal_idx)
                total_tree = len(sorted_indices)
                increment = total_tree // n_estimators
                current_idx = 0
                while len(tree_indices) < n_estimators:
                    current_idx += increment
                    tree_indices.append(sorted_indices[current_idx])
            elif rules == "farthest":
                candidates = np.arange(self.rset.ntrees())
                candidates = np.delete(candidates, optimal_idx)
                # Loop over to find the farthest trees among the existing trees
                while len(tree_indices) < n_estimators:
                    best_idx = None
                    max_total_dist = -1

                    for idx in candidates:
                        # Calculate the total distance to the existing trees
                        total_dist = np.sum([pred_distances(predictions[idx], predictions[i]) for i in tree_indices]) 

                        if total_dist > max_total_dist:
                            max_total_dist = total_dist
                            best_idx = idx
                    tree_indices.append(best_idx)
                    candidates = np.delete(candidates, np.where(candidates == best_idx)[0][0])
            else:
                raise NotImplementedError("Not Implemented")
        self.init_estimators(tree_indices, X_train, y_train)




    def init_estimators(self, idices, X, y):
        """
        Initialize the estimators for the Random Forest
        """
        self.estimators_ = []
        for i in idices:
            tree = self.rset.get_tree(i)
            tree.generate_proba(X, y)
            sklearn_tree = SklearnTree(tree)
            sklearn_tree.n_features_in_ = self.n_features_in_
            self.estimators_.append(sklearn_tree)

    def predict(self, X):
        """
        Predict the class labels for the input data
        """
        n_samples = X.shape[0]
        predictions = np.zeros((n_samples, self.n_classes_))
        
        for tree in self.estimators_:
            pred = tree.predict(X)
            predictions[np.arange(n_samples), pred] += 1
        
        return np.argmax(predictions, axis=1)

    def predict_proba(self, X):
        
        proba = np.zeros((X.shape[0], self.n_classes_))
        for tree in self.estimators_:
            proba += tree.predict_proba(X)
        proba /= len(self.estimators_)
        return proba

    def score(self, X, y):
        """
        Calculate the accuracy of the model
        """
        predictions = self.predict(X)
        return np.mean(predictions == y)
    
