from logging import root
import torch
import random


def mean_dis(sample, data):
    return torch.norm(sample - data, dim=1).mean().item()


def min_dis(sample, data):
    return torch.norm(sample - data, dim=1).min().item()


def prototype_dis(sample, data):
    return torch.norm(sample - data.min(dim=0), dim=1).item()


class Parse_Tree:
    def __init__(self, name, distance='min'):
        """
        Initializes a new node with the given name and no children.

        :param name: The name of the node.
        """
        self.name = name
        self.depth = None
        self.index = None
        self.data = None
        self.children = []
        self.distance_name = distance
        if distance == 'mean':
            # print("Using the Mean distance")
            self.distance = mean_dis
        if distance == 'min':
            # print("Using the Min distance")
            self.distance = min_dis
        if distance == 'prototype':
            # print("Using the prototype distance")
            self.distance = prototype_dis
        else:
            NotImplementedError()
            
    def get_all_nodes(self):
        """
        Returns a list of all nodes in the tree.

        :return: A list of all nodes in the tree.
        """
        nodes = [self]
        for child in self.children:
            nodes.extend(child.get_all_nodes())
        return nodes
    
    def get_all_attrnodes(self):
        """
        Returns a list of all nodes in the tree.

        :return: A list of all nodes in the tree.
        """
        nodes = self.get_all_nodes()

        return [n for n in nodes if n.data is not None]
    
    def get_all_leafnodes(self):
        """
        Returns a list of all nodes in the tree.

        :return: A list of all nodes in the tree.
        """
        nodes = self.get_all_nodes()

        return [n for n in nodes if len(n.children) ==0]
    
    def max_depth(self):
        """
        Returns the maximum depth of the tree.

        :return: The maximum depth of the tree.
        """
        if not self.children:
            return 1

        max_child_depth = max(child.max_depth() for child in self.children)
        return max_child_depth + 1

    def add_child(self, child):
        """
        Adds the given node as a child of this node.

        :param child: The node to add as a child.
        """
        self.children.append(child)

    def __repr__(self):
        """
        Returns a string representation of the node and its children.

        :return: A string representation of the node.
        """
        return self._repr_helper(0)

    def _repr_helper(self, indent):
        """
        Returns a string representation of the node and its children, with the given indentation.

        :param indent: The number of spaces to indent the string.
        :return: A string representation of the node and its children.
        """
        if self.data is None:
            result = " " * indent + "- " + self.name + \
                ":d{}_i{}".format(self.depth, self.index) + str(self.data) + "\n"
        else:
            
            result = " " * indent + "- " + self.name + \
                ":d{}_i{}".format(self.depth, self.index) + str(self.data.shape) + "\n"

        for child in self.children:
            result += child._repr_helper(indent + 2)
        return result

    def to_dict(self):
        """
        Converts this Parse_Tree object to a dictionary representation.

        :return: A dictionary representation of this Parse_Tree node.
        """
        if not self.children:
            return self.name

        child_dict = {}
        child_list = []

        for child in self.children:
            child_data = child.to_dict()
            if isinstance(child_data, dict):
                child_dict.update(child_data)
            else:
                child_list.append(child_data)

        if child_dict:
            if self.name == 'root':
                return child_dict
            else:
                return {self.name: child_dict}
        else:
            return {self.name: child_list}

    @classmethod
    def from_dict(cls, d, root=None, depth=0, index=0):
        """
        Creates a new tree from the given dictionary.

        :param d: A dictionary representing a tree.
        :param root: The root node of the tree.
        :param depth: The depth of the current node.
        :param index: The index of the current node among its siblings.
        :return: A new node object.
        """

        # Check if dictionary is empty
        if not d:
            return None

        # Create root node if not provided, and set its depth and index
        node = cls("root") if root is None else root
        node.depth = depth
        node.index = index

        # Initialize stack with root node, dictionary, and current depth
        stack = [(node, d, depth)]

        while stack:
            # Pop current node, dictionary, and current depth from stack
            current_node, current_dict, current_depth = stack.pop()

            # Initialize current index to 0
            current_index = 0

            # Iterate over key-value pairs in dictionary
            for key, value in current_dict.items():
                # Create new child node with key as its name
                child = Parse_Tree(key)

                # Set child node's depth and index
                child.depth = current_depth + 1
                child.index = current_index + 1

                # Increment current index and add child node to current node's children
                current_index += 1
                current_node.add_child(child)

                # Check if value is dictionary, list, or other
                if isinstance(value, dict):
                    # If value is dictionary, append new tuple to stack with child node, value, and child node's depth
                    stack.append((child, value, child.depth))
                elif isinstance(value, list):
                    # If value is list, create new child node for each item in list
                    for idx, item in enumerate(value):
                        n = Parse_Tree(item)
                        n.depth = child.depth + 1
                        n.index = idx + 1
                        child.add_child(n)
                else:
                    # If value is other, create new child node with value as its name
                    n = Parse_Tree(value)
                    n.depth = child.depth + 1
                    n.index = 1
                    child.add_child(n)

        # Return root node of created tree
        return node

    def set_values(self, values_dict):
        """
        Sets the tensor values for the nodes in the tree based on the input dictionary.

        :param values_dict: A dictionary with node names as keys and tensors as values.
        """
        assert isinstance(values_dict, dict) 
        if self.name in values_dict.keys():
            self.data = values_dict[self.name]

        for child in self.children:
            child.set_values(values_dict)

    def distance_to_leaf_nodes(self, sample):
        """
        Computes the distance between an input sample (1xD tensor) and the data (NxD tensor) in all leaf nodes of the tree.

        :param sample: The input sample (1xD tensor) for which the distance to the data in all leaf nodes is to be calculated.
        :return: A dictionary containing the leaf node names as keys and the corresponding mean distances as values.
        """
        def _distance_to_leaf_nodes_helper(node, sample):
            distances = {}

            # If the current node has no children, it is a leaf node
            if not node.children:
                # If the current node has data, calculate the mean distance
                if node.data is not None:
                    distance = self.distance(sample, node.data)
                    distances[node.name] = distance
            else:
                # Iterate through the children of the current node and recursively compute distances
                for child in node.children:
                    child_distances = _distance_to_leaf_nodes_helper(
                        child, sample)
                    distances.update(child_distances)

            return distances

        # Call the recursive helper function starting with the root node
        leaf_node_distances = _distance_to_leaf_nodes_helper(self, sample)

        return leaf_node_distances

    def _get_matching_paths(self, node, match_node_names, current_path):
        """
        A helper function for top_matches. Recursively gets paths to the matching nodes for the children nodes of the given node.

        :param node: The current node in the tree.
        :param match_node_names: The set of names of the top-k match nodes.
        :param current_path: The path to the current node.
        :return: A list of paths to the matching nodes.
        """
        assert isinstance(node, Parse_Tree)
        if node.name in match_node_names:
            return [current_path + [node.index]]

        matching_paths = []
        for child in node.children:
            child_path = current_path + [node.index]
            matching_child_paths = self._get_matching_paths(
                child, match_node_names, child_path)
            matching_paths.extend(matching_child_paths)

        return matching_paths

    def top_matches_nomerge(self, sample, k=5):
        """
        Returns the top-k matches according to the leaf node distances while preserving the tree structure.

        :param sample: The input sample (1xD tensor) for which the top-k matches are to be found.
        :param k: The number of top matches to return (default: 3).
        :return: A list of k trees, each containing a match and its children.
        """
        # Compute the distance to all leaf nodes
        leaf_node_distances = self.distance_to_leaf_nodes(sample)

        # Sort leaf nodes by distance and get the top-k matches
        sorted_distances = sorted(
            leaf_node_distances.items(), key=lambda x: x[1])[:k]

        # Get the names of the top-k match nodes
        match_node_names = set([name for name, _ in sorted_distances])

        # Call the helper function to get paths to the matching nodes starting with the root node
        matching_paths = self._get_matching_paths(self, match_node_names, [])

        # Return a list of k trees containing the top-k matches and their children
        return matching_paths

    def build_and_merge_trees(self, node, match_node_names):
        """
        A helper function for top_matches. Recursively builds and merges trees for the children nodes of the given node.

        :param node: The current node in the tree.
        :param match_node_names: The set of names of the top-k match nodes.
        :return: A new node with the same name and data as the current node, and its children set to the matching children.
        """
        if node.name in match_node_names:
            return node
        else:
            matching_children = [
                self.build_and_merge_trees(child, match_node_names) for child in node.children
            ]
            matching_children = [
                child for child in matching_children if child is not None]
            if matching_children:
                new_node = Parse_Tree(node.name, node.distance_name)
                new_node.data = node.data
                new_node.depth = node.depth
                new_node.index = node.index
                new_node.children = matching_children
                return new_node
            else:
                return None

    def top_matches(self, sample, k=5):
        """
        Returns the top-k matches according to the leaf node distances while preserving the tree structure and merging trees with the same nodes.

        :param sample: The input sample (1xD tensor) for which the top-k matches are to be found.
        :param k: The number of top matches to return (default: 3).
        :return: The merged tree containing the top-k matches.
        """
        # Compute the distance to all leaf nodes
        leaf_node_distances = self.distance_to_leaf_nodes(sample)

        # Sort leaf nodes by distance and get the top-k matches
        sorted_distances = sorted(
            leaf_node_distances.items(), key=lambda x: x[1])[:k]

        # Get the names of the top-k match nodes
        match_node_names = set([name for name, _ in sorted_distances])

        # Call the helper function to build and merge trees starting with the root node
        merged_tree = self.build_and_merge_trees(self, match_node_names)

        # Return the merged tree containing the top-k matches
        return merged_tree
    
    def random_subtree(self, k=5):
        """
        Returns the top-k matches according to the leaf node distances while preserving the tree structure and merging trees with the same nodes.

        :param sample: The input sample (1xD tensor) for which the top-k matches are to be found.
        :param k: The number of top matches to return (default: 3).
        :return: The merged tree containing the top-k matches.
        """
        all_nodes = self.get_all_attrnodes()
        
        
        random.shuffle(all_nodes)
        match_node_names = [n.name for n in all_nodes][:k]

        # Call the helper function to build and merge trees starting with the root node
        merged_tree = self.build_and_merge_trees(self, match_node_names)

        # Return the merged tree containing the top-k matches
        return merged_tree

    @classmethod
    def edit_distance(cls, self, other):
        """
        Calculates the edit distance between this tree and another tree.
        dynamic programming

        :param other: The other tree.
        :return: The edit distance between the two trees.
        """
        # Ensure that the other object is a Parse_Tree instance
        assert isinstance(
            other, Parse_Tree), "Other object must be a Parse_Tree object"
        assert isinstance(
            self, Parse_Tree), "Other object must be a Parse_Tree object"

        # Count the number of nodes in each tree
        n = self._count_nodes()
        m = other._count_nodes()

        # Initialize the matrix with size (n+1)x(m+1)
        matrix = [[0] * (m+1) for _ in range(n+1)]

        # Initialize the first row and column with the deletion and insertion costs
        for i in range(1, n+1):
            matrix[i][0] = i
        for j in range(1, m+1):
            matrix[0][j] = j

        # Iterate over the remaining cells and compute the substitution, deletion, insertion, and match costs
        for i in range(1, n+1):
            for j in range(1, m+1):
                if self._get_node(i).name == other._get_node(j).name:
                    substitution_cost = 0
                else:
                    substitution_cost = 1
                deletion_cost = matrix[i-1][j] + 1
                insertion_cost = matrix[i][j-1] + 1
                substitution_or_match_cost = matrix[i -
                                                    1][j-1] + substitution_cost
                matrix[i][j] = min(
                    deletion_cost, insertion_cost, substitution_or_match_cost)

        # The final value is the edit distance
        return matrix[n][m]

    def _count_nodes(self):
        """
        Counts the number of nodes in the tree.

        :return: The number of nodes in the tree.
        """
        count = 1
        for child in self.children:
            count += child._count_nodes()
        return count

    def _get_node(self, i):
        """
        Returns the ith node in the tree in a depth-first traversal order.

        :param i: The index of the node to retrieve (starting from 1).
        :return: The ith node in the tree.
        """
        if i == 1:
            return self
        i -= 1
        for child in self.children:
            child_size = child._count_nodes()
            if i <= child_size:
                return child._get_node(i)
            i -= child_size
        return None

    @classmethod
    def mcs_distance(cls, tree1, tree2):
        """
        Computes the Maximum Common Subgraph (MCS) distance between this tree and another tree.

        :param other: Another Parse_Tree object to compute the distance to.
        :return: The MCS distance between this tree and the other tree.
        """
        assert isinstance(tree1, Parse_Tree) and isinstance(tree2, Parse_Tree)
        # Compute the MCS distance recursively
        return tree1._mcs_distance_helper(tree1, tree2)
    
    @classmethod
    def normalize_mcs_distance(cls, tree1, tree2):
        """
        Computes the Maximum Common Subgraph (MCS) distance between this tree and another tree.

        :param other: Another Parse_Tree object to compute the distance to.
        :return: The MCS distance between this tree and the other tree.
        """
        assert isinstance(tree1, Parse_Tree) and isinstance(tree2, Parse_Tree)
        # Compute the MCS distance recursively
        normalize_factor = (tree1._count_nodes() * tree2._count_nodes()) ** 0.5
        return tree1._mcs_distance_helper(tree1, tree2) / normalize_factor

    @classmethod
    def find_largest_common_subtree(cls, root1, root2):
        if root1 is None or root2 is None:
            return None
        if root1.name == root2.name:
            subtree = Parse_Tree(root1.name, root1.distance_name)
            children1 = sorted(root1.children, key=lambda x: x.name)
            children2 = sorted(root2.children, key=lambda x: x.name)

            for child1, child2 in zip(children1, children2):
                common_child = cls.find_largest_common_subtree(child1, child2)
                if common_child is not None:
                    subtree.add_child(common_child)
            return subtree
        else:
            return None

    def _mcs_distance_helper(self, root1, root2):
        """
        Helper function for computing the MCS distance between two trees.

        :param other: Another Parse_Tree object to compute the distance to.
        :param node1: The current node of this tree.
        :param node2: The current node of the other tree.
        :return: The MCS distance between the two trees.
        """
        common_subtree = Parse_Tree.find_largest_common_subtree(root1, root2)
        if common_subtree is None:
            return 0
        else:
            return common_subtree._count_nodes()

    @classmethod
    def normalize_tree_kernel_distance(cls, tree1, tree2, lambda_param=0.5):
        """
        Calculates the normalized tree kernel distance between two trees.

        :param tree1: The first tree.
        :param tree2: The second tree.
        :param lambda_param: The decay factor, which controls the contribution of subtrees of different depths.
        :return: The normalized tree kernel distance between the two trees.
        """
        assert isinstance(tree1, Parse_Tree) and isinstance(tree2, Parse_Tree)
        kernel_distance = tree1._tree_kernel_helper(tree1, tree2, lambda_param)
        normalization_factor = (tree1._tree_kernel_helper(tree1, tree1, lambda_param) *
                                tree1._tree_kernel_helper(tree2, tree2, lambda_param)) ** 0.5
        return kernel_distance / normalization_factor
    
    @classmethod
    def tree_kernel_distance(cls, tree1, tree2, lambda_param=0.5):
        """
        Calculates the tree kernel distance between this tree and another tree.

        :param other: The other tree to calculate the distance with.
        :param lambda_param: The decay factor, which controls the contribution of subtrees of different depths.
        :return: The tree kernel distance between the two trees.
        """
        assert isinstance(tree1, Parse_Tree) and isinstance(tree2, Parse_Tree)
        return tree1._tree_kernel_helper(tree1, tree2, lambda_param)

    def _tree_kernel_helper(self, node1, node2, lambda_param):
        """
        Helper function to recursively calculate the tree kernel distance between two nodes.

        :param node1: The first node.
        :param node2: The second node.
        :param lambda_param: The decay factor.
        :return: The tree kernel distance between the two nodes.
        """
        if node1 is None or node2 is None:
            return 0

        if node1.name == node2.name:
            num_common_subtrees = 1
            for i in range(len(node1.children)):
                for j in range(len(node2.children)):
                    num_common_subtrees += self._tree_kernel_helper(
                        node1.children[i], node2.children[j], lambda_param)
            return lambda_param * num_common_subtrees
        else:
            return 0


