Source code for archai.nas.arch_trainer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional, Callable, Type
import os

import torch
from torch.utils.data import DataLoader
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from overrides import overrides, EnforceOverrides

from ..common.config import Config
from ..common import common, utils
from ..nas.model import Model
from ..nas.model_desc import ModelDesc
from ..common.trainer import Trainer
from ..nas.vis_model_desc import draw_model_desc
from ..common.checkpoint import CheckPoint

TArchTrainer = Optional[Type['ArchTrainer']]

[docs]class ArchTrainer(Trainer, EnforceOverrides): def __init__(self, conf_train: Config, model: Model, checkpoint:Optional[CheckPoint]) -> None: super().__init__(conf_train, model, checkpoint) self._l1_alphas = conf_train['l1_alphas'] self._plotsdir = conf_train['plotsdir'] # if l1 regularization is needed then cache alphas if self._l1_alphas > 0.0: self._alphas = list(self.model.all_owned().param_by_kind('alphas'))
[docs] @overrides def compute_loss(self, lossfn: Callable, y: Tensor, logits: Tensor, aux_weight: float, aux_logits: Optional[Tensor]) -> Tensor: loss = super().compute_loss(lossfn, y, logits, aux_weight, aux_logits) # add L1 alpha regularization if self._l1_alphas > 0.0: l_extra = sum(torch.sum(a.abs()) for a in self._alphas) loss += self._l1_alphas * l_extra return loss
[docs] @overrides def post_epoch(self, train_dl: DataLoader, val_dl: Optional[DataLoader])->None: super().post_epoch(train_dl, val_dl) self._draw_model()
# TODO: move this outside as utility def _draw_model(self) -> None: if not self._plotsdir: return train_metrics = self.get_metrics() if train_metrics: best_train, best_val = train_metrics.run_metrics.best_epoch() is_best = best_val and best_val==train_metrics.cur_epoch() is_best = is_best or best_train==train_metrics.cur_epoch() if is_best: # log model_desc as a image plot_filepath = utils.full_path(os.path.join( self._plotsdir, f"EP{train_metrics.cur_epoch().index:03d}"), create=True) draw_model_desc(self.model.finalize(), filepath=plot_filepath, caption=f"Epoch {train_metrics.cur_epoch().index}")