import random
from typing import Dict, List, Tuple
from collections import deque
from itertools import product

from tree import tree_pairwise_distances, build_tree_with_facilities, print_tree
from tree import TreeNode as GeneralTreeNode
from binary_tree import BinTreeNode, binary_tree_pairwise_distances_node_num, print_binary_tree
import numpy as np


# ---------------------------------------------------------------------------
# Step 1: Extract parent->children structure from the general tree
# ---------------------------------------------------------------------------

def _extract_children_from_tree(root: GeneralTreeNode) -> Dict[int, List[Tuple[int, float]]]:
    """
    Build a directed children dict from a GeneralTreeNode tree.

    children[u] = list of (v, w) meaning an edge u -> v of length w.
    """
    children: Dict[int, List[Tuple[int, float]]] = {}
    stack = [root]
    seen = set()
    while stack:
        node = stack.pop()
        if node in seen:
            continue
        seen.add(node)
        u = node.node_number
        if u not in children:
            children[u] = []
        for child, w in zip(node.children, node.children_distances):
            v = child.node_number
            children[u].append((v, float(w)))
            stack.append(child)
            if v not in children:
                children[v] = []
    return children


# ---------------------------------------------------------------------------
# Step 2: Binarization helper (add dummy nodes with zero edges)
# ---------------------------------------------------------------------------
class _BinaryTreeBuilder:
    """
    Convert an arbitrary rooted tree with children lists into a binary tree
    by adding dummy nodes of degree 2 with zero-length edges.

    Parameters
    ----------
    children : Dict[int, List[Tuple[int, float]]]
        For each node u, children[u] is a list of (v, w) pairs representing
        a directed edge u -> v of length w.
    root_id : int
        Id of the root node in the original tree.

    Notes
    -----
    The transformation preserves all distances between the *original* nodes:
    we only ever insert edges of length 0 and keep every original edge length
    intact. For a node with d children we introduce exactly max(0, d-2)
    dummy nodes, which is the minimum possible if we are not allowed to
    contract or delete original nodes.
    """

    def __init__(self, children: Dict[int, List[Tuple[int, float]]], root_id: int):
        self.original_children = children
        self.root_id = root_id
        self.new_children: Dict[int, List[Tuple[int, float]]] = {}
        # New dummy ids come after all existing node ids
        self.next_id: int = max(children.keys()) + 1 if children else 0

    def _ensure_node(self, u: int) -> None:
        """Ensure u appears as a key in new_children."""
        if u not in self.new_children:
            self.new_children[u] = []

    def _new_dummy(self) -> int:
        """Create a fresh dummy node id and register it."""
        u = self.next_id
        self.next_id += 1
        self.new_children[u] = []
        return u

    def _binarize_subtree(self, u: int) -> int:
        """
        Recursively binarize the subtree rooted at original node u.

        Returns
        -------
        int
            The node id in the new tree representing u (always u itself).
        """
        self._ensure_node(u)
        children = self.original_children.get(u, [])
        if not children:
            # Leaf: nothing to do
            return u

        # First binarize all children recursively
        processed: List[Tuple[int, float]] = []
        for v, w in children:
            new_child = self._binarize_subtree(v)
            processed.append((new_child, float(w)))

        k = len(processed)
        if k <= 2:
            # Already binary (or unary) at this node
            for child, w in processed:
                self.new_children[u].append((child, w))
            return u

        # k >= 3: we keep the first child directly under u
        first_child, first_w = processed[0]
        self.new_children[u].append((first_child, first_w))

        # The remaining k-1 children will be hung off a chain of (k-2) dummies:
        #
        #   u -- first_child
        #   u -- D1
        #   D1 -- child_2, D2
        #   D2 -- child_3, D3
        #   ...
        #   D_{k-2} -- child_{k-1}, child_k
        #
        # All edges u–D1 and Di–D{i+1} have length 0; edges from dummies to
        # real children keep the original edge lengths.
        remaining = processed[1:]
        num_dummies = k - 2
        dummies: List[int] = [self._new_dummy() for _ in range(num_dummies)]

        # Second child of u is the first dummy
        self.new_children[u].append((dummies[0], 0.0))

        # Internal dummies except the last one
        for i in range(num_dummies - 1):
            dummy = dummies[i]
            next_dummy = dummies[i + 1]
            child_id, child_w = remaining[i]
            # Attach a real child and the next dummy
            self.new_children[dummy].append((child_id, child_w))
            self.new_children[dummy].append((next_dummy, 0.0))

        # Last dummy has the last two children
        last_dummy = dummies[-1]
        left_child_id, left_w = remaining[-2]
        right_child_id, right_w = remaining[-1]
        self.new_children[last_dummy].append((left_child_id, left_w))
        self.new_children[last_dummy].append((right_child_id, right_w))

        return u

    def build(self) -> Dict[int, List[Tuple[int, float]]]:
        """
        Perform binarization starting from root_id.
        """
        self._binarize_subtree(self.root_id)
        return self.new_children

