import numpy as np
from collections import defaultdict

def get_dyadic_cover(start, end, sort_output=False):
    ''' Returns an output list of strings that represent dyadic intervals that cover [start, end]

    Parameters:
    start: Start of interval (0-based)
    end: End of interval (0-based)
    sort_output: Optional, default: False. True if the output dyadic covering should be returned in order.

    Returns:
    List of binary strings that terminate with zero or more '*'. Use 'interpret_nodestring' to
        translate a given string to a dyadic interval.
    '''

    output = []
    if start == end:
        # The cover is just the single leaf
        b_str = np.binary_repr(end)
        output.append(b_str)
    # In the special case where start = 0, we are just computing the same nodes as the binary mechanism
    elif start == 0:
        b = end + 1
        b_str = np.binary_repr(b)
        n = len(b_str)

        for j in range(len(b_str)):
            if b_str[j] == '1':
                output.append(b_str[:j] + '0' + (n - j-1) * '*')
    else:
        # And, when constructing the intervals we need to visit the leaf before and after our range
        a = start - 1
        b = end + 1

        # Now we construct the binary representation of each
        b_str = np.binary_repr(b)
        n = len(b_str)
        a_str = np.binary_repr(a, width=len(b_str))

        # Find common prefix
        i = -1
        while a_str[i+1] == b_str[i+1]:
            i += 1

        # From this index, a and b differ
        start_index = i + 1

        # Find all intervals
        for j in range(start_index + 1, n):
            if a_str[j] == '0':
                output.append(a_str[:j] + '1' + (n - j - 1) * '*')
            if b_str[j] == '1':
                output.append(b_str[:j] + '0' + (n - j - 1) * '*')
        
        # Allows us to return the nodes in order of expansion left-right
        if sort_output:
            output.sort()
        
    return output

def get_dyadic_cover_levels(start, end):
    ''' Returns an dictionary of the dyadic levels used to cover the interval [start, end].

    Parameters:
    start: Start of interval (0-based)
    end: End of interval (0-based)

    Returns:
    A dictionary that maps level (leaves having level 0) to their occurrence in the cover.
    '''
    output = defaultdict(int)
    for n in get_dyadic_cover(start, end):
        output[n.count('*')] += 1 
    return output

def interpret_nodestring(node_string):
    ''' Takes a string representation of a dyadic interval and outputs integers (start, end)

    Parameters:
    node_string: A binary string with zero or more '*' appended to it

    Returns:
    A tuple of integers (start, end) representing the dyadic interval
    '''
    return int(node_string.replace('*', '0'), 2), int(node_string.replace('*', '1'), 2)


# Just a sanity check that verifies that we compute correct coverings
def _sanity_check(start, end):
    next_to_cover = start
    # Iterate through the _SORTED_ decomposition
    print(f"Covering [{start}, {end}]")
    for n in get_dyadic_cover(start, end, sort_output=True):
        x, y = interpret_nodestring(n)
        print(f"[{x} , {y}]")
        assert x == next_to_cover, f"{x} != {next_to_cover}"
        next_to_cover = y + 1
    
    assert end == next_to_cover - 1

if __name__ == '__main__':
    # Run the sanity check. Can we correctly compute all coverings on sub-intervals in [0, 100]?
    for end in range(10):
        for start in range(end+1):
            print(f"Testing covering {start}-{end}")
            _sanity_check(start, end)
    print('Passed!')