def test_edit_distance():
    from zss import simple_distance, Node

    # Create a sample tree using the zss package
    A = (
        Node("f")
        .addkid(Node("a")
                .addkid(Node("h"))
                .addkid(Node("c")
                        .addkid(Node("l"))))
        .addkid(Node("e"))
    )
    B = (
        Node("f")
        .addkid(Node("a")
                .addkid(Node("d"))
                .addkid(Node("c")
                        .addkid(Node("b"))))
        .addkid(Node("e"))
    )
    # Create a similar tree using the Parse_Tree class

    dict1 = {'root': {'f': ['a', 'e']}}
    dict1['a'] = ['h', 'c']
    dict1['c'] = ['l']

    dict2 = {'root': {'f': ['a', 'e']}}
    dict2['a'] = ['d', 'c']
    dict2['c'] = ['b']
    tree1 = Parse_Tree.from_dict(dict1)
    tree2 = Parse_Tree.from_dict(dict2)

    # Compute the edit distance using both implementations
    distance1 = Parse_Tree.edit_distance(tree1, tree2)
    distance2 = simple_distance(A, B)

    # Print the results
    print(f"My implementation: {distance1}")
    print(f"zss implementation: {distance2}")

    # Print a message indicating that all tests passed
    print("All edit distance tests passed!")


