Source code for archai.algos.gumbelsoftmax.gs_op

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Iterable, Optional, Tuple, List, Iterator

import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from overrides import overrides

from archai.nas.model_desc import OpDesc
from archai.nas.operations import Op
from archai.nas.arch_params import ArchParams
from archai.common.utils import zip_eq

# TODO: reduction cell might have output reduced by 2^1=2X due to
#   stride 2 through input nodes however FactorizedReduce does only
#   4X reduction. Is this correct?


[docs]class GsOp(Op): """The output of GsOp is weighted output of all allowed primitives. """ PRIMITIVES = [ 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', # identity 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', 'none' # this must be at the end so top1 doesn't chose it ] def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams], affine:bool): super().__init__() # assume last PRIMITIVE is 'none' assert GsOp.PRIMITIVES[-1] == 'none' self._ops = nn.ModuleList() for primitive in GsOp.PRIMITIVES: op = Op.create( OpDesc(primitive, op_desc.params, in_len=1, trainables=None), affine=affine, arch_params=None) self._ops.append(op) # we do this at the end so that we can capture all arch params registered by # any previous child modules self._setup_arch_params(arch_params)
[docs] def set_op_sampled_weights(self, sampled_weights:Tensor): ''' Sets the weight for each op ''' assert sampled_weights.shape[0] == len(GsOp.PRIMITIVES) self._sampled_weights = sampled_weights
[docs] @overrides def forward(self, x): assert self._sampled_weights is not None return sum(w * op(x) for w, op in zip_eq(self._sampled_weights, self._ops))
[docs] @overrides def finalize(self, sampled_weights) -> Tuple[OpDesc, Optional[float]]: # finalization where each edge gets to keep as many # unique operations that are **sampled at the node level** assert sampled_weights.shape[0] == len(GsOp.PRIMITIVES) # we can't handle empty op assert sampled_weights.bool().any() greater_than_0 = sampled_weights > 0 children = [] children_ins = [] selected_alphas = [] for i in range(greater_than_0.size()[0]): if greater_than_0[i]: children.append(self._ops[i].finalize()[0]) selected_alphas.append(self._alphas[0][i].item()) # all the ops will operate on the single node input children_ins.append(0) final_op_desc = OpDesc(name='multi_op', params={ # copy convolution parameters 'conv': self.desc.params['conv'] }, # number of inputs remains same and in this # case should be 1 in_len=self.desc.in_len, trainables=None, # primitive's finalize call also records its # weights in description. finalize call returns # (desc, rank) where rank for primitive is None children = children, children_ins = children_ins ) max_alpha = 0.0 if selected_alphas: max_alpha = max(selected_alphas) return final_op_desc, max_alpha
[docs] @overrides def can_drop_path(self) -> bool: return False
[docs] @overrides def ops(self)->Iterator[Tuple['Op', float]]: # type: ignore return iter(sorted(zip_eq(self._ops, self._alphas[0]), key=lambda t:t[1], reverse=True))
def _setup_arch_params(self, arch_params:Optional[ArchParams])->None: # do we have shared arch params? if arch_params is None: # create our own arch params new_p = nn.Parameter( # TODO: use better init than uniform random? 1.0e-3*torch.randn(len(GsOp.PRIMITIVES)), requires_grad=True) self.create_arch_params([('alphas', new_p)]) else: assert arch_params.has_kind('alphas') self.set_arch_params(arch_params) # we store alphas in list so Pytorch don't register them self._alphas = list(self.arch_params().param_by_kind('alphas')) assert len(self._alphas)==1