# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterator, Mapping, Type, Optional, Tuple, List
import math
import copy
import random
import os
import torch
import tensorwatch as tw
from torch.utils.data.dataloader import DataLoader
import yaml
from overrides import overrides
from archai.common.common import logger
from archai.common.checkpoint import CheckPoint
from archai.common.config import Config
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas.arch_trainer import TArchTrainer
from archai.nas import nas_utils
from archai.nas.model_desc import CellType, ModelDesc
from archai.common.trainer import Trainer
from archai.datasets import data
from archai.nas.model import Model
from archai.common.metrics import EpochMetrics, Metrics
from archai.common import utils
from archai.nas.finalizers import Finalizers
from archai.nas.searcher import ModelMetrics, Searcher, SearchResult
from archai.nas import nas_utils
[docs]class SearchCombinations(Searcher):
[docs] @overrides
def search(self, conf_search:Config, model_desc_builder:ModelDescBuilder,
trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult:
# region config vars
conf_model_desc = conf_search['model_desc']
conf_post_train = conf_search['post_train']
conf_checkpoint = conf_search['checkpoint']
resume = conf_search['resume']
# endregion
self._checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume)
macro_combinations = list(self.get_combinations(conf_search))
start_macro_i, best_search_result = self.restore_checkpoint(conf_search,
macro_combinations)
best_macro_comb = -1,-1,-1 # reductions, cells, nodes
for macro_comb_i in range(start_macro_i, len(macro_combinations)):
reductions, cells, nodes = macro_combinations[macro_comb_i]
logger.pushd(f'r{reductions}.c{cells}.n{nodes}')
# build model description that we will search on
model_desc = self.build_model_desc(model_desc_builder, conf_model_desc,
reductions, cells, nodes)
# perform search on model description
model_desc, search_metrics = self.search_model_desc(conf_search,
model_desc, trainer_class, finalizers)
# train searched model for few epochs to get some perf metrics
model_metrics = self.train_model_desc(model_desc,
conf_post_train)
assert model_metrics is not None, "'post_train' section in yaml should have non-zero epochs if running combinations search"
# save result
self.save_trained(conf_search, reductions, cells, nodes, model_metrics)
# update the best result so far
if self.is_better_metrics(best_search_result.search_metrics,
model_metrics.metrics):
best_search_result = SearchResult(model_desc, search_metrics,
model_metrics.metrics)
best_macro_comb = reductions, cells, nodes
# checkpoint
assert best_search_result is not None
self.record_checkpoint(macro_comb_i, best_search_result)
logger.popd() # reductions, cells, nodes
assert best_search_result is not None
self.clean_log_result(conf_search, best_search_result)
logger.info({'best_macro_comb':best_macro_comb})
return best_search_result
[docs] def is_better_metrics(self, metrics1:Optional[Metrics],
metrics2:Optional[Metrics])->bool:
if metrics1 is None or metrics2 is None:
return True
return metrics2.best_val_top1() >= metrics1.best_val_top1()
[docs] def restore_checkpoint(self, conf_search:Config, macro_combinations)\
->Tuple[int, Optional[SearchResult]]:
conf_pareto = conf_search['pareto']
pareto_summary_filename = conf_pareto['summary_filename']
summary_filepath = utils.full_path(pareto_summary_filename)
# if checkpoint is available then restart from last combination we were running
checkpoint_avail = self._checkpoint is not None
resumed, state = False, None
start_macro_i, best_result = 0, None
if checkpoint_avail:
state = self._checkpoint.get('search', None)
if state is not None:
start_macro_i = state['start_macro_i']
assert start_macro_i >= 0 and start_macro_i < len(macro_combinations)
best_result = yaml.load(state['best_result'], Loader=yaml.Loader)
start_macro_i += 1 # resume after the last checkpoint
resumed = True
if not resumed:
# erase previous file left over from run
utils.zero_file(summary_filepath)
logger.warn({'resumed': resumed, 'checkpoint_avail': checkpoint_avail,
'checkpoint_val': state is not None,
'start_macro_i': start_macro_i,
'total_macro_combinations': len(macro_combinations)})
return start_macro_i, best_result
[docs] def record_checkpoint(self, macro_comb_i:int, best_result:SearchResult)->None:
if self._checkpoint is not None:
state = {'start_macro_i': macro_comb_i,
'best_result': yaml.dump(best_result)}
self._checkpoint.new()
self._checkpoint['search'] = state
self._checkpoint.commit()
[docs] def get_combinations(self, conf_search:Config)->Iterator[Tuple[int, int, int]]:
conf_pareto = conf_search['pareto']
conf_model_desc = conf_search['model_desc']
min_cells = conf_model_desc['n_cells']
min_reductions = conf_model_desc['n_reductions']
min_nodes = conf_model_desc['cell']['n_nodes']
max_cells = conf_pareto['max_cells']
max_reductions = conf_pareto['max_reductions']
max_nodes = conf_pareto['max_nodes']
logger.info({'min_reductions': min_reductions,
'min_cells': min_cells,
'min_nodes': min_nodes,
'max_reductions': max_reductions,
'max_cells': max_cells,
'max_nodes': max_nodes
})
# TODO: what happens when reductions is 3 but cells is 2? have to step
# through code and check
for reductions in range(min_reductions, max_reductions+1):
for cells in range(min_cells, max_cells+1):
for nodes in range(min_nodes, max_nodes+1):
yield reductions, cells, nodes
[docs] def save_trained(self, conf_search:Config, reductions:int, cells:int, nodes:int,
model_metrics:ModelMetrics)->None:
"""Save the model and metric info into a log file"""
metrics_dir = conf_search['metrics_dir']
# construct path where we will save
subdir = utils.full_path(metrics_dir.format(**vars()), create=True)
model_stats = nas_utils.get_model_stats(model_metrics.model)
# save model_stats in its own file
model_stats_filepath = os.path.join(subdir, 'model_stats.yaml')
if model_stats_filepath:
with open(model_stats_filepath, 'w') as f:
yaml.dump(model_stats, f)
# save just metrics separately for convinience
metrics_filepath = os.path.join(subdir, 'metrics.yaml')
if metrics_filepath:
with open(metrics_filepath, 'w') as f:
yaml.dump(model_stats.metrics, f)
logger.info({'model_stats_filepath': model_stats_filepath,
'metrics_filepath': metrics_filepath})
# append key info in root pareto data
if self._summary_filepath:
train_top1 = val_top1 = train_epoch = val_epoch = math.nan
# extract metrics
if model_metrics.metrics:
best_metrics = model_metrics.metrics.run_metrics.best_epoch()
train_top1 = best_metrics[0].top1.avg
train_epoch = best_metrics[0].index
if best_metrics[1]:
val_top1 = best_metrics[1].top1.avg if len(best_metrics)>1 else math.nan
val_epoch = best_metrics[1].index if len(best_metrics)>1 else math.nan
# extract model stats
flops = model_stats.Flops
parameters = model_stats.parameters
inference_memory = model_stats.inference_memory
inference_duration = model_stats.duration
utils.append_csv_file(self._summary_filepath, [
('reductions', reductions),
('cells', cells),
('nodes', nodes),
('train_top1', train_top1),
('train_epoch', train_epoch),
('val_top1', val_top1),
('val_epoch', val_epoch),
('flops', flops),
('params', parameters),
('inference_memory', inference_memory),
('inference_duration', inference_duration)
])