def test_mcs_distance():
    # Define two trees with different structures
    tree_dict1 = {
        "A": {
            "B": {
                "D": {},
                "E": {}
            },
            "C": {
                "F": {},
                "G": {}
            }
        }
    }

    tree_dict2 = {
        "X": {
            "Y": {
                "Z": {}
            }
        }
    }

    # Convert the dictionaries to trees
    tree1 = Parse_Tree.from_dict(tree_dict1)
    tree2 = Parse_Tree.from_dict(tree_dict2)
    # print(Parse_Tree.find_largest_common_subtree(tree1, tree2))

    # Compute the MCS distance between the trees
    mcs_distance = Parse_Tree.mcs_distance(tree1, tree2)
    print(mcs_distance)

    # Check that the distance is 0
    # assert mcs_distance == 0

    # Define two trees with the same structure
    tree_dict3 = {
        "A": {
            "B": {
                "D": {}}}
    }

    tree_dict4 = {
        "A": {
            "B": {
                "D": {}}}
    }

    # Convert the dictionaries to trees
    tree3 = Parse_Tree.from_dict(tree_dict3)
    tree4 = Parse_Tree.from_dict(tree_dict4)

    # Compute the MCS distance between the trees
    mcs_distance = Parse_Tree.mcs_distance(tree3, tree4)
    print(mcs_distance)

    # # Check that the distance is 3 (the number of nodes in the common subgraph)
    # assert mcs_distance == 3

    # print("All tests passed!")


