import numpy as np
from collections import deque
import pandas as pd
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.tree import DecisionTreeClassifier
import sklearn

from collections import defaultdict
import json

class Tree:
    """
    Unified representation of a tree classifier in Python
    This class accepts a dictionary representation of a tree classifier and 
        decodes it into an interactive object
    Additional support for encoding/decoding layer can be layers if the 
        feature-space of the model differs from the feature space of the original data
    """
    def __init__(self, tree):
        # tree is in json format
        self.tree_ = tree

        children_left = []
        children_right = []
        feature = []
        threshold = []
        prediction = []
        stack = [(self.tree_, None, None)]
        while stack:
            node, parent_index, is_left = stack.pop()
            ## To fix null node
            if node is None:
                print("Warning: Null node encountered. Skipping.")
                node = {
                    'prediction': 0,
                    'name': 'Prediction'
                }
            current_index = len(children_left)
            is_leaf = "prediction" in node

            if is_leaf:
                children_left.append(-1)
                children_right.append(-1)
                feature.append(-2)
                threshold.append(-2)
                prediction.append(node["prediction"])
            else:
                feature.append(node["feature"])
                threshold.append(0.5)
                prediction.append(-1)
                children_left.append(None)
                children_right.append(None)
            if parent_index is not None:
                if is_left:
                    children_left[parent_index] = current_index
                else:
                    children_right[parent_index] = current_index
            if not is_leaf:
                stack.append((node["true"], current_index, False))
                stack.append((node["false"], current_index, True))

                
            # children_left.append(-1 if is_leaf else i+len(queue)+1)
            # children_right.append(-1 if is_leaf else i+len(queue)+2)
            # feat = -2 if is_leaf else node["feature"]
            # thres = -2 if is_leaf else 0.5
            # threshold.append(thres)
            # feature.append(feat)
            # prediction.append(node["prediction"] if is_leaf else -1)
            # if not is_leaf:
            #     queue.append(node["true"])
            #     queue.append(node["false"])
        self.children_left = np.array(children_left)
        self.children_right = np.array(children_right)
        self.feature = np.array(feature)
        self.prediction = np.array(prediction)
        self.threshold = np.array(threshold)
        self.max_depth = self.maximum_depth() - 1


    def __str__(self):
        return f"Left Children: {self.children_left}\nRight Children: {self.children_right}\nFeature: {self.feature}\nPrediction: {self.class_counts}\nDepth: {self.max_depth}\nThreshold: {self.threshold}\n"
    
    def export_text(self):
        pass

    # Modified to take original data but need double check on correctness
    def predict(self, X):
        "X is (n, p) numpy array"
        X = np.array(X)
        n_samples = X.shape[0]
        current_node = np.zeros(n_samples, dtype=np.int32)
        while True:
            feature_indices = self.feature[current_node]

            go_left = (X[np.arange(n_samples), feature_indices] <= self.threshold[current_node]).astype(bool)
            destination = np.where(go_left, self.children_left[current_node], self.children_right[current_node])
            current_node = np.where(destination != -1, destination, current_node)
            if np.all(self.children_left[current_node] == self.children_right[current_node]):
                break
        leaf_nodes = current_node
        predictions = self.prediction[leaf_nodes]

        return predictions
    
    # Modified to take original data but need double check on correctness
    def generate_proba(self, X_train, y_train):
        n_samples = X_train.shape[0]

        current_node = np.zeros(n_samples, dtype=np.int32)
        
        while True:
            feature_indices = self.feature[current_node]
            go_left = X_train[np.arange(n_samples), feature_indices] <= self.threshold[current_node]
            destination = np.where(go_left, self.children_left[current_node], self.children_right[current_node])
            current_node = np.where(destination != -1, destination, current_node)
            if np.all(self.children_left[current_node] == self.children_right[current_node]):
                break

        leaf_nodes = current_node

        n_nodes = self.children_left.shape[0]
        n_classes = 2
        self.class_counts = np.zeros((n_nodes, n_classes), dtype=np.int32)

        # For each sample, track and update all nodes along its path
        for sample_idx in range(n_samples):
            node = 0
            x = X_train[sample_idx]
            label = int(y_train[sample_idx])

            while self.children_left[node] != self.children_right[node]:  # while not a leaf
                self.class_counts[node, label] += 1
                feature_idx = self.feature[node]
                threshold = self.threshold[node]
                if x[feature_idx] <= threshold:
                    node = self.children_left[node]
                else:
                    node = self.children_right[node]

            # Update leaf node
            self.class_counts[node, label] += 1

        
        self.class_proba = self.class_counts / self.class_counts.sum(axis=1, keepdims=True)
        self.class_proba[np.isnan(self.class_proba)] = 0
        self.weighted_n_node_samples = self.class_counts.sum(axis=1)
        self.value = self.class_counts[:, np.newaxis, :].astype(np.float64)
        self.value = self.value / self.value.sum(axis=2, keepdims=True)
        self.node_count = self.children_left.shape[0]
        self.n_features_ = X_train.shape[1]
        self.n_features_in_ = X_train.shape[1]
       
    def predict_proba(self, X):
        n_samples = X.shape[0]
        current_node = np.zeros(n_samples, dtype=np.int32)

        while True:
            feature_indices = self.feature[current_node]
            go_left = X[np.arange(n_samples), feature_indices] >= 0.5
            current_node = np.where(go_left, self.children_left[current_node], self.children_right[current_node])
            if np.all(self.children_left[current_node] == self.children_right[current_node]):
                break
        
        leaf_nodes = current_node

        return self.class_proba[leaf_nodes]


    def score(self, X, y, weight=None):
        y_hat = self.predict(X)
        if weight == "balanced":
            return balanced_accuracy_score(y, y_hat)
        else:
            return accuracy_score(y, y_hat)

    def leaves(self):
        """
        Returns
        ---
        natural number : The number of terminal nodes present in this tree
        """
        leaves_counter = 0
        nodes = [self.tree_]
        while len(nodes) > 0:
            node = nodes.pop()
            if "prediction" in node:
                leaves_counter += 1
            else:
                nodes.append(node["true"])
                nodes.append(node["false"])
        return leaves_counter
    
    def nodes(self):
        """
        Returns
        ---
        natural number : The number of nodes present in this tree
        """
        nodes_counter = 0
        nodes = [self.tree_]
        while len(nodes) > 0:
            node = nodes.pop()
            if "prediction" in node:
                nodes_counter += 1
            else:
                nodes_counter += 1
                nodes.append(node["true"])
                nodes.append(node["false"])
        return nodes_counter

    def maximum_depth(self, node=None):
        """
        Returns
        ---
        natural number : the length of the longest decision path in this tree. A single-node tree will return 1.
        """
        if node is None:
            node = self.tree_
        if "prediction" in node:
            return 1
        else:
            return 1 + max(self.maximum_depth(node["true"]), self.maximum_depth(node["false"]))

    def to_json(self, binary=False):
        return [self.node_to_json(0, 0, 0, binary)[0]]
    
    def node_to_json(self, node_id, index, depth, binary):
        if self.prediction[index] != -1:
            leaf = 1.0 if self.prediction[index] == 1 else -1.0
            return {
                "nodeid": int(node_id),
                "leaf": leaf,
            }, node_id
        left_id = node_id + 1
        left_child, new_node_id = self.node_to_json(left_id, self.children_left[index], depth + 1, binary)
        right_id = new_node_id + 1
        right_child, new_node_id = self.node_to_json(right_id, self.children_right[index], depth + 1, binary)
        split = self.feature[index]
        split_condition = self.threshold[index]

        if binary:
            split = feature
            split_condition = 0.5
        return {
            "nodeid": int(node_id),
            "depth": int(depth),
            "split": split,
            "split_condition": split_condition,
            "yes": int(left_id),
            "no": int(right_id),
            "children": [left_child, right_child]
        }, new_node_id
