# B+Tree adapded from https://gist.github.com/benben233/2c8a2a8ab44a7beabad0df1b6658232e
import random
import numpy as np
import itertools


class Node(object):
    """Base node object. It should be index node
    Each node stores keys and children.

    Attributes:
        parent
        cost_dict
    """

    def __init__(self, cost_dict: dict, parent=None):
        """Child nodes are stored in values. Parent nodes simply act as a medium to traverse the tree.
        :type parent: Node"""
        self.keys: list = []
        self.values: list[Node] = []
        self.parent: Node = parent
        self.cost_dict: dict = cost_dict

    def index(self, key):
        """Return the index where the key should be.
        :type key: str
        """
        for i, item in enumerate(self.keys):
            if key < item:
                return i

        return len(self.keys)

    def __getitem__(self, item):
        return self.values[self.index(item)]

    def __setitem__(self, key, value):
        i = self.index(key)
        self.keys[i:i] = [key]
        self.values.pop(i)
        self.values[i:i] = value

    def split(self):
        """Splits the node into two and stores them as child nodes.
        extract a pivot from the child to be inserted into the keys of the parent.
        @:return key and two children
        """
        self.cost_dict["splits"] += 1
        self.cost_dict["parent_splits"] += 1

        left = Node(cost_dict=self.cost_dict, parent=self.parent)

        mid = len(self.keys) // 2

        left.keys = self.keys[:mid]
        left.values = self.values[: mid + 1]
        for child in left.values:
            child.parent = left

        key = self.keys[mid]
        self.keys = self.keys[mid + 1 :]
        self.values = self.values[mid + 1 :]

        return key, [left, self]

    def __delitem__(self, key):
        i = self.index(key)
        del self.values[i]
        if i < len(self.keys):
            del self.keys[i]
        else:
            del self.keys[i - 1]

    def fusion(self):
        self.cost_dict["fusions"] += 1
        self.cost_dict["parent_fusions"] += 1

        index = self.parent.index(self.keys[0])
        # merge this node with the next node
        if index < len(self.parent.keys):
            next_node: Node = self.parent.values[index + 1]
            next_node.keys[0:0] = self.keys + [self.parent.keys[index]]
            for child in self.values:
                child.parent = next_node
            next_node.values[0:0] = self.values
        else:  # If self is the last node, merge with prev
            prev: Node = self.parent.values[-2]
            prev.keys += [self.parent.keys[-1]] + self.keys
            for child in self.values:
                child.parent = prev
            prev.values += self.values

    def borrow_key(self, minimum: int):
        index = self.parent.index(self.keys[0])
        if index < len(self.parent.keys):
            next_node: Node = self.parent.values[index + 1]
            if len(next_node.keys) > minimum:
                self.keys += [self.parent.keys[index]]

                borrow_node = next_node.values.pop(0)
                borrow_node.parent = self
                self.values += [borrow_node]
                self.parent.keys[index] = next_node.keys.pop(0)
                return True
        elif index != 0:
            prev: Node = self.parent.values[index - 1]
            if len(prev.keys) > minimum:
                self.keys[0:0] = [self.parent.keys[index - 1]]

                borrow_node = prev.values.pop()
                borrow_node.parent = self
                self.values[0:0] = [borrow_node]
                self.parent.keys[index - 1] = prev.keys.pop()
                return True

        return False


class Leaf(Node):
    def __init__(self, cost_dict: dict, parent=None, prev_node=None, next_node=None):
        """
        Create a new leaf in the leaf link
        :type prev_node: Leaf
        :type next_node: Leaf
        """
        super(Leaf, self).__init__(cost_dict, parent)
        self.next: Leaf = next_node
        if next_node is not None:
            next_node.prev = self
        self.prev: Leaf = prev_node
        if prev_node is not None:
            prev_node.next = self

    def __getitem__(self, item):
        return self.values[self.keys.index(item)]

    def __setitem__(self, key, value):
        i = self.index(key)
        if key not in self.keys:
            self.keys[i:i] = [key]
            self.values[i:i] = [value]
        else:
            self.values[i - 1] = value

    def split(self):
        self.cost_dict["splits"] += 1

        left = Leaf(
            cost_dict=self.cost_dict,
            parent=self.parent,
            prev_node=self.prev,
            next_node=self,
        )
        mid = len(self.keys) // 2

        left.keys = self.keys[:mid]
        left.values = self.values[:mid]

        self.keys: list = self.keys[mid:]
        self.values: list = self.values[mid:]

        # When the leaf node is split, set the parent key to the left-most key of the right child node.
        return self.keys[0], [left, self]

    def __delitem__(self, key):
        i = self.keys.index(key)
        del self.keys[i]
        del self.values[i]

    def fusion(self):
        self.cost_dict["fusions"] += 1

        if self.next is not None and self.next.parent == self.parent:
            self.next.keys[0:0] = self.keys
            self.next.values[0:0] = self.values
        else:
            self.prev.keys += self.keys
            self.prev.values += self.values

        if self.next is not None:
            self.next.prev = self.prev
        if self.prev is not None:
            self.prev.next = self.next

    def borrow_key(self, minimum: int):
        index = self.parent.index(self.keys[0])
        if index < len(self.parent.keys) and len(self.next.keys) > minimum:
            self.keys += [self.next.keys.pop(0)]
            self.values += [self.next.values.pop(0)]
            self.parent.keys[index] = self.next.keys[0]
            return True
        elif index != 0 and len(self.prev.keys) > minimum:
            self.keys[0:0] = [self.prev.keys.pop()]
            self.values[0:0] = [self.prev.values.pop()]
            self.parent.keys[index - 1] = self.keys[0]
            return True

        return False


