archai.nas package

Submodules

archai.nas.arch_module module

class archai.nas.arch_module.ArchModule[source]

Bases: torch.nn.modules.module.Module, abc.ABC, overrides.enforce.EnforceOverrides

ArchModule enahnces nn.Module by making a clear separation between regular weights and the architecture weights. The architecture parameters can be added using create_arch_params() method and then accessed using arch_params() method.

all_owned()archai.nas.arch_params.ArchParams[source]
arch_params(recurse=False, only_owned=False)archai.nas.arch_params.ArchParams[source]
create_arch_params(named_params: Iterable[Tuple[str, Union[torch.nn.parameter.Parameter, torch.nn.modules.container.ParameterDict, torch.nn.modules.container.ParameterList]]]) → None[source]
nonarch_params(recurse: bool) → Iterator[torch.nn.parameter.Parameter][source]
set_arch_params(arch_params: archai.nas.arch_params.ArchParams) → None[source]

archai.nas.arch_params module

class archai.nas.arch_params.ArchParams(arch_params: Iterable[Tuple[str, Union[torch.nn.parameter.Parameter, torch.nn.modules.container.ParameterDict, torch.nn.modules.container.ParameterList]]], registrar: Optional[torch.nn.modules.module.Module] = None)[source]

Bases: collections.UserDict

This class holds set of learnable architecture parameter(s) for a given module. For example, one instance of this class would hold alphas for one instance of MixedOp. For sharing parameters, instance of this class can be passed around. Different algorithms may add learnable parameters for their need.

static empty()archai.nas.arch_params.ArchParams[source]
static from_module(module: torch.nn.modules.module.Module, recurse: bool = False)archai.nas.arch_params.ArchParams[source]
has_kind(kind: str) → bool[source]
static nonarch_from_module(module: torch.nn.modules.module.Module, recurse: bool = False) → Iterator[torch.nn.parameter.Parameter][source]
param_by_kind(kind: Optional[str]) → Iterator[torch.nn.parameter.Parameter][source]
paramdict_by_kind(kind: Optional[str]) → Iterator[torch.nn.modules.container.ParameterDict][source]
paramlist_by_kind(kind: Optional[str]) → Iterator[torch.nn.modules.container.ParameterList][source]

archai.nas.arch_trainer module

class archai.nas.arch_trainer.ArchTrainer(conf_train: archai.common.config.Config, model: archai.nas.model.Model, checkpoint: Optional[archai.common.checkpoint.CheckPoint])[source]

Bases: archai.common.trainer.Trainer, overrides.enforce.EnforceOverrides

compute_loss(lossfn: Callable, y: torch.Tensor, logits: torch.Tensor, aux_weight: float, aux_logits: Optional[torch.Tensor]) → torch.Tensor[source]
post_epoch(train_dl: torch.utils.data.dataloader.DataLoader, val_dl: Optional[torch.utils.data.dataloader.DataLoader]) → None[source]

archai.nas.cell module

class archai.nas.cell.Cell(desc: archai.nas.model_desc.CellDesc, affine: bool, droppath: bool, trainables_from: Optional[Cell])[source]

Bases: archai.nas.arch_module.ArchModule, overrides.enforce.EnforceOverrides

forward(s0, s1)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ops() → Iterable[archai.nas.operations.Op][source]
archai.nas.cell.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False) → Tensor

Constructs a tensor with data.

Warning

torch.tensor() always copies data. If you have a Tensor data and want to avoid a copy, use torch.Tensor.requires_grad_() or torch.Tensor.detach(). If you have a NumPy ndarray and want to avoid a copy, use torch.as_tensor().

Warning

When data is a tensor x, torch.tensor() reads out ‘the data’ from whatever it is passed, and constructs a leaf variable. Therefore torch.tensor(x) is equivalent to x.clone().detach() and torch.tensor(x, requires_grad=True) is equivalent to x.clone().detach().requires_grad_(True). The equivalents using clone() and detach() are recommended.

