Source code for archai.algos.gumbelsoftmax.gs_finalizers

from typing import List, Tuple, Optional, Iterator, Dict
from overrides import overrides

import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import os

from archai.common.common import get_conf
from archai.common.common import get_expdir
from archai.common.common import logger
from archai.datasets.data import get_data
from archai.nas.model import Model
from archai.nas.cell import Cell
from archai.nas.model_desc import CellDesc, ModelDesc, NodeDesc, EdgeDesc
from archai.nas.finalizers import Finalizers
from archai.algos.gumbelsoftmax.gs_op import GsOp


[docs]class GsFinalizers(Finalizers):
[docs] @overrides def finalize_node(self, node:nn.ModuleList, node_index:int, node_desc:NodeDesc, max_final_edges:int, *args, **kwargs)->NodeDesc: conf = get_conf() gs_num_sample = conf['nas']['search']['model_desc']['cell']['gs']['num_sample'] # gather the alphas of all edges in this node node_alphas = [] for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: alphas = [alpha for op, alpha in edge._op.ops()] node_alphas.extend(alphas) # TODO: will creating a tensor from a list of tensors preserve the graph? node_alphas = torch.Tensor(node_alphas) assert node_alphas.nelement() > 0 # sample ops via gumbel softmax sample_storage = [] for _ in range(gs_num_sample): sampled = F.gumbel_softmax(node_alphas, tau=1, hard=True, eps=1e-10, dim=-1) sample_storage.append(sampled) samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0) # send the sampled op weights to their # respective edges to be used for edge level finalize selected_edges = [] counter = 0 for _, edge in enumerate(node): if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: this_edge_sampled_weights = samples_summed[counter:counter+len(edge._op.PRIMITIVES)] counter += len(edge._op.PRIMITIVES) # finalize the edge if this_edge_sampled_weights.bool().any(): op_desc, _ = edge._op.finalize(this_edge_sampled_weights) new_edge = EdgeDesc(op_desc, edge.input_ids) selected_edges.append(new_edge) # delete excess edges if len(selected_edges) > max_final_edges: # since these are sample edges there is no ordering # amongst them so we just arbitrarily select a few selected_edges = selected_edges[:max_final_edges] return NodeDesc(selected_edges, node_desc.conv_params)