from htssr.primitives import (
    uops,
    bops,
    bijection,
    zops,
    nesting_rules,
    complexity_rules,
)
from htssr.utils import (
    ids2key,
    crude_ids2tree,
    tree2ids,
    size_lex_order,
    root_check_nesting,
    tree_check_complexity,
)

max_canon_size = 32

def size_enumeration(max_size, domain):
    # fringe = [[0], [1], [2]]
    fringe = [[_id] for _id in zops] # size == 1
    all_ids = {ids2key(ids, domain): ids for ids in fringe}
    # size_ptrs = {1: (0, 2)} # Inclusive
    size_ptrs = {1: (0, len(fringe) - 1)} # Inclusive
    for size in range(2, max_size + 1):
        # Unary
        usubsize = size - 1
        ustart, uend = size_ptrs[usubsize]
        fringe_update = {}
        for upos in range(ustart, uend + 1):
            uids = fringe[upos]
            for uop in uops:
                new_ids = [uop] + uids
                ###
                # if not root_check_nesting(new_ids, nesting_rules):
                #     continue
                if not tree_check_complexity(new_ids, complexity_rules):
                    continue
                ###
                key = ids2key(new_ids, domain)
                if key in all_ids:
                    order_test = size_lex_order(all_ids[key], new_ids)
                    if order_test:
                        continue
                # fringe.append(new_ids)
                fringe_update[key] = new_ids
                all_ids[key] = new_ids
        # Binary
        for b0subsize in range(1, size - 1):
            b1subsize = (size - 1) - b0subsize
            b0start, b0end = size_ptrs[b0subsize]
            b1start, b1end = size_ptrs[b1subsize]
            for b0pos in range(b0start, b0end + 1):
                for b1pos in range(b1start, b1end + 1):
                    b0ids = fringe[b0pos]
                    b1ids = fringe[b1pos]
                    for bop in bops:
                        new_ids = [bop] + b0ids + b1ids
                        ###
                        # if not root_check_nesting(new_ids, nesting_rules):
                        #     continue
                        if not tree_check_complexity(new_ids, complexity_rules):
                            continue
                        ###
                        key = ids2key(new_ids, domain)
                        if key in all_ids:
                            order_test = size_lex_order(all_ids[key], new_ids)
                            if order_test:
                                continue
                        # fringe.append(new_ids)
                        fringe_update[key] = new_ids
                        all_ids[key] = new_ids
        fringe_update = sorted([updt for _, updt in fringe_update.items()])
        for updt in fringe_update:
            fringe.append(updt)
        # Update size pointers
        size_ptrs[size] = (size_ptrs[size - 1][1] + 1, len(fringe) - 1)
    return fringe, all_ids, size_ptrs
