Source code for archai.algos.gumbelsoftmax.gs_model_desc_builder

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Iterable, Sequence, Tuple, List, Set, Optional
import copy

from overrides import overrides

from archai.common.common import get_conf
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas.operations import Op, MultiOp
from archai.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
                                  EdgeDesc, TensorShape, TensorShapes, NodeDesc, ModelDesc
from archai.common.config import Config
from .gs_op import GsOp

[docs]class GsModelDescBuilder(ModelDescBuilder):
[docs] @overrides def pre_build(self, conf_model_desc:Config)->None: Op.register_op('gs_op', lambda op_desc, arch_params, affine: GsOp(op_desc, arch_params, affine))
[docs] @overrides def build_nodes(self, stem_shapes:TensorShapes, conf_cell:Config, cell_index:int, cell_type:CellType, node_count:int, in_shape:TensorShape, out_shape:TensorShape) \ ->Tuple[TensorShapes, List[NodeDesc]]: assert in_shape[0]==out_shape[0] reduction = (cell_type==CellType.Reduction) nodes:List[NodeDesc] = [] conv_params = ConvMacroParams(in_shape[0], out_shape[0]) gs_num_sample = conf_cell['gs']['num_sample'] # add gs op for each edge for i in range(node_count): edges=[] for j in range(i+2): op_desc = OpDesc('gs_op', params={ 'conv': conv_params, 'stride': 2 if reduction and j < 2 else 1, 'gs_num_sample': gs_num_sample }, in_len=1, trainables=None, children=None) edge = EdgeDesc(op_desc, input_ids=[j]) edges.append(edge) nodes.append(NodeDesc(edges=edges, conv_params=conv_params)) out_shapes = [copy.deepcopy(out_shape) for _ in range(node_count)] return out_shapes, nodes