
from xnas.search_space_common import summarize_search_space_arch
from xnas.search_space_common import nas_bench_201_search_space


def create_search_space(
        input_shape=(12, 12, 1), output_shape=(4,),
        filter_scale_factor=1,
        *args, **kwargs
):
    if isinstance(input_shape[0], int):
        
        
        input_shape = [input_shape]

    return nas_bench_201_search_space(
        input_shape,
        output_shape,
        regression=False,
        n_nodes=4,
        stack_size=2,
        filter_scale_factor=filter_scale_factor,
    )


def test_create_search_space(*args):
    
    from random import random

    if args:
        ops = [int(arg) for arg in args]
        num_layers = len(ops)
        print('parsed:', ops)
    else:  
        num_layers = 3
        ops = [random() for _ in range(num_layers)]

    search_space = create_search_space(num_layers=num_layers)
    summarize_search_space_arch(search_space, ops)


if __name__ == '__main__':
    import sys

    test_create_search_space(*sys.argv[1:])
