Source code for archai.nas.evaluater

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Tuple
import importlib
import sys
import string
import os

import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader

from overrides import overrides, EnforceOverrides

from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas import nas_utils
from archai.common import ml_utils, utils
from archai.common.metrics import EpochMetrics, Metrics
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint


[docs]class EvalResult: def __init__(self, train_metrics:Metrics) -> None: self.train_metrics = train_metrics
[docs]class Evaluater(EnforceOverrides):
[docs] def evaluate(self, conf_eval:Config, model_desc_builder:ModelDescBuilder)->EvalResult: logger.pushd('eval_arch') # region conf vars conf_checkpoint = conf_eval['checkpoint'] resume = conf_eval['resume'] model_filename = conf_eval['model_filename'] metric_filename = conf_eval['metric_filename'] # endregion model = self.create_model(conf_eval, model_desc_builder) checkpoint = nas_utils.create_checkpoint(conf_checkpoint, resume) train_metrics = self.train_model(conf_eval, model, checkpoint) train_metrics.save(metric_filename) # save model if model_filename: model_filename = utils.full_path(model_filename) ml_utils.save_model(model, model_filename) logger.info({'model_save_path': model_filename}) logger.popd() return EvalResult(train_metrics)
[docs] def train_model(self, conf_train:Config, model:nn.Module, checkpoint:Optional[CheckPoint])->Metrics: conf_loader = conf_train['loader'] conf_train = conf_train['trainer'] # get data train_dl, test_dl = self.get_data(conf_loader) trainer = Trainer(conf_train, model, checkpoint) train_metrics = trainer.fit(train_dl, test_dl) return train_metrics
[docs] def get_data(self, conf_loader:Config)->Tuple[DataLoader, DataLoader]: # this dict caches the dataset objects per dataset config so we don't have to reload # the reason we do dynamic attribute is so that any dependent methods # can do ray.remote if not hasattr(self, '_data_cache'): self._data_cache = {} # first get from cache train_dl, test_dl = self._data_cache.get(id(conf_loader), (None, None)) # if not found in cache then create if train_dl is None: train_dl, _, test_dl = data.get_data(conf_loader) self._data_cache[id(conf_loader)] = (train_dl, test_dl) assert train_dl is not None and test_dl is not None return train_dl, test_dl
def _default_module_name(self, dataset_name:str, function_name:str)->str: """Select PyTorch pre-defined network to support manual mode""" module_name = '' # TODO: below detection code is too week, need to improve, possibly encode image size in yaml and use that instead if dataset_name.startswith('cifar'): if function_name.startswith('res'): # support resnext as well module_name = 'archai.cifar10_models.resnet' elif function_name.startswith('dense'): module_name = 'archai.cifar10_models.densenet' elif dataset_name.startswith('imagenet') or dataset_name.startswith('sport8'): module_name = 'torchvision.models' if not module_name: raise NotImplementedError(f'Cannot get default module for {function_name} and dataset {dataset_name} because it is not supported yet') return module_name
[docs] def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder, final_desc_filename=None, full_desc_filename=None)->nn.Module: assert model_desc_builder is not None, 'Default evaluater requires model_desc_builder' # region conf vars # if explicitly passed in then don't get from conf if not final_desc_filename: final_desc_filename = conf_eval['final_desc_filename'] full_desc_filename = conf_eval['full_desc_filename'] conf_model_desc = conf_eval['model_desc'] # endregion # load model desc file to get template model template_model_desc = ModelDesc.load(final_desc_filename) model_desc = model_desc_builder.build(conf_model_desc, template=template_model_desc) # save desc for reference model_desc.save(full_desc_filename) model = self.model_from_desc(model_desc) logger.info({'model_factory':False, 'cells_len':len(model.desc.cell_descs()), 'init_node_ch': conf_model_desc['model_stems']['init_node_ch'], 'n_cells': conf_model_desc['n_cells'], 'n_reductions': conf_model_desc['n_reductions'], 'n_nodes': conf_model_desc['cell']['n_nodes']}) return model
[docs] def model_from_desc(self, model_desc)->Model: return Model(model_desc, droppath=True, affine=True)