# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterator, Mapping, Type, Optional, Tuple, List
import math
import copy
import random
import os
from overrides import EnforceOverrides
from torch.utils.data.dataloader import DataLoader
from archai.common.common import logger
from archai.common.config import Config
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas.arch_trainer import TArchTrainer
from archai.common.trainer import Trainer
from archai.nas.model_desc import CellType, ModelDesc
from archai.datasets import data
from archai.nas.model import Model
from archai.common.metrics import EpochMetrics, Metrics
from archai.common import utils
from archai.nas.finalizers import Finalizers
[docs]class ModelMetrics:
def __init__(self, model:Model, metrics:Metrics) -> None:
self.model = model
self.metrics = metrics
[docs]class SearchResult:
def __init__(self, model_desc:Optional[ModelDesc],
search_metrics:Optional[Metrics],
train_metrics:Optional[Metrics]) -> None:
self.model_desc = model_desc
self.search_metrics = search_metrics
self.train_metrics = train_metrics
[docs]class Searcher(EnforceOverrides):
[docs] def search(self, conf_search:Config, model_desc_builder:Optional[ModelDescBuilder],
trainer_class:TArchTrainer, finalizers:Finalizers)->SearchResult:
# region config vars
conf_model_desc = conf_search['model_desc']
conf_post_train = conf_search['post_train']
cells = conf_model_desc['n_cells']
reductions = conf_model_desc['n_reductions']
nodes = conf_model_desc['cell']['n_nodes']
# endregion
assert model_desc_builder is not None, 'Default search implementation requires model_desc_builder'
# build model description that we will search on
model_desc = self.build_model_desc(model_desc_builder, conf_model_desc,
reductions, cells, nodes)
# perform search on model description
model_desc, search_metrics = self.search_model_desc(conf_search, model_desc,
trainer_class, finalizers)
# train searched model for few epochs to get some perf metrics
model_metrics = self.train_model_desc(model_desc, conf_post_train)
search_result = SearchResult(model_desc, search_metrics,
model_metrics.metrics if model_metrics is not None else None)
self.clean_log_result(conf_search, search_result)
return search_result
[docs] def clean_log_result(self, conf_search:Config, search_result:SearchResult)->None:
final_desc_filename = conf_search['final_desc_filename']
# remove weights info deom model_desc so its more readable
search_result.model_desc.clear_trainables()
# if file name was specified then save the model desc
if final_desc_filename:
search_result.model_desc.save(final_desc_filename)
if search_result.search_metrics is not None:
logger.info({'search_top1_val':
search_result.search_metrics.best_val_top1()})
if search_result.train_metrics is not None:
logger.info({'train_top1_val':
search_result.train_metrics.best_val_top1()})
[docs] def build_model_desc(self, model_desc_builder:ModelDescBuilder,
conf_model_desc:Config,
reductions:int, cells:int, nodes:int)->ModelDesc:
# reset macro params in copy of config
conf_model_desc = copy.deepcopy(conf_model_desc)
conf_model_desc['n_reductions'] = reductions
conf_model_desc['n_cells'] = cells
# create model desc for search using model config
# we will build model without call to model_desc_builder for pre-training
model_desc = model_desc_builder.build(conf_model_desc, template=None)
return model_desc
[docs] def get_data(self, conf_loader:Config)->Tuple[Optional[DataLoader], Optional[DataLoader]]:
# this dict caches the dataset objects per dataset config so we don't have to reload
# the reason we do dynamic attribute is so that any dependent methods
# can do ray.remote
if not hasattr(self, '_data_cache'):
self._data_cache = {}
# first get from cache
train_ds, val_ds = self._data_cache.get(id(conf_loader), (None, None))
# if not found in cache then create
if train_ds is None:
train_ds, val_ds, _ = data.get_data(conf_loader)
self._data_cache[id(conf_loader)] = (train_ds, val_ds)
return train_ds, val_ds
[docs] def finalize_model(self, model:Model, finalizers:Finalizers)->ModelDesc:
return finalizers.finalize_model(model, restore_device=False)
[docs] def search_model_desc(self, conf_search:Config, model_desc:ModelDesc,
trainer_class:TArchTrainer, finalizers:Finalizers)\
->Tuple[ModelDesc, Optional[Metrics]]:
# if trainer is not specified for algos like random search we return same desc
if trainer_class is None:
return model_desc, None
logger.pushd('arch_search')
conf_trainer = conf_search['trainer']
conf_loader = conf_search['loader']
model = Model(model_desc, droppath=False, affine=False)
# get data
train_dl, val_dl = self.get_data(conf_loader)
assert train_dl is not None
# search arch
arch_trainer = trainer_class(conf_trainer, model, checkpoint=None)
search_metrics = arch_trainer.fit(train_dl, val_dl)
# finalize
found_desc = self.finalize_model(model, finalizers)
logger.popd()
return found_desc, search_metrics
[docs] def train_model_desc(self, model_desc:ModelDesc, conf_train:Config)\
->Optional[ModelMetrics]:
"""Train given description"""
# region conf vars
conf_trainer = conf_train['trainer']
conf_loader = conf_train['loader']
trainer_title = conf_trainer['title']
epochs = conf_trainer['epochs']
drop_path_prob = conf_trainer['drop_path_prob']
# endregion
# if epochs ==0 then nothing to train, so save time
if epochs <= 0:
return None
logger.pushd(trainer_title)
model = Model(model_desc, droppath=drop_path_prob>0.0, affine=True)
# get data
train_dl, val_dl = self.get_data(conf_loader)
assert train_dl is not None
trainer = Trainer(conf_trainer, model, checkpoint=None)
train_metrics = trainer.fit(train_dl, val_dl)
logger.popd()
return ModelMetrics(model, train_metrics)