# ---------------------------------------------------------------------------
# Step 3: Main embedding function: general tree -> binary tree
# ---------------------------------------------------------------------------

def embed_tree_to_binary_tree(
    root: GeneralTreeNode,
    nodes: Dict[int, GeneralTreeNode],
):
    """
    Embed a general TreeNode tree into a BinTreeNode tree without distorting
    distances between the original nodes, while adding as few dummy nodes
    as possible for binarisation.

    Strategy
    --------
      1. Extract the original parent->children structure with edge lengths.
      2. Binarise this tree by adding dummy nodes with zero-length edges so
         that every node has at most 2 children. For a node with d children,
         we add exactly max(0, d-2) dummy nodes, which is optimal if we do
         not contract or delete original nodes.
      3. Build BinTreeNode objects and copy facility/client attributes from
         the original TreeNode objects onto the corresponding BinTreeNodes.

    All original clients and facilities are represented by BinTreeNodes with
    the *same* node_number. Newly created dummy nodes have node_numbers
    strictly larger than any original node_number and carry no demand/capacity.

    Parameters
    ----------
    root : TreeNode
        Root of the original tree.
    nodes : Dict[int, TreeNode]
        Mapping original node_number -> TreeNode.

    Returns
    -------
    bin_root : BinTreeNode
        Root of the embedded binary tree (same node_number as the original
        root).
    bin_nodes : Dict[int, BinTreeNode]
        Mapping binary-node-id -> BinTreeNode for all nodes (original +
        dummy).
    tree_to_bintree_idx : Dict[int, int]
        Mapping original node_number -> node_number of its representative in
        the binary tree. In this construction it is simply the identity
        mapping (u -> u).
    """
    # 1. Extract structure from the original tree
    children = _extract_children_from_tree(root)

    # 2. Binarise the tree (no auxiliary data-leaf layer)
    builder = _BinaryTreeBuilder(children=children, root_id=root.node_number)
    bin_children = builder.build()

    # Ensure every child appears as a key
    for u, ch_list in list(bin_children.items()):
        for v, _ in ch_list:
            if v not in bin_children:
                bin_children[v] = []

    # 3. Build BinTreeNode objects and set parent / children pointers
    bin_nodes: Dict[int, BinTreeNode] = {
        u: BinTreeNode(node_number=u) for u in bin_children.keys()
    }

    # Compute parent relationships with a DFS/BFS from the root
    root_id = root.node_number
    parent: Dict[int, int] = {}
    parent_dist: Dict[int, float] = {root_id: 0.0}
    stack = [root_id]
    parent[root_id] = -1
    while stack:
        u = stack.pop()
        for v, w in bin_children[u]:
            if v in parent:
                continue
            parent[v] = u
            parent_dist[v] = float(w)
            stack.append(v)

    # Initialise all BinTreeNodes with structural + attribute fields
    for u, node in bin_nodes.items():
        node.parent = bin_nodes[parent[u]] if parent[u] != -1 else None
        node.parent_id = parent[u]
        node.parent_distance = float(parent_dist.get(u, 0.0))
        node.left = None
        node.right = None
        node.left_distance = 0.0
        node.right_distance = 0.0
        node.capacity = 0
        node.facility_type = -1
        node.facility_type_vector = tuple()
        node.is_client = False
        node.is_leaf = False

    # Attach children according to bin_children
    for u, ch_list in bin_children.items():
        node_u = bin_nodes[u]
        if len(ch_list) > 2:
            raise RuntimeError("Binarisation failed: node has more than 2 children")
        if len(ch_list) >= 1:
            v0, w0 = ch_list[0]
            node_u.left = bin_nodes[v0]
            node_u.left_distance = float(w0)
        if len(ch_list) == 2:
            v1, w1 = ch_list[1]
            node_u.right = bin_nodes[v1]
            node_u.right_distance = float(w1)

    # Mark leaves and copy client/facility attributes onto original nodes.
    orig_ids = set(nodes.keys())

    for u, node in bin_nodes.items():
        node.is_leaf = (node.left is None and node.right is None)
        if u in orig_ids:
            orig_node = nodes[u]
            node.is_client = getattr(orig_node, "is_client", False)
            node.capacity = getattr(orig_node, "capacity", 0)
            node.facility_type_vector = getattr(orig_node, "facility_type_vector", tuple())
            node.facility_type = getattr(orig_node, "facility_type", -1)
        else:
            # Structural / dummy node: no demand/capacity
            node.is_client = False
            node.capacity = 0
            node.facility_type_vector = tuple()
            node.facility_type = -1

    # Identity mapping: each original node u is represented by node_number u
    tree_to_bintree_idx: Dict[int, int] = {u: u for u in orig_ids}

    return bin_nodes[root_id], bin_nodes, tree_to_bintree_idx