def test_tree_kernel_distance():
    # Create tree 1
    root1 = Parse_Tree('A')
    child1_1 = Parse_Tree('B')
    child1_2 = Parse_Tree('C')
    root1.add_child(child1_1)
    root1.add_child(child1_2)
    child1_1.add_child(Parse_Tree('D'))
    child1_2.add_child(Parse_Tree('E'))

    # Create tree 2
    root2 = Parse_Tree('A')
    child2_1 = Parse_Tree('B')
    child2_2 = Parse_Tree('C')
    root2.add_child(child2_1)
    root2.add_child(child2_2)
    child2_1.add_child(Parse_Tree('D'))
    child2_2.add_child(Parse_Tree('F'))

    # Create tree 3
    root3 = Parse_Tree('F')
    child3_1 = Parse_Tree('G')
    child3_2 = Parse_Tree('H')
    root3.add_child(child3_1)
    root3.add_child(child3_2)
    child3_1.add_child(Parse_Tree('I'))
    child3_2.add_child(Parse_Tree('J'))

    print(root1)
    print(root2)
    print(root3)
    # Calculate tree kernel distances
    dist_1_2 = Parse_Tree.tree_kernel_distance(root1, root2)
    dist_1_3 = Parse_Tree.tree_kernel_distance(root1, root3)
    dist_2_3 = Parse_Tree.tree_kernel_distance(root2, root3)

    print("Tree kernel distance between tree 1 and tree 2:", dist_1_2)
    print("Tree kernel distance between tree 1 and tree 3:", dist_1_3)
    print("Tree kernel distance between tree 2 and tree 3:", dist_2_3)

    dist_1_1 = Parse_Tree.tree_kernel_distance(root1, root1)
    dist_2_1 = Parse_Tree.tree_kernel_distance(root2, root1)

    print(dist_1_1)
    assert dist_1_2 == dist_2_1, f"Expected distance between tree 1 and tree 2 to be equal to distance between tree 2 and tree 1, but got {dist_1_2} and {dist_2_1} respectively"


