
from xnas.search_space_common import nas_bench_201_search_space
from xnas.search_space_common import summarize_search_space_arch


def create_search_space(
        input_shape=(16, 16, 3), output_shape=(120,),
        filter_scale_factor=1,
        *args, **kwargs
):
    if isinstance(input_shape[0], int):
        
        
        input_shape = [input_shape]

    return nas_bench_201_search_space(
        input_shape,
        output_shape,
        regression=False,
        n_nodes=4,
        stack_size=5,
        filter_scale_factor=filter_scale_factor,
    )


def test_create_search_space(*args):
    
    from random import random

    search_space = create_search_space()

    if args:
        ops = [int(arg) for arg in args]
        print('parsed:', ops)
    else:  
        ops = [random() for _ in range(search_space.num_nodes)]

    summarize_search_space_arch(search_space, ops)


if __name__ == '__main__':
    import sys

    test_create_search_space(*sys.argv[1:])
