Source code for archai.algos.divnas.divnas_cell

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import defaultdict
from typing import Callable, Iterable, List, Optional, Tuple, Dict
from abc import ABC, abstractmethod

from overrides import overrides, EnforceOverrides

import numpy as np

import torch
from torch import nn, tensor
from overrides import overrides, EnforceOverrides

import archai.algos.divnas.analyse_activations as aa

from archai.nas.cell import Cell
from archai.nas.operations import Zero
from archai.nas.operations import Op


[docs]class Divnas_Cell(): ''' Wrapper cell class for divnas specific modifications ''' def __init__(self, cell:Cell): self._cell = cell self._collect_activations = False self._edgeoptype = None self._sigma = None self._counter = 0 self.node_covs:Dict[int, np.array] = {} self.node_num_to_node_op_to_cov_ind:Dict[int, Dict[Op, int]] = {}
[docs] def collect_activations(self, edgeoptype, sigma:float)->None: self._collect_activations = True self._edgeoptype = edgeoptype self._sigma = sigma # collect bookkeeping info for i, node in enumerate(self._cell.dag): node_op_to_cov_ind:Dict[Op, int] = {} counter = 0 for edge in node: for op, alpha in edge._op.ops(): if isinstance(op, Zero): continue node_op_to_cov_ind[op] = counter counter += 1 self.node_num_to_node_op_to_cov_ind[i] = node_op_to_cov_ind # go through all edges in the DAG and if they are of edgeoptype # type then set them to collect activations for i, node in enumerate(self._cell.dag): # initialize the covariance matrix for this node num_ops = 0 for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == self._edgeoptype: num_ops += edge._op.num_primitive_ops - 1 edge._op.collect_activations = True self.node_covs[id(node)] = np.zeros((num_ops, num_ops))
[docs] def update_covs(self): assert self._collect_activations for _, node in enumerate(self._cell.dag): # TODO: convert to explicit ordering all_activs = [] for j, edge in enumerate(node): if type(edge._op) == self._edgeoptype: activs = edge._op.activations all_activs.append(activs) # update covariance matrix activs_converted = self._convert_activations(all_activs) new_cov = aa.compute_rbf_kernel_covariance(activs_converted, sigma=self._sigma) updated_cov = (self._counter * self.node_covs[id(node)] + new_cov) / (self._counter + 1) self.node_covs[id(node)] = updated_cov
[docs] def clear_collect_activations(self): for _, node in enumerate(self._cell.dag): for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == self._edgeoptype: edge._op.collect_activations = False self._collect_activations = False self._edgeoptype = None self._sigma = None self._node_covs = {}
def _convert_activations(self, all_activs:List[List[np.array]])->List[np.array]: ''' Converts to the format needed by covariance computing functions Input all_activs: List[List[np.array]]. Outer list len is num_edges. Inner list is of num_ops length. Each element in inner list is [batch_size, x, y, z] ''' num_ops = len(all_activs[0]) for activs in all_activs: assert num_ops == len(activs) all_edge_list = [] for edge in all_activs: obsv_dict = defaultdict(list) # assumption edge_np will be (num_ops, batch_size, x, y, z) edge_np = np.array(edge) for op in range(edge_np.shape[0]): for b in range(edge_np.shape[1]): feat = edge_np[op][b] feat = feat.flatten() obsv_dict[op].append(feat) feature_list = [*range(num_ops)] for key in obsv_dict.keys(): feat = np.array(obsv_dict[key]) feature_list[key] = feat all_edge_list.extend(feature_list) return all_edge_list