def check_embedding_correctness(
    tree_root: GeneralTreeNode,
    binary_tree_root: BinTreeNode,
    tree_to_bintree_idx: Dict[int, int],
    num_pairs: int = 50,
    tol: float = 1e-9,
) -> None:
    """
    Sample random pairs of original nodes and check that their distances
    (in the original tree and in the embedded binary tree between the
    corresponding leaves) match up to 'tol'.

    This is intended as a sanity check for the embedding.
    """

    orig_ids = list(tree_to_bintree_idx.keys())
    if len(orig_ids) < 2:
        print("Not enough nodes to test distances.")
        return

    tree_node_ids, tree_dist_matrix = tree_pairwise_distances(tree_root)

    binary_tree_ids, binary_tree_dist_matrix = binary_tree_pairwise_distances_node_num(binary_tree_root)

    checked = 0
    missmatch_found = False
    for u, v in product(orig_ids, orig_ids):
        dist_orig = tree_dist_matrix[u][v]
        bin_u = tree_to_bintree_idx[u]
        bin_v = tree_to_bintree_idx[v]
        dist_bin = binary_tree_dist_matrix[bin_u][bin_v]
        if abs(dist_orig - dist_bin) > tol:
            print(f"Mismatch for pair ({u}, {v}): original={dist_orig}, binary={dist_bin}")
            missmatch_found = True
        checked += 1

    if missmatch_found:
        print("Distance mismatches found.")
    else:
        print(f"Checked {checked} random pairs: all distances match within tolerance {tol}.")


# Example usage and testing
def test_stub():
    n = 20
    t = 4
    facility_prob = 0.4
    max_facility_types = t
    seed = 123456789

    tree_root, tree_nodes = build_tree_with_facilities(n, t, facility_prob, max_facility_types, seed)
    bin_root, bin_nodes, tree_to_bintree_idx = embed_tree_to_binary_tree(tree_root, tree_nodes)

    check_embedding_correctness(tree_root, bin_root, tree_to_bintree_idx)
    print("Tree num nodes =", len(tree_nodes))
    print("Binary tree num nodes =", len(bin_nodes))


if    __name__ == "__main__":
    test_stub()