Args:
data (array_like): Initial data for the tensor. Can be a list, tuple,

NumPy ndarray, scalar, and other types.

dtype (torch.dtype, optional): the desired data type of returned tensor.

Default: if None, infers data type from data.

device (torch.device, optional): the desired device of returned tensor.

Default: if None, uses the current device for the default tensor type (see torch.set_default_tensor_type()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.

requires_grad (bool, optional): If autograd should record operations on the

returned tensor. Default: False.

pin_memory (bool, optional): If set, returned tensor would be allocated in

the pinned memory. Works only for CPU tensors. Default: False.

Example:

>>> torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
tensor([[ 0.1000,  1.2000],
        [ 2.2000,  3.1000],
        [ 4.9000,  5.2000]])

>>> torch.tensor([0, 1])  # Type inference on data
tensor([ 0,  1])

>>> torch.tensor([[0.11111, 0.222222, 0.3333333]],
                 dtype=torch.float64,
                 device=torch.device('cuda:0'))  # creates a torch.cuda.DoubleTensor
tensor([[ 0.1111,  0.2222,  0.3333]], dtype=torch.float64, device='cuda:0')

>>> torch.tensor(3.14159)  # Create a scalar (zero-dimensional tensor)
tensor(3.1416)

>>> torch.tensor([])  # Create an empty tensor (of size (0,))
tensor([])

archai.nas.dag_edge module

class archai.nas.dag_edge.DagEdge(desc: archai.nas.model_desc.EdgeDesc, affine: bool, droppath: bool, template_edge: Optional[DagEdge])[source]

Bases: archai.nas.arch_module.ArchModule

forward(inputs: List[torch.Tensor])[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

op()archai.nas.operations.Op[source]

archai.nas.evaluater module

class archai.nas.evaluater.EvalResult(train_metrics: archai.common.metrics.Metrics)[source]

Bases: object

class archai.nas.evaluater.Evaluater[source]

Bases: overrides.enforce.EnforceOverrides

create_model(conf_eval: archai.common.config.Config, model_desc_builder: archai.nas.model_desc_builder.ModelDescBuilder, final_desc_filename=None, full_desc_filename=None) → torch.nn.modules.module.Module[source]
evaluate(conf_eval: archai.common.config.Config, model_desc_builder: archai.nas.model_desc_builder.ModelDescBuilder)archai.nas.evaluater.EvalResult[source]
get_data(conf_loader: archai.common.config.Config) → Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader][source]
model_from_desc(model_desc)archai.nas.model.Model[source]
train_model(conf_train: archai.common.config.Config, model: torch.nn.modules.module.Module, checkpoint: Optional[archai.common.checkpoint.CheckPoint])archai.common.metrics.Metrics[source]

archai.nas.exp_runner module

class archai.nas.exp_runner.ExperimentRunner(config_filename: str, base_name: str, clean_expdir=False)[source]

Bases: abc.ABC, overrides.enforce.EnforceOverrides

copy_search_to_eval() → None[source]
evaluater()archai.nas.evaluater.Evaluater[source]
finalizers()archai.nas.finalizers.Finalizers[source]
get_conf(is_search_or_eval: bool)archai.common.config.Config[source]
get_expname(is_search_or_eval: bool) → str[source]
model_desc_builder() → Optional[archai.nas.model_desc_builder.ModelDescBuilder][source]
run(search=True, eval=True) → Tuple[Optional[archai.nas.searcher.SearchResult], Optional[archai.nas.evaluater.EvalResult]][source]
run_eval(conf_eval: archai.common.config.Config)archai.nas.evaluater.EvalResult[source]
searcher()archai.nas.searcher.Searcher[source]
abstract trainer_class() → Optional[Type[archai.nas.arch_trainer.ArchTrainer]][source]

archai.nas.finalizers module

class archai.nas.finalizers.Finalizers[source]

Bases: overrides.enforce.EnforceOverrides

Provides base algorithms for finalizing model, cell and edge which can be overriden

For op-level finalize, just put logic in op’s finalize.

For model/cell/edge level finalize, you can override the methods in this class to customize the behavior. To override any of these methods, simply create new class in your algos folder, for example, diversity/diversity_finalizers.py. In this file create class that derives from Finalizers. Then in your algos exp_runner.py, return instance of that class in its finalizers() method.

finalize_cell(cell: archai.nas.cell.Cell, cell_index: int, model_desc: archai.nas.model_desc.ModelDesc, *args, **kwargs)archai.nas.model_desc.CellDesc[source]
finalize_cells(model: archai.nas.model.Model) → List[archai.nas.model_desc.CellDesc][source]
finalize_edge(edge) → Tuple[archai.nas.model_desc.EdgeDesc, Optional[float]][source]
finalize_model(model: archai.nas.model.Model, to_cpu=True, restore_device=True)archai.nas.model_desc.ModelDesc[source]
finalize_node(node: torch.nn.modules.container.ModuleList, node_index: int, node_desc: archai.nas.model_desc.NodeDesc, max_final_edges: int, *args, **kwargs)archai.nas.model_desc.NodeDesc[source]
get_edge_ranks(node: torch.nn.modules.container.ModuleList) → Tuple[List[archai.nas.model_desc.EdgeDesc], List[Tuple[archai.nas.model_desc.EdgeDesc, float]]][source]
select_edges(edge_desc_ranks: List[Tuple[archai.nas.model_desc.EdgeDesc, float]], max_final_edges: int) → List[archai.nas.model_desc.EdgeDesc][source]

archai.nas.model module

class archai.nas.model.AuxTower(aux_tower_desc: archai.nas.model_desc.AuxTowerDesc)[source]

Bases: torch.nn.modules.module.Module

forward(x: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.model.Model(model_desc: archai.nas.model_desc.ModelDesc, droppath: bool, affine: bool)[source]

Bases: archai.nas.arch_module.ArchModule

device_type() → str[source]
drop_path_prob(p: float)[source]

Set drop path probability This will be called externally so any DropPath_ modules get new probability. Typically, every epoch we will reduce this probability.

forward(x) → Tuple[torch.Tensor, Optional[torch.Tensor]][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

ops() → Iterable[archai.nas.operations.Op][source]
summary() → dict[source]

archai.nas.model_desc module

class archai.nas.model_desc.AuxTowerDesc(ch_in: int, n_classes: int, stride: int)[source]

Bases: object

class archai.nas.model_desc.CellDesc(id: int, cell_type: archai.nas.model_desc.CellType, conf_cell: archai.common.config.Config, stems: List[archai.nas.model_desc.OpDesc], stem_shapes: List[List[int]], nodes: List[archai.nas.model_desc.NodeDesc], node_shapes: List[List[int]], post_op: archai.nas.model_desc.OpDesc, out_shape: List[int], trainables_from: int)[source]

Bases: object

all_empty() → bool[source]
all_full() → bool[source]
clear_trainables() → None[source]
clone(id: int)archai.nas.model_desc.CellDesc[source]
load_state_dict(state_dict) → None[source]
nodes() → List[archai.nas.model_desc.NodeDesc][source]
reset_nodes(nodes: List[archai.nas.model_desc.NodeDesc], node_shapes: List[List[int]], post_op: archai.nas.model_desc.OpDesc, out_shape: List[int]) → None[source]
state_dict() → dict[source]
class archai.nas.model_desc.CellType[source]

Bases: enum.Enum

An enumeration.

Reduction = 'reduction'
Regular = 'regular'
class archai.nas.model_desc.ConvMacroParams(ch_in: int, ch_out: int)[source]

Bases: object

Holds parameters that may be altered by macro architecture

clone()archai.nas.model_desc.ConvMacroParams[source]
class archai.nas.model_desc.EdgeDesc(op_desc: archai.nas.model_desc.OpDesc, input_ids: List[int])[source]

Bases: object

Edge description between two nodes in the cell

clear_trainables() → None[source]
clone(conv_params: Optional[archai.nas.model_desc.ConvMacroParams], clear_trainables: bool)archai.nas.model_desc.EdgeDesc[source]
load_state_dict(state_dict) → None[source]
state_dict() → dict[source]
class archai.nas.model_desc.ModelDesc(conf_model_desc: archai.common.config.Config, model_stems: List[archai.nas.model_desc.OpDesc], pool_op: archai.nas.model_desc.OpDesc, cell_descs: List[archai.nas.model_desc.CellDesc], aux_tower_descs: List[Optional[archai.nas.model_desc.AuxTowerDesc]], logits_op: archai.nas.model_desc.OpDesc)[source]

Bases: object

all_empty() → bool[source]
all_full() → bool[source]
cell_descs() → List[archai.nas.model_desc.CellDesc][source]
cell_type_count(cell_type: archai.nas.model_desc.CellType) → int[source]
clear_trainables() → None[source]
clone()archai.nas.model_desc.ModelDesc[source]
has_aux_tower() → bool[source]
static load(filename: str, load_trainables=False)archai.nas.model_desc.ModelDesc[source]
load_state_dict(state_dict) → None[source]
reset_cells(cell_descs: List[archai.nas.model_desc.CellDesc], aux_tower_descs: List[Optional[archai.nas.model_desc.AuxTowerDesc]]) → None[source]
save(filename: str, save_trainables=False) → Optional[str][source]
state_dict() → dict[source]
class archai.nas.model_desc.NodeDesc(edges: List[archai.nas.model_desc.EdgeDesc], conv_params: archai.nas.model_desc.ConvMacroParams)[source]

Bases: object

clear_trainables() → None[source]
clone()[source]
load_state_dict(state_dict) → None[source]
state_dict() → dict[source]
class archai.nas.model_desc.OpDesc(name: str, params: dict, in_len: int, trainables: Optional[Mapping], children: Optional[List[OpDesc]] = None, children_ins: Optional[List[int]] = None)[source]

Bases: object

Op description that is in each edge

clear_trainables() → None[source]
clone(clone_trainables=True)archai.nas.model_desc.OpDesc[source]
load_state_dict(state_dict) → None[source]
state_dict() → dict[source]

archai.nas.model_desc_builder module

class archai.nas.model_desc_builder.ModelDescBuilder[source]

Bases: overrides.enforce.EnforceOverrides

build(conf_model_desc: archai.common.config.Config, template: Optional[archai.nas.model_desc.ModelDesc] = None)archai.nas.model_desc.ModelDesc[source]
build_aux_tower(out_shape: List[int], conf_model_desc: archai.common.config.Config, cell_index: int) → Optional[archai.nas.model_desc.AuxTowerDesc][source]
build_cell(in_shapes: List[List[List[int]]], conf_cell: archai.common.config.Config, cell_index: int)archai.nas.model_desc.CellDesc[source]
build_cell_post_op(stem_shapes: List[List[int]], node_shapes: List[List[int]], conf_cell: archai.common.config.Config, cell_index: int) → Tuple[List[int], archai.nas.model_desc.OpDesc][source]
build_cell_stems(in_shapes: List[List[List[int]]], conf_cell: archai.common.config.Config, cell_index: int) → Tuple[List[List[int]], List[archai.nas.model_desc.OpDesc]][source]
build_cells(in_shapes: List[List[List[int]]], conf_model_desc: archai.common.config.Config) → Tuple[List[archai.nas.model_desc.CellDesc], List[Optional[archai.nas.model_desc.AuxTowerDesc]]][source]
build_logits_op(in_shapes: List[List[List[int]]], conf_model_desc: archai.common.config.Config)archai.nas.model_desc.OpDesc[source]
build_model_pool(in_shapes: List[List[List[int]]], conf_model_desc: archai.common.config.Config)archai.nas.model_desc.OpDesc[source]
build_model_stems(in_shapes: List[List[List[int]]], conf_model_desc: archai.common.config.Config) → List[archai.nas.model_desc.OpDesc][source]
build_nodes(stem_shapes: List[List[int]], conf_cell: archai.common.config.Config, cell_index: int, cell_type: archai.nas.model_desc.CellType, node_count: int, in_shape: List[int], out_shape: List[int]) → Tuple[List[List[int]], List[archai.nas.model_desc.NodeDesc]][source]
build_nodes_from_template(stem_shapes: List[List[int]], conf_cell: archai.common.config.Config, cell_index: int) → Tuple[List[List[int]], List[archai.nas.model_desc.NodeDesc]][source]
create_cell_templates(template: Optional[archai.nas.model_desc.ModelDesc]) → List[Optional[archai.nas.model_desc.CellDesc]][source]
get_cell_template(cell_index: int) → Optional[archai.nas.model_desc.CellDesc][source]
get_cell_type(cell_index: int)archai.nas.model_desc.CellType[source]
get_conf_cell()archai.common.config.Config[source]
get_conf_dataset()archai.common.config.Config[source]
get_conf_model_stems()archai.common.config.Config[source]
get_node_channels(conf_model_desc: archai.common.config.Config) → List[List[int]][source]

Returns array of channels for each node in each cell. All nodes are assumed to have same output channels as input channels.

get_node_count(cell_index: int) → int[source]
get_reduction_indices(conf_model_desc: archai.common.config.Config) → List[int][source]

Returns cell indices which reduces HxW and doubles channels

get_trainables_from(cell_index: int) → int[source]
pre_build(conf_model_desc: archai.common.config.Config) → None[source]
seed_cell(model_desc: archai.nas.model_desc.ModelDesc) → None[source]

archai.nas.nas_utils module

archai.nas.nas_utils.checkpoint_empty(checkpoint: Optional[archai.common.checkpoint.CheckPoint]) → bool[source]
archai.nas.nas_utils.create_checkpoint(conf_checkpoint: archai.common.config.Config, resume: bool) → Optional[archai.common.checkpoint.CheckPoint][source]

Creates checkpoint given its config. If resume is True then attempt is made to load existing checkpoint otherwise an empty checkpoint is created.

archai.nas.nas_utils.get_model_stats(model: archai.nas.model.Model, input_tensor_shape=[1, 3, 32, 32], clone_model=True) → tensorwatch.model_graph.torchstat_utils.ModelStats[source]

archai.nas.operations module

class archai.nas.operations.AvgPool2d7x7[source]

Bases: archai.nas.operations.Op

can_drop_path() → bool[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.ConcateChannelsOp(op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.MergeOp

forward(states: List[torch.Tensor])[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.ConvBNReLU(op_desc: archai.nas.model_desc.OpDesc, kernel_size: int, stride: int, padding: int, affine: bool)[source]

Bases: archai.nas.operations.Op

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.DilConv(op_desc: archai.nas.model_desc.OpDesc, kernel_size: int, stride: int, padding: int, dilation: int, affine: bool)[source]

Bases: archai.nas.operations.Op

(Dilated) depthwise separable conv ReLU - (Dilated) depthwise separable - Pointwise - BN

If dilation == 2, 3x3 conv => 5x5 receptive field

5x5 conv => 9x9 receptive field

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.DropPath_(p: float = 0.0)[source]

Bases: torch.nn.modules.module.Module

Replace values in tensor by 0. with probability p Ref: https://arxiv.org/abs/1605.07648

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.FacConv(op_desc: archai.nas.model_desc.OpDesc, kernel_length: int, padding: int, affine: bool)[source]

Bases: archai.nas.operations.Op

Factorized conv ReLU - Conv(Kx1) - Conv(1xK) - BN

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.FactorizedReduce(op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.Op

reduce feature maps height/width by 2X while doubling channels using two 1x1 convs, each with stride=2.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.Identity(op_desc: archai.nas.model_desc.OpDesc)[source]

Bases: archai.nas.operations.Op

can_drop_path() → bool[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.LinearOp(op_desc: archai.nas.model_desc.OpDesc)[source]

Bases: archai.nas.operations.Op

can_drop_path() → bool[source]
forward(x: torch.Tensor)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.MergeOp(op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.Op, abc.ABC

can_drop_path() → bool[source]
forward(states: List[torch.Tensor])[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.MultiOp(op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.Op

forward(x: Union[torch.Tensor, List[torch.Tensor]]) → torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.Op[source]

Bases: archai.nas.arch_module.ArchModule, abc.ABC, overrides.enforce.EnforceOverrides

can_drop_path() → bool[source]
static create(op_desc: archai.nas.model_desc.OpDesc, affine: bool, arch_params: Optional[archai.nas.arch_params.ArchParams] = None)archai.nas.operations.Op[source]
finalize() → Tuple[archai.nas.model_desc.OpDesc, Optional[float]][source]

for trainable op, return final op and its rank

get_trainables() → Mapping[source]
ops() → Iterator[Tuple[archai.nas.operations.Op, float]][source]

Return contituent ops, if this op is primitive just return self

static register_op(name: str, factory_fn: Callable, exists_ok=True) → None[source]
set_trainables(state_dict) → None[source]
class archai.nas.operations.PoolAdaptiveAvg2D[source]

Bases: archai.nas.operations.Op

can_drop_path() → bool[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.PoolBN(pool_type: str, op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.Op

AvgPool or MaxPool - BN

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.ProjectChannelsOp(op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.MergeOp

forward(states: List[torch.Tensor])[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.ReLUConvBN(op_desc: archai.nas.model_desc.OpDesc, kernel_size: int, stride: int, padding: int, affine: bool)[source]

Bases: archai.nas.operations.Op

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.SepConv(op_desc: archai.nas.model_desc.OpDesc, kernel_size: int, padding: int, affine: bool)[source]

Bases: archai.nas.operations.Op

Depthwise separable conv DilConv(dilation=1) * 2

This is same as two DilConv stacked with dilation=1

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.SkipConnect(op_desc: archai.nas.model_desc.OpDesc, affine)[source]

Bases: archai.nas.operations.Op

can_drop_path() → bool[source]
forward(x: torch.Tensor) → torch.Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.StemBase(reduction: int)[source]

Bases: archai.nas.operations.Op

Abstract base class for model stems that enforces reduction property indicating amount of spatial map reductions performed by stem, i.e., reduction=2 for each stride=2

class archai.nas.operations.StemConv3x3(op_desc: archai.nas.model_desc.OpDesc, affine: bool)[source]

Bases: archai.nas.operations.StemBase

can_drop_path() → bool[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.StemConv3x3S4(op_desc, affine: bool)[source]

Bases: archai.nas.operations.StemBase

can_drop_path() → bool[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.StemConv3x3S4S2(op_desc, affine: bool)[source]

Bases: archai.nas.operations.StemBase

can_drop_path() → bool[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class archai.nas.operations.Zero(op_desc: archai.nas.model_desc.OpDesc)[source]

Bases: archai.nas.operations.Op

Represents no connection. Zero op can be thought of 1x1 kernel with fixed zero weight. For stride=1, it will produce output of same dimension as input but with all 0s. Now with stride of 2, it will zero out every other pixel in output.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

archai.nas.operations.affine_grid_generator()

archai.nas.random_finalizers module

class archai.nas.random_finalizers.RandomFinalizers[source]

Bases: archai.nas.finalizers.Finalizers

finalize_node(node: torch.nn.modules.container.ModuleList, node_index: int, node_desc: archai.nas.model_desc.NodeDesc, max_final_edges: int, *args, **kwargs)archai.nas.model_desc.NodeDesc[source]

archai.nas.search_combinations module

class archai.nas.search_combinations.SearchCombinations[source]

Bases: archai.nas.searcher.Searcher

get_combinations(conf_search: archai.common.config.Config) → Iterator[Tuple[int, int, int]][source]
is_better_metrics(metrics1: Optional[archai.common.metrics.Metrics], metrics2: Optional[archai.common.metrics.Metrics]) → bool[source]
record_checkpoint(macro_comb_i: int, best_result: archai.nas.searcher.SearchResult) → None[source]
restore_checkpoint(conf_search: archai.common.config.Config, macro_combinations) → Tuple[int, Optional[archai.nas.searcher.SearchResult]][source]
save_trained(conf_search: archai.common.config.Config, reductions: int, cells: int, nodes: int, model_metrics: archai.nas.searcher.ModelMetrics) → None[source]

Save the model and metric info into a log file

search(conf_search: archai.common.config.Config, model_desc_builder: archai.nas.model_desc_builder.ModelDescBuilder, trainer_class: Optional[Type[ArchTrainer]], finalizers: archai.nas.finalizers.Finalizers)archai.nas.searcher.SearchResult[source]

archai.nas.searcher module

class archai.nas.searcher.ModelMetrics(model: archai.nas.model.Model, metrics: archai.common.metrics.Metrics)[source]

Bases: object

class archai.nas.searcher.SearchResult(model_desc: Optional[archai.nas.model_desc.ModelDesc], search_metrics: Optional[archai.common.metrics.Metrics], train_metrics: Optional[archai.common.metrics.Metrics])[source]

Bases: object

class archai.nas.searcher.Searcher[source]

Bases: overrides.enforce.EnforceOverrides

build_model_desc(model_desc_builder: archai.nas.model_desc_builder.ModelDescBuilder, conf_model_desc: archai.common.config.Config, reductions: int, cells: int, nodes: int)archai.nas.model_desc.ModelDesc[source]
clean_log_result(conf_search: archai.common.config.Config, search_result: archai.nas.searcher.SearchResult) → None[source]
finalize_model(model: archai.nas.model.Model, finalizers: archai.nas.finalizers.Finalizers)archai.nas.model_desc.ModelDesc[source]
get_data(conf_loader: archai.common.config.Config) → Tuple[Optional[torch.utils.data.dataloader.DataLoader], Optional[torch.utils.data.dataloader.DataLoader]][source]
search(conf_search: archai.common.config.Config, model_desc_builder: Optional[archai.nas.model_desc_builder.ModelDescBuilder], trainer_class: Optional[Type[ArchTrainer]], finalizers: archai.nas.finalizers.Finalizers)archai.nas.searcher.SearchResult[source]
search_model_desc(conf_search: archai.common.config.Config, model_desc: archai.nas.model_desc.ModelDesc, trainer_class: Optional[Type[ArchTrainer]], finalizers: archai.nas.finalizers.Finalizers) → Tuple[archai.nas.model_desc.ModelDesc, Optional[archai.common.metrics.Metrics]][source]
train_model_desc(model_desc: archai.nas.model_desc.ModelDesc, conf_train: archai.common.config.Config) → Optional[archai.nas.searcher.ModelMetrics][source]

Train given description

archai.nas.vis_model_desc module

Network architecture visualizer using graphviz

archai.nas.vis_model_desc.draw_cell_desc(cell_desc: archai.nas.model_desc.CellDesc, filepath: str = None, caption: str = None) → graphviz.dot.Digraph[source]

make DAG plot and optionally save to filepath as .png

archai.nas.vis_model_desc.draw_model_desc(model_desc: archai.nas.model_desc.ModelDesc, filepath: str = None, caption: str = None) → Tuple[Optional[graphviz.dot.Digraph], Optional[graphviz.dot.Digraph]][source]

Module contents