Source code for archai.algos.manual.manual_evaluater

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional
import importlib
import sys
import string
import os

from overrides import overrides

import torch
from torch import nn

from overrides import overrides, EnforceOverrides

from archai.common.trainer import Trainer
from archai.common.config import Config
from archai.common.common import logger
from archai.datasets import data
from archai.nas.model_desc import ModelDesc
from archai.nas.model_desc_builder import ModelDescBuilder
from archai.nas import nas_utils
from archai.common import ml_utils, utils
from archai.common.metrics import EpochMetrics, Metrics
from archai.nas.model import Model
from archai.common.checkpoint import CheckPoint
from archai.nas.evaluater import Evaluater



[docs]class ManualEvaluater(Evaluater):
[docs] @overrides def create_model(self, conf_eval:Config, model_desc_builder:ModelDescBuilder, final_desc_filename=None, full_desc_filename=None)->nn.Module: # region conf vars dataset_name = conf_eval['loader']['dataset']['name'] # if explicitly passed in then don't get from conf if not final_desc_filename: final_desc_filename = conf_eval['final_desc_filename'] model_factory_spec = conf_eval['model_factory_spec'] # endregion assert model_factory_spec return self._model_from_factory(model_factory_spec, dataset_name)
def _model_from_factory(self, model_factory_spec:str, dataset_name:str)->Model: splitted = model_factory_spec.rsplit('.', 1) function_name = splitted[-1] if len(splitted) > 1: module_name = splitted[0] else: module_name = self._default_module_name(dataset_name, function_name) module = importlib.import_module(module_name) if module_name else sys.modules[__name__] function = getattr(module, function_name) model = function() logger.info({'model_factory':True, 'module_name': module_name, 'function_name': function_name, 'params': ml_utils.param_size(model)}) return model