from dataclasses import dataclass, field
from typing import Dict, List, Set, Tuple
import itertools, hashlib, random

@dataclass
class Node:
    nid: str
    children: List[str] = field(default_factory=list)
    leaves: Set[int] = field(default_factory=set)   # canonical leaf IDs under this node
    depth: int = 0

class PrefixDAG:
    def __init__(self):
        self.nodes: Dict[str, Node] = {}

    def add_node(self, nid: str, depth:int=0):
        if nid not in self.nodes:
            self.nodes[nid] = Node(nid=nid, depth=depth)

    def add_edge(self, parent: str, child: str):
        self.nodes[parent].children.append(child)

    def N(self, nid: str) -> int:
        return len(self.nodes[nid].leaves)

    def child_partition_ok(self, nid:str) -> bool:
        parts = [self.nodes[c].leaves for c in self.nodes[nid].children]
        # pairwise disjoint and union equals parent set
        disjoint = all(parts[i].isdisjoint(parts[j]) for i in range(len(parts)) for j in range(i+1,len(parts)))
        union_ok = set().union(*parts) == self.nodes[nid].leaves
        return disjoint and union_ok

    def add_child(self, parent, child):
        self.nodes[parent].children.append(child)

    def set_leaves(self, nid, leaves):
        self.nodes[nid].leaves = list(leaves)

def build_demo_prefix_dag(D: int=3, B: int=3, seed:int=42) -> PrefixDAG:
    """
    Build a small context-indexed prefix–DAG with child partition.
    Leaves are integers; parent leaf-set is partitioned by children.
    """
    random.seed(seed)
    g = PrefixDAG()
    g.add_node("root", depth=0)
    # create full B-ary tree of depth D, then merge some sibling contexts into same state IDs (but keep distinct contexts via nid string)
    # for simplicity, keep DAG tree-shaped but with context IDs, so partition holds exactly.
    next_id = 0
    frontier = ["root"]
    for d in range(D):
        new_frontier=[]
        for p in frontier:
            # partition parent leaves into B parts
            if d==0:
                # assign a fresh leaf universe under root at depth D
                total_leaves = B**D
                all_leaves = list(range(total_leaves))
                random.shuffle(all_leaves)
                chunks = [set(all_leaves[i::B]) for i in range(B)]
            else:
                # split parent's leaves deterministically
                pl = list(g.nodes[p].leaves); random.shuffle(pl)
                chunks = [set(pl[i::B]) for i in range(B)]
            # children
            for b in range(B):
                nid = f"v_{d+1}_{p}_{b}"
                g.add_node(nid, depth=d+1)
                g.add_edge(p, nid)
                g.nodes[nid].leaves = chunks[b]
                new_frontier.append(nid)
            # parent leaves = union children
            g.nodes[p].leaves = set().union(*(g.nodes[c].leaves for c in g.nodes[p].children))
        frontier = new_frontier
    # sanity: partition property on all internals
    for nid, node in g.nodes.items():
        if node.children:
            assert g.child_partition_ok(nid), f"Partition fails at {nid}"
    return g

def leaves_under(g: PrefixDAG, nid: str):
    return g.nodes[nid].leaves

# --- helper: build a balanced prefix-DAG with child partition
def build_balanced_suiteA(D=3, B=3, seed=0):
    random.seed(seed)
    g = PrefixDAG()
    g.add_node("root", 0)
    # create a complete B-ary tree of depth D; leaves at depth D
    # allocate distinct leaf IDs to satisfy partition property
    total_leaves = B**D
    leaf_ids = list(range(total_leaves))
    def build(nid, depth, ids):
        g.set_leaves(nid, ids if depth==D else [])
        if depth==D: return
        size = len(ids)//B
        for b in range(B):
            sub = ids[b*size:(b+1)*size]
            cid = f"v_{depth+1}_{nid}_{b}"
            g.add_node(cid, depth+1)
            g.add_child(nid, cid)
            build(cid, depth+1, sub)
    build("root", 0, leaf_ids)
    return g

# --- helper: random prefix-DAG with partition (Suite B)
def build_random_suiteB(layers=3, B=3, seed=0):
    random.seed(seed)
    g = PrefixDAG()
    g.add_node("root", 0)
    # start with L leaves at root; we choose L = B^layers
    total_leaves = B**layers
    leaf_ids = list(range(total_leaves))
    g.set_leaves("root", leaf_ids)
    frontier = ["root"]
    for d in range(1, layers+1):
        new_frontier=[]
        for p in frontier:
            ids = leaves_under(g, p)
            random.shuffle(ids)
            chunks = [ids[i::B] for i in range(B)]  # partition
            for b, sub in enumerate(chunks):
                cid = f"v_{d}_{p}_{b}"
                g.add_node(cid, d)
                g.add_child(p, cid)
                g.set_leaves(cid, sub)
                new_frontier.append(cid)
        frontier = new_frontier
    return g

# --- tiny toolchain DAG resembling retrieval→summary→calc
def build_toolchain_suiteC(seed=0):
    random.seed(seed)
    # re-use balanced but rename to suggest tools
    g = build_balanced_suiteA(D=3, B=3, seed=seed)
    # labels (optional): retrieval at depth1, summary depth2, calc depth3
    return g
