# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterator, Mapping, Type, Optional, Tuple, List
import math
import copy
import random
import os
from queue import Queue
import secrets
import string
from enum import Enum
import copy
# latest verion of ray works on Windows as well
import ray
from overrides import overrides
import numpy as np
import torch
import tensorwatch as tw
from torch.utils.data.dataloader import DataLoader
import yaml
from archai.common import common
from archai.common.common import logger, CommonState
from archai.common.checkpoint import CheckPoint
from archai.common.config import Config
from archai.nas.arch_trainer import TArchTrainer
from archai.nas import nas_utils
from archai.nas.model_desc import ConvMacroParams, CellDesc, CellType, OpDesc, \
EdgeDesc, TensorShape, TensorShapes, NodeDesc, ModelDesc
from archai.common.trainer import Trainer
from archai.datasets import data
from archai.nas.model import Model
from archai.common.metrics import Metrics
from archai.common import utils
from archai.nas.finalizers import Finalizers
from archai.algos.petridish.petridish_utils import _convex_hull_from_points
from archai.nas.searcher import SearchResult
from archai.nas.search_combinations import SearchCombinations
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.algos.petridish.petridish_utils import ConvexHullPoint, JobStage, \
sample_from_hull, plot_frontier, save_hull_frontier, save_hull, plot_pool, plot_seed_model_stats
[docs]class SearcherPetridish(SearchCombinations):
[docs] @overrides
def search(self, conf_search:Config, model_desc_builder:ModelDescBuilder,
trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult:
logger.pushd('search')
# region config vars
self.conf_search = conf_search
conf_checkpoint = conf_search['checkpoint']
resume = conf_search['resume']
conf_post_train = conf_search['post_train']
final_desc_foldername = conf_search['final_desc_foldername']
conf_petridish = conf_search['petridish']
# petridish distributed search related parameters
self._convex_hull_eps = conf_petridish['convex_hull_eps']
self._max_madd = conf_petridish['max_madd']
self._max_hull_points = conf_petridish['max_hull_points']
self._checkpoints_foldername = conf_petridish['checkpoints_foldername']
# endregion
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
# parent models list
self._hull_points: List[ConvexHullPoint] = []
self._ensure_dataset_download(conf_search)
# checkpoint will restore the hull we had
is_restored = self._restore_checkpoint()
# seed the pool with many models of different
# macro parameters like number of cells, reductions etc if parent pool
# could not be restored and/or this is the first time this job has been run.
future_ids = [] if is_restored else self._create_seed_jobs(conf_search,
model_desc_builder)
while not self._is_search_done():
logger.info(f'Ray jobs running: {len(future_ids)}')
if future_ids:
# get first completed job
job_id_done, future_ids = ray.wait(future_ids)
hull_point = ray.get(job_id_done[0])
logger.info(f'Hull point id {hull_point.id} with stage {hull_point.job_stage.name} completed')
if hull_point.is_trained_stage():
self._update_convex_hull(hull_point)
# sample a point and search
sampled_point = sample_from_hull(self._hull_points,
self._convex_hull_eps)
future_id = SearcherPetridish.search_model_desc_dist.remote(self,
conf_search, sampled_point, model_desc_builder, trainer_class,
finalizers, common.get_state())
future_ids.append(future_id)
logger.info(f'Added sampled point {sampled_point.id} for search')
elif hull_point.job_stage==JobStage.SEARCH:
# create the job to train the searched model
future_id = SearcherPetridish.train_model_desc_dist.remote(self,
conf_post_train, hull_point, common.get_state())
future_ids.append(future_id)
logger.info(f'Added sampled point {hull_point.id} for post-search training')
else:
raise RuntimeError(f'Job stage "{hull_point.job_stage}" is not expected in search loop')
# cancel any remaining jobs to free up gpus for the eval phase
for future_id in future_ids:
ray.cancel(future_id, force=True) # without force, main process stops
ray.wait([future_id])
# plot and save the hull
expdir = common.get_expdir()
assert expdir
plot_frontier(self._hull_points, self._convex_hull_eps, expdir)
best_point = save_hull_frontier(self._hull_points, self._convex_hull_eps,
final_desc_foldername, expdir)
save_hull(self._hull_points, expdir)
plot_pool(self._hull_points,expdir )
# return best point as search result
search_result = SearchResult(best_point.model_desc, search_metrics=None,
train_metrics=best_point.metrics)
self.clean_log_result(conf_search, search_result)
logger.popd()
return search_result
@staticmethod
@ray.remote(num_gpus=1)
def search_model_desc_dist(searcher:'SearcherPetridish', conf_search:Config,
hull_point:ConvexHullPoint, model_desc_builder:ModelDescBuilder,
trainer_class:TArchTrainer, finalizers:Finalizers, common_state:CommonState)\
->ConvexHullPoint:
# as this runs in different process, initialize globals
common.init_from(common_state)
#register ops as we are in different process now
conf_model_desc = conf_search['model_desc']
model_desc_builder.pre_build(conf_model_desc)
assert hull_point.is_trained_stage()
# cloning is strictly not needed but just in case if we run this
# function in same process, it would be good to avoid surprise
model_desc = hull_point.model_desc.clone()
searcher._add_node(model_desc, model_desc_builder)
model_desc, search_metrics = searcher.search_model_desc(conf_search,
model_desc, trainer_class, finalizers)
cells, reductions, nodes = hull_point.cells_reductions_nodes
new_point = ConvexHullPoint(JobStage.SEARCH, hull_point.id,
hull_point.sampling_count, model_desc,
(cells, reductions, nodes+1), # we added a node
metrics=search_metrics)
return new_point
@staticmethod
@ray.remote(num_gpus=1)
def train_model_desc_dist(searcher:'SearcherPetridish', conf_train:Config,
hull_point:ConvexHullPoint, common_state:CommonState)\
->ConvexHullPoint:
# as this runs in different process, initialize globals
common.init_from(common_state)
assert not hull_point.is_trained_stage()
model_metrics = searcher.train_model_desc(hull_point.model_desc, conf_train)
model_stats = nas_utils.get_model_stats(model_metrics.model)
new_point = ConvexHullPoint(hull_point.next_stage(), hull_point.id, hull_point.
sampling_count, hull_point.model_desc,
hull_point.cells_reductions_nodes,
model_metrics.metrics,
model_stats)
return new_point
def _add_node(self, model_desc:ModelDesc, model_desc_builder:ModelDescBuilder)->None:
for ci, cell_desc in enumerate(model_desc.cell_descs()):
reduction = (cell_desc.cell_type==CellType.Reduction)
nodes = cell_desc.nodes()
# petridish must seed with one node
assert len(nodes) > 0
# input/output channels for all nodes are same
conv_params = nodes[0].conv_params
# assign input IDs to nodes, s0 and s1 have IDs 0 and 1
# however as we will be inserting new node before last one
input_ids = list(range(len(nodes) + 1))
assert len(input_ids) >= 2 # 2 stem inputs
op_desc = OpDesc('petridish_reduction_op' if reduction else 'petridish_normal_op',
params={
'conv': conv_params,
# specify strides for each input, later we will
# give this to each primitive
'_strides':[2 if reduction and j < 2 else 1 \
for j in input_ids],
}, in_len=len(input_ids), trainables=None, children=None)
edge = EdgeDesc(op_desc, input_ids=input_ids)
new_node = NodeDesc(edges=[edge], conv_params=conv_params)
nodes.insert(len(nodes)-1, new_node)
# output shape of all nodes are same
node_shapes = cell_desc.node_shapes
new_node_shape = copy.deepcopy(node_shapes[-1])
node_shapes.insert(len(node_shapes)-1, new_node_shape)
# post op needs rebuilding because number of inputs to it has changed so input/output channels may be different
post_op_shape, post_op_desc = model_desc_builder.build_cell_post_op(cell_desc.stem_shapes,
node_shapes, cell_desc.conf_cell, ci)
cell_desc.reset_nodes(nodes, node_shapes,
post_op_desc, post_op_shape)
def _ensure_dataset_download(self, conf_search:Config)->None:
conf_loader = conf_search['loader']
self.get_data(conf_loader)
def _is_search_done(self)->bool:
'''Terminate search if max MAdd or number of points exceeded'''
if not self._hull_points:
return False
max_madd_parent = max(self._hull_points, key=lambda p:p.model_stats.MAdd)
return max_madd_parent.model_stats.MAdd > self._max_madd or \
len(self._hull_points) > self._max_hull_points
def _create_seed_jobs(self, conf_search:Config, model_desc_builder:ModelDescBuilder)->list:
conf_model_desc = conf_search['model_desc']
conf_seed_train = conf_search['seed_train']
future_ids = [] # ray job IDs
seed_model_stats = [] # seed model stats for visualization and debugging
macro_combinations = list(self.get_combinations(conf_search))
for reductions, cells, nodes in macro_combinations:
# if N R N R N R cannot be satisfied, ignore combination
if cells < reductions * 2 + 1:
continue
# create seed model
model_desc = self.build_model_desc(model_desc_builder,
conf_model_desc,
reductions, cells, nodes)
hull_point = ConvexHullPoint(JobStage.SEED, 0, 0, model_desc,
(cells, reductions, nodes))
# pre-train the seed model
future_id = SearcherPetridish.train_model_desc_dist.remote(self,
conf_seed_train, hull_point, common.get_state())
future_ids.append(future_id)
# build a model so we can get its model stats
temp_model = Model(model_desc, droppath=True, affine=True)
seed_model_stats.append(nas_utils.get_model_stats(temp_model))
# save the model stats in a plot and tsv file so we can
# visualize the spread on the x-axis
expdir = common.get_expdir()
assert expdir
plot_seed_model_stats(seed_model_stats, expdir)
return future_ids
def _update_convex_hull(self, new_point:ConvexHullPoint)->None:
assert new_point.is_trained_stage() # only add models for which we have metrics and stats
self._hull_points.append(new_point)
if self._checkpoint is not None:
self._checkpoint.new()
self._checkpoint['convex_hull_points'] = self._hull_points
self._checkpoint.commit()
logger.info(f'Added to convex hull points: MAdd {new_point.model_stats.MAdd}, '
f'num cells {len(new_point.model_desc.cell_descs())}, '
f'num nodes in cell {len(new_point.model_desc.cell_descs()[0].nodes())}')
def _restore_checkpoint(self)->bool:
can_restore = self._checkpoint is not None \
and 'convex_hull_points' in self._checkpoint
if can_restore:
self._hull_points = self._checkpoint['convex_hull_points']
logger.warn({'Hull restored': True})
return can_restore
[docs] @overrides
def build_model_desc(self, model_desc_builder:ModelDescBuilder,
conf_model_desc:Config,
reductions:int, cells:int, nodes:int)->ModelDesc:
# reset macro params in copy of config
conf_model_desc = copy.deepcopy(conf_model_desc)
conf_model_desc['n_reductions'] = reductions
conf_model_desc['n_cells'] = cells
conf_model_desc['cell']['n_nodes'] = nodes
# create model desc for search using model config
# we will build model without call to model_desc_builder for pre-training
model_desc = model_desc_builder.build(conf_model_desc, template=None)
return model_desc