"""Stuff related to expression trees."""


def generate_bin_dist(max_ops: int):
    """
    `max_ops`: maximum number of operators
    Enumerate the number of possible binary trees that can be generated from empty nodes.
    D[e][n] represents the number of different binary trees with n nodes that
    can be generated from e empty nodes, using the following recursion:
        D(0, n) = 0
        D(1, n) = C_n (n-th Catalan number)
        D(e, n) = D(e - 1, n + 1) - D(e - 2, n + 1)
    """
    # initialize Catalan numbers
    catalans = [1]
    for i in range(1, 2 * max_ops + 1):
        catalans.append((4 * i - 2) * catalans[i - 1] // (i + 1))

    # enumerate possible trees
    D = []
    for e in range(max_ops + 1):  # number of empty nodes
        s = []
        for n in range(2 * max_ops - e + 1):  # number of operators
            if e == 0:
                s.append(0)
            elif e == 1:
                s.append(catalans[n])
            else:
                s.append(D[e - 1][n + 1] - D[e - 2][n + 1])
        D.append(s)
    return D


def generate_ubi_dist(max_ops: int, nl: int, p1: int, p2: int):
    """
    `max_ops`: maximum number of operators
    Enumerate the number of possible unary-binary trees that can be generated from empty nodes.
    D[e][n] represents the number of different binary trees with n nodes that
    can be generated from e empty nodes, using the following recursion:
        D(0, n) = 0
        D(e, 0) = L ** e
        D(e, n) = L * D(e - 1, n) + p_1 * D(e, n - 1) + p_2 * D(e + 1, n - 1)
    """
    # enumerate possible trees
    # first generate the tranposed version of D, then transpose it
    D = []
    D.append([0] + ([nl ** i for i in range(1, 2 * max_ops + 1)]))
    for n in range(1, 2 * max_ops + 1):  # number of operators
        s = [0]
        for e in range(1, 2 * max_ops - n + 1):  # number of empty nodes
            s.append(nl * s[e - 1] + p1 * D[n - 1][e] + p2 * D[n - 1][e + 1])
        D.append(s)
    assert all(len(D[i]) >= len(D[i + 1]) for i in range(len(D) - 1))
    D = [[D[j][i] for j in range(len(D)) if i < len(D[j])] for i in range(max(len(x) for x in D))]
    return D