def generate_random_tensor(n, m):
    return torch.rand(n, m)


def test_construction():

    tree_dict = {
        "Concepts": [
            "Animal",
            "Mammal",
            "Wild",
            "Herbivore"
        ],
        "Substance": [
            "four legs",
            "two ears",
            "two eyes",
            "a nose",
            "a tail",
            "fur",
            "antlers (in males)"
        ],
        "Attributes": {
            "Head": {
                "shape": "elongated",
                "Nose": [
                    "moist",
                    "black"
                ],
                "Ears": [
                    "pointy",
                    "sensitive"
                ],
                "Eyes": [
                    "large",
                    "brown"
                ]
            },
        },
        "Environment": [
            "forests",
            "meadows",
            "grasslands"
        ]
    }

    tree = Parse_Tree.from_dict(tree_dict)
    print(tree)


def test_more():
    N = 2
    M = 2
    tree_dict = {
        "Concepts": [
            "Animal",
            "Mammal",
            "Wild",
            "Herbivore"
        ],
        "Substance": [
            "four legs",
            "two ears",
            "two eyes",
            "a nose",
            "a tail",
            "fur",
            "antlers (in males)"
        ],
        "Attributes": {
            "Head": {
                "shape": "elongated",
                "Nose": [
                    "moist",
                    "black"
                ],
                "Ears": [
                    "pointy",
                    "sensitive"
                ],
                "Eyes": [
                    "large",
                    "brown"
                ]
            },
        },
        "Environment": [
            "forests",
            "meadows",
            "grasslands"
        ]
    }
    values = {
        "Animal": generate_random_tensor(N, M),
        "Mammal": generate_random_tensor(N, M),
        "Wild": generate_random_tensor(N, M),
        "Herbivore": generate_random_tensor(N, M),
        "four legs": generate_random_tensor(N, M),
        "two ears": generate_random_tensor(N, M),
        "two eyes": generate_random_tensor(N, M),
        "a nose": generate_random_tensor(N, M),
        "a tail": generate_random_tensor(N, M),
        "fur": generate_random_tensor(N, M),
        "antlers (in males)": generate_random_tensor(N, M),
        "shape": generate_random_tensor(N, M),
        "elongated": generate_random_tensor(N, M),
        "moist": generate_random_tensor(N, M),
        "black": generate_random_tensor(N, M),
        "pointy": generate_random_tensor(N, M),
        "sensitive": generate_random_tensor(N, M),
        "large": generate_random_tensor(N, M),
        "brown": generate_random_tensor(N, M),
        "forests": generate_random_tensor(N, M),
        "meadows": generate_random_tensor(N, M),
        "grasslands": generate_random_tensor(N, M)
    }
    tree = Parse_Tree.from_dict(tree_dict)
    tree.set_values(values)
    print(tree)

    sample = generate_random_tensor(1, M)
    paths = tree.top_matches_nomerge(sample)
    paths_names = tree.top_matches(sample)
    print(paths_names)
    print(paths)
    max_depth = tree.max_depth()
    print(max_depth)
    new_paths = []
    for path in paths:
        new_paths.append(path + [0] * int(max_depth-len(path)))
        
    print(new_paths)
    print(torch.tensor(new_paths).shape)
    # print(new_dict)
    # root_converted = Parse_Tree.from_dict(new_dict)
    # print(root_converted)
    # # This will return True if the string representations are the same
    # print(path.__repr__() == root_converted.__repr__())


