Source code for archai.algos.gumbelsoftmax.gs_arch_trainer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Mapping, Optional, Union
import copy

import torch
from torch.utils.data import DataLoader
from torch import Tensor, nn, autograd
from torch.nn.modules.loss import _Loss
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import torch.nn.functional as F

from overrides import overrides

from archai.common.config import Config
from archai.nas.arch_trainer import ArchTrainer
from archai.common import utils, ml_utils
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.common.common import logger, get_conf
from archai.algos.gumbelsoftmax.gs_op import GsOp

[docs]class GsArchTrainer(ArchTrainer): def __init__(self, conf_train: Config, model: nn.Module, checkpoint: Optional[CheckPoint]) -> None: super().__init__(conf_train, model, checkpoint) conf = get_conf() self._gs_num_sample = conf['nas']['search']['model_desc']['cell']['gs']['num_sample']
[docs] @overrides def create_optimizer(self, conf_optim:Config, params) -> Optimizer: # in this case we don't need to differentiate between arch_params and weights # as the same optimizer will update both arch_params = list(self.model.all_owned().param_by_kind('alphas')) nonarch_params = list(self.model.nonarch_params(recurse=True)) # TODO: do we need different param groups? Check in paper if they are using different optimizers for alphas or not. param_groups = [{'params': nonarch_params}, {'params': arch_params}] return ml_utils.create_optimizer(conf_optim, param_groups)
[docs] @overrides def pre_step(self, x:Tensor, y:Tensor)->None: super().pre_step(x, y) # TODO: is it a good idea to ensure model is in training mode here? # for each node in a cell, get the alphas of each incoming edge # concatenate them all together, sample from them via GS # push the resulting weights to the corresponding edge ops # for use in their respective forward for _, cell in enumerate(self.model.cells): for _, node in enumerate(cell.dag): # collect all alphas for all edges in to node node_alphas = [] for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: node_alphas.extend(alpha for op, alpha in edge._op.ops()) # TODO: will creating a tensor from a list of tensors preserve the graph? node_alphas = torch.Tensor(node_alphas) if node_alphas.nelement() > 0: # sample ops via gumbel softmax sample_storage = [] for _ in range(self._gs_num_sample): sampled = F.gumbel_softmax(node_alphas, tau=1, hard=False, eps=1e-10, dim=-1) sample_storage.append(sampled) samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0) samples = samples_summed / torch.sum(samples_summed) # TODO: should we be normalizing the sampled weights? # TODO: do gradients blow up as number of samples increases? # send the sampled op weights to their respective edges # to be used in forward counter = 0 for _, edge in enumerate(node): if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: this_edge_sampled_weights = samples[counter:counter+len(edge._op.PRIMITIVES)] edge._op.set_op_sampled_weights(this_edge_sampled_weights) counter += len(edge._op.PRIMITIVES)