import numpy as np
from collections import deque
import pandas as pd
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.tree import DecisionTreeClassifier
import sklearn

from collections import defaultdict
from module.threshold_guess import ThresholdGuess
import json

class Tree:
    """
    Unified representation of a tree classifier in Python
    This class accepts a dictionary representation of a tree classifier and 
        decodes it into an interactive object
    Additional support for encoding/decoding layer can be layers if the 
        feature-space of the model differs from the feature space of the original data
    """
    def __init__(self, tree, encoder: ThresholdGuess=None, type="json"):
        # tree is in json format
        self.encoder = encoder
        if type == "ROCT-N":
            self.roct_n_to_tree(tree)
            return
        self.tree = tree

        children_left = []
        children_right = []
        feature = []
        prediction = []
        
        queue = deque([self.tree])
        while queue:
            i = len(children_left)
            node = queue[0]

            ## To fix null node
            if node is None:
                node = {
                    'prediction': 0,
                    'name': 'Prediction'
                }
            queue.popleft()
            is_leaf = "prediction" in node
            children_left.append(i if is_leaf else i+len(queue)+1)
            children_right.append(i if is_leaf else i+len(queue)+2)
            feature.append(-1 if is_leaf else node["feature"])
            prediction.append(node["prediction"] if is_leaf else -1)
            if not is_leaf:
                queue.append(node["true"])
                queue.append(node["false"])
        self.children_left = np.array(children_left)
        self.children_right = np.array(children_right)
        self.feature = np.array(feature)
        self.prediction = np.array(prediction)

    def __str__(self):
        return f"Left Children: {self.children_left}\nRight Children: {self.children_right}\nFeature: {self.feature}\nPrediction: {self.prediction}\n"
    
    def export_text(self):
        pass

    
    def predict(self, X):
        "X is (n, p) numpy array"

        n_samples = X.shape[0]
        current_node = np.zeros(n_samples, dtype=np.int32)
        while True:
            feature_indices = self.feature[current_node]

            go_left = X[np.arange(n_samples), feature_indices]
            go_left = (go_left >= 0.5).astype(bool)
            current_node = np.where(go_left, self.children_left[current_node], self.children_right[current_node])
            if np.all(self.children_left[current_node] == self.children_right[current_node]):
                break
        leaf_nodes = current_node
        predictions = self.prediction[leaf_nodes]

        return predictions
    
    def generate_proba(self, X_train, y_train):
        n_samples = X_train.shape[0]

        current_node = np.zeros(n_samples, dtype=np.int32)
        
        while True:
            feature_indices = self.feature[current_node]
            go_left = X_train[np.arange(n_samples), feature_indices] >= 0.5
            current_node = np.where(go_left, self.children_left[current_node], self.children_right[current_node])
            if np.all(self.children_left[current_node] == self.children_right[current_node]):
                break

        leaf_nodes = current_node


        n_classes = 2
        self.class_counts = np.zeros((self.children_left.shape[0], n_classes), dtype=np.int32)

        for leaf, label in zip(leaf_nodes, y_train):
            label = int(label)
            self.class_counts[leaf, label] += 1
        
        self.class_proba = self.class_counts / self.class_counts.sum(axis=1, keepdims=True)
        self.class_proba[np.isnan(self.class_proba)] = 0

    def predict_proba(self, X):

        n_samples = X.shape[0]
        current_node = np.zeros(n_samples, dtype=np.int32)

        while True:
            feature_indices = self.feature[current_node]
            go_left = X[np.arange(n_samples), feature_indices] >= 0.5
            current_node = np.where(go_left, self.children_left[current_node], self.children_right[current_node])
            if np.all(self.children_left[current_node] == self.children_right[current_node]):
                break
        
        leaf_nodes = current_node

        return self.class_proba[leaf_nodes]


    def score(self, X, y, weight=None):
        y_hat = self.predict(X)
        if weight == "balanced":
            return balanced_accuracy_score(y, y_hat)
        else:
            return accuracy_score(y, y_hat)

    def leaves(self):
        """
        Returns
        ---
        natural number : The number of terminal nodes present in this tree
        """
        leaves_counter = 0
        nodes = [self.tree]
        while len(nodes) > 0:
            node = nodes.pop()
            if "prediction" in node:
                leaves_counter += 1
            else:
                nodes.append(node["true"])
                nodes.append(node["false"])
        return leaves_counter
    
    def nodes(self):
        """
        Returns
        ---
        natural number : The number of nodes present in this tree
        """
        nodes_counter = 0
        nodes = [self.tree]
        while len(nodes) > 0:
            node = nodes.pop()
            if "prediction" in node:
                nodes_counter += 1
            else:
                nodes_counter += 1
                nodes.append(node["true"])
                nodes.append(node["false"])
        return nodes_counter

    def maximum_depth(self, node=None):
        """
        Returns
        ---
        natural number : the length of the longest decision path in this tree. A single-node tree will return 1.
        """
        if node is None:
            node = self.tree
        if "prediction" in node:
            return 1
        else:
            return 1 + max(self.maximum_depth(node["true"]), self.maximum_depth(node["false"]))

    def to_json(self, binary=False):
        return [self.node_to_json(0, 0, 0, binary)[0]]
    
    def node_to_json(self, node_id, index, depth, binary):
        if self.prediction[index] != -1:
            leaf = 1.0 if self.prediction[index] == 1 else -1.0
            return {
                "nodeid": int(node_id),
                "leaf": leaf,
            }, node_id
        left_id = node_id + 1
        left_child, new_node_id = self.node_to_json(left_id, self.children_left[index], depth + 1, binary)
        right_id = new_node_id + 1
        right_child, new_node_id = self.node_to_json(right_id, self.children_right[index], depth + 1, binary)
        feature = self.feature[index]
        if self.encoder is None:
            split = feature
            split_condition = 0.5
        else:
            thresholds = self.encoder.thresholds
            split = int(thresholds[feature][0])
            split_condition = thresholds[feature][1]
        if binary:
            split = feature
            split_condition = 0.5
        return {
            "nodeid": int(node_id),
            "depth": int(depth),
            "split": split,
            "split_condition": split_condition,
            "yes": int(left_id),
            "no": int(right_id),
            "children": [left_child, right_child]
        }, new_node_id

    def roct_n_to_tree(self, roct):
        """
        Converts a Robust OCT model to a tree model
        """ 
        num_nodes = roct._tree.total_nodes
        children_left = [-1 for _ in range(num_nodes)]
        children_right = [-1 for _ in range(num_nodes)]
        feature = [-1 for _ in range(num_nodes)]
        prediction = [-1 for _ in range(num_nodes)]

        for node in np.arange(1, roct._tree.total_nodes + 1):
            for label in [0, 1]:
                if roct.w_value[(node, label)] > 0.5:
                    prediction[node - 1] = label
        
        for node in np.arange(1, roct._tree.total_nodes + 1):
            idx = node-1
            if prediction[idx] != -1:
                children_left[idx] = idx
                children_right[idx] = idx
            for f, theta in roct._f_theta_indices:
                if (node, f, theta) not in roct.b_value:
                    continue
                if roct.b_value[node, f, theta] > 0.5:
                    feature[idx] = int(f.split('_')[1])
                    if theta == 0:
                        children_right[idx] = 2 * node - 1
                        children_left[idx] = 2 * node
                    else:
                        children_left[idx] = 2 * node - 1
                        children_right[idx] = 2 * node
                    break
        
        node_to_remove = []
        node_map = defaultdict(lambda: -1)
        node_map[0] = 0
        valid_idx = 1
        for idx in range(1, len(children_left)):
            parent = (idx - 1) // 2
            if children_left[parent] == children_right[parent]:
                node_to_remove.append(idx)
            else:
                node_map[idx] = valid_idx
                valid_idx += 1

        for idx in range(len(children_left)):
            children_left[idx] = node_map[children_left[idx]]
            children_right[idx] = node_map[children_right[idx]]

        children_left, children_right, feature, prediction = zip(*((a, b, c, d) for a, b, c, d in zip(children_left, children_right, feature, prediction) if a != -1))

        self.children_left = np.array(children_left)
        self.children_right = np.array(children_right)
        self.feature = np.array(feature)
        self.prediction = np.array(prediction)