def test_maxdepth():
    # Create a Parse_Tree object called "tree"
    tree = Parse_Tree.from_dict(
        {"child1":
         {"grandchild1":
          "greatgrandchild1"
          }
         })
         
    print(tree)
    # Get the maximum depth of the tree
    max_depth = tree.max_depth()

    # Print the maximum depth
    print(max_depth)  # Output: 4
    
def test_ranomtree():
    N = 2
    M = 2
    tree_dict = {
        "Concepts": [
            "Animal",
            "Mammal",
            "Wild",
            "Herbivore"
        ],
        "Substance": [
            "four legs",
            "two ears",
            "two eyes",
            "a nose",
            "a tail",
            "fur",
            "antlers (in males)"
        ],
        "Attributes": {
            "Head": {
                "shape": "elongated",
                "Nose": [
                    "moist",
                    "black"
                ],
                "Ears": [
                    "pointy",
                    "sensitive"
                ],
                "Eyes": [
                    "large",
                    "brown"
                ]
            },
        },
        "Environment": [
            "forests",
            "meadows",
            "grasslands"
        ]
    }
    values = {
        "Animal": generate_random_tensor(N, M),
        "Mammal": generate_random_tensor(N, M),
        "Wild": generate_random_tensor(N, M),
        "Herbivore": generate_random_tensor(N, M),
        "four legs": generate_random_tensor(N, M),
        "two ears": generate_random_tensor(N, M),
        "two eyes": generate_random_tensor(N, M),
        "a nose": generate_random_tensor(N, M),
        "a tail": generate_random_tensor(N, M),
        "fur": generate_random_tensor(N, M),
        "antlers (in males)": generate_random_tensor(N, M),
        "shape": generate_random_tensor(N, M),
        "elongated": generate_random_tensor(N, M),
        "moist": generate_random_tensor(N, M),
        "black": generate_random_tensor(N, M),
        "pointy": generate_random_tensor(N, M),
        "sensitive": generate_random_tensor(N, M),
        "large": generate_random_tensor(N, M),
        "brown": generate_random_tensor(N, M),
        "forests": generate_random_tensor(N, M),
        "meadows": generate_random_tensor(N, M),
        "grasslands": generate_random_tensor(N, M)
    }
    # Create a Parse_Tree object called "tree"
    tree = Parse_Tree.from_dict(
        tree_dict)
         
    tree.set_values(values)
    # print(tree)
    # Get the maximum depth of the tree
    subtree = tree.random_subtree(k=2)

    # Print the maximum depth
    print(subtree)  # Output: 4


if __name__ == "__main__":
    # test_mcs_distance()
    # test_tree_kernel_distance()


    # test_more()
    test_ranomtree()
    # test_maxdepth()