class BPlusTree(object):
    """B+ tree object, consisting of nodes.

    Nodes will automatically be split into two once it is full. When a split occurs, a key will
    'float' upwards and be inserted into the parent node to act as a pivot.

    Attributes:
        maximum (int): The maximum number of keys each node can hold.
    """

    root: Node
    cost_dict: dict

    def __init__(self, maximum=4):
        self.cost_dict = {
            "splits": 0,
            "parent_splits": 0,
            "fusions": 0,
            "parent_fusions": 0,
        }
        self.root = Leaf(cost_dict=self.cost_dict)
        self.maximum: int = maximum if maximum > 2 else 2
        self.minimum: int = self.maximum // 2
        self.depth = 0

    def find(self, key) -> Leaf:
        """find the leaf

        Returns:
            Leaf: the leaf which should have the key
        """
        node = self.root
        # Traverse tree until leaf node is reached.
        while type(node) is not Leaf:
            node = node[key]

        return node

    def __getitem__(self, item):
        return self.find(item)[item]

    def query(self, key):
        """Returns a value for a given key, and None if the key does not exist."""
        leaf = self.find(key)
        return leaf[key] if key in leaf.keys else None

    def change(self, key, value):
        """change the value

        Returns:
            (bool,Leaf): the leaf where the key is. return False if the key does not exist
        """
        leaf = self.find(key)
        if key not in leaf.keys:
            return False, leaf
        else:
            leaf[key] = value
            return True, leaf

    def __setitem__(self, key, value, leaf=None):
        """Inserts a key-value pair after traversing to a leaf node. If the leaf node is full, split
        the leaf node into two.
        """
        if leaf is None:
            leaf = self.find(key)
        leaf[key] = value
        if len(leaf.keys) > self.maximum:
            self.insert_index(*leaf.split())

    def insert(self, key, value):
        """
        Returns:
            (bool,Leaf): the leaf where the key is inserted. return False if already has same key
        """
        leaf = self.find(key)
        if key in leaf.keys:
            return False, leaf
        else:
            self.__setitem__(key, value, leaf)
            return True, leaf

    def insert_index(self, key, values: list[Node]):
        """For a parent and child node,
        Insert the values from the child into the values of the parent."""
        parent = values[1].parent
        if parent is None:
            values[0].parent = values[1].parent = self.root = Node(
                cost_dict=self.cost_dict
            )
            self.depth += 1
            self.root.keys = [key]
            self.root.values = values
            return

        parent[key] = values
        # If the node is full, split the  node into two.
        if len(parent.keys) > self.maximum:
            self.insert_index(*parent.split())
        # Once a leaf node is split, it consists of a internal node and two leaf nodes.
        # These need to be re-inserted back into the tree.

    def delete(self, key, node: Node = None):
        if node is None:
            node = self.find(key)
        del node[key]

        if len(node.keys) < self.minimum:
            if node == self.root:
                if len(self.root.keys) == 0 and len(self.root.values) > 0:
                    self.root = self.root.values[0]
                    self.root.parent = None
                    self.depth -= 1
                return

            elif not node.borrow_key(self.minimum):
                node.fusion()
                self.delete(key, node.parent)

    def show(self, node=None, file=None, _prefix="", _last=True):
        """Prints the keys at each level."""
        if node is None:
            node = self.root
        print(_prefix, "`- " if _last else "|- ", node.keys, sep="", file=file)
        _prefix += "   " if _last else "|  "

        if type(node) is Node:
            # Recursively print the key of child nodes (if these exist).
            for i, child in enumerate(node.values):
                _last = i == len(node.values) - 1
                self.show(child, file, _prefix, _last)

    def output(self):
        return tuple(self.cost_dict.values()), self.depth

    def readfile(self, reader):
        i = 0
        for i, line in enumerate(reader):
            s = line.decode().split(maxsplit=1)
            self[s[0]] = s[1]
            if i % 1000 == 0:
                print("Insert " + str(i) + "items")
        return i + 1

    def leftmost_leaf(self) -> Leaf:
        node = self.root
        while type(node) is not Leaf:
            node = node.values[0]
        return node

    def get_obs_space_representation(self, max_depth):
        """
        Returns a 1D array representation of the tree:
        - Keys in each node are padded with zeros to `maximum` keys.
        - The entire structure is padded with zeros to account for the maximum possible nodes at each level.
        """

        max_depth += 1 # Add 1 to account for the root node
        levels = [[] for _ in range(max_depth)]

        def dfs(node: Node, depth: int):
            if depth == max_depth:
                return

            level = levels[depth]

            if node is None:
                level += [0] * self.maximum
                children = []
            else:
                level += node.keys.copy() + [0] * (self.maximum - len(node.keys))
                assert len(level) % self.maximum == 0

                if type(node) is Leaf:
                    children = []
                else:
                    children = node.values.copy()


            while len(children) < self.maximum + 1:
                children.append(None)

            assert len(children) == self.maximum + 1

            for child in children:
                dfs(child, depth + 1)

        # Start traversal from the root
        dfs(self.root, 0)

        # Make sure the layers are filled correctly
        prev_nodes = 1
        for level in levels[1:]:
            cur_nodes = prev_nodes * (self.maximum + 1)
            assert len(level) == cur_nodes * self.maximum
            prev_nodes = cur_nodes

        flattened_representation = list(itertools.chain(*levels))
        return np.array(flattened_representation).flatten()
    
    def get_obs_space_feature_representation(self, max_depth):
        """
        Returns a 1D array representation of the tree with feature engineering:
        - Each node is represented by its minimum key, maximum key, and fill percentage.
        - The structure is padded with zeros to account for the maximum possible nodes at each level.
        """

        max_depth += 1  # Add 1 to account for the root node
        levels = [[] for _ in range(max_depth)]

        def dfs(node: Node, depth: int):
            if depth == max_depth:
                return

            level = levels[depth]

            if node is None:
                level += [0, 0, 0]
                children = []
            else:
                min_key = min(node.keys) if node.keys else 0
                max_key = max(node.keys) if node.keys else 0
                fill_percentage = len(node.keys) / self.maximum

                level += [min_key, max_key, fill_percentage]

                if isinstance(node, Leaf):
                    children = []
                else:
                    children = node.values.copy()

            while len(children) < self.maximum + 1:
                children.append(None)

            assert len(children) == self.maximum + 1

            for child in children:
                dfs(child, depth + 1)

        dfs(self.root, 0)

        # Make sure the layers are filled correctly
        prev_nodes = 1
        for level in levels[1:]:
            cur_nodes = prev_nodes * (self.maximum + 1)
            assert len(level) == cur_nodes * 3  # **3 features per node (min, max, fill percentage).**
            prev_nodes = cur_nodes

        flattened_representation = list(itertools.chain(*levels))
        return np.array(flattened_representation).flatten()


    def reset_cost_dict(self):
        for key in self.cost_dict.keys():
            self.cost_dict[key] = 0

    def calculate_reward(self):
        cost_factors = {
            "splits": 2,
            "parent_splits": 1,
            "fusions": 2,
            "parent_fusions": 1,
        }
        reward = 0
        for key in self.cost_dict.keys():
            reward += cost_factors[key] * self.cost_dict[key]
        self.reset_cost_dict()
        return reward



def calculate_length_max_depth_of_tree(max_tree_values, max_keys):
    """
    Calculate the length of the observation space representation.
    """
    # max_depth = calculate_max_depth(num_inserts,num_values,max_keys)

    max_depth = 1 + np.log(max_tree_values) / (np.log(max_keys + 1))
    max_depth = int(max_depth)
    #print("Max Depth:", max_depth)

    total_keys = 3
    prev_nodes = 1
    for level in range(1, max_depth+1):
        cur_nodes = prev_nodes * (max_keys + 1)
        total_keys += cur_nodes * 3

        values_in_level = cur_nodes * 3
        prev_nodes = cur_nodes
    #print("Total Keys:", total_keys)
    #print("Values in Level:", values_in_level)
    return  total_keys, max_depth


def printTree(tree):
    current_node = tree.root
    if current_node is not None:
        print(current_node.values)
        print(current_node.keys)
        print(current_node.nextKey)
        print(current_node.parent)
        print(current_node.check_leaf)
        print("\n")
        if not current_node.check_leaf:
            for i, item in enumerate(current_node.keys):
                printTree(current_node.keys[i])


