# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterable, Tuple, Optional, Any, List
from collections import OrderedDict
import numpy as np
import yaml
import os
import torch
from torch import nn, Tensor
from overrides import overrides
from archai.nas.arch_params import ArchParams
from archai.nas.cell import Cell
from archai.nas.operations import Op, DropPath_
from archai.nas.model_desc import ModelDesc, AuxTowerDesc, CellDesc
from archai.common.common import logger
from archai.common import utils, ml_utils
from archai.nas.arch_module import ArchModule
[docs]class Model(ArchModule):
def __init__(self, model_desc:ModelDesc, droppath:bool, affine:bool):
super().__init__()
# some of these fields are public as finalizer needs access to them
self.desc = model_desc
# TODO: support any number of stems
assert len(model_desc.model_stems)==2, "Model compiler currently only supports 2 stems"
stem0_op = Op.create(model_desc.model_stems[0], affine=affine)
stem1_op = Op.create(model_desc.model_stems[1], affine=affine)
self.model_stems = nn.ModuleList((stem0_op, stem1_op))
self.cells = nn.ModuleList()
self._aux_towers = nn.ModuleList()
for i, (cell_desc, aux_tower_desc) in \
enumerate(zip(model_desc.cell_descs(), model_desc.aux_tower_descs)):
self._build_cell(cell_desc, aux_tower_desc, droppath, affine)
# adaptive pooling output size to 1x1
self.pool_op = Op.create(model_desc.pool_op, affine=affine)
# since ch_p records last cell's output channels
# it indicates the input channel number
self.logits_op = Op.create(model_desc.logits_op, affine=affine)
# for i,cell in enumerate(self.cells):
# print(i, ml_utils.param_size(cell))
#logger.info({'model_summary': self.summary()})
def _build_cell(self, cell_desc:CellDesc,
aux_tower_desc:Optional[AuxTowerDesc],
droppath:bool, affine:bool)->None:
trainables_from = None if cell_desc.trainables_from==cell_desc.id \
else self.cells[cell_desc.trainables_from]
cell = Cell(cell_desc, affine=affine, droppath=droppath,
trainables_from=trainables_from)
self.cells.append(cell)
self._aux_towers.append(AuxTower(aux_tower_desc) \
if aux_tower_desc else None)
[docs] def summary(self)->dict:
all_arch_params = list(self.all_owned()
.param_by_kind(kind=None))
return {
'cell_count': len(self.cells),
#'cell_params': [ml_utils.param_size(c) for c in self.cells]
'params': ml_utils.param_size(self),
'arch_params_len': len(all_arch_params),
'arch_params_numel': np.sum(a.numel() for a in all_arch_params),
'ops': np.sum(len(n.edges) for c in self.desc.cell_descs() for n in c.nodes()),
}
[docs] def ops(self)->Iterable[Op]:
for cell in self.cells:
for op in cell.ops():
yield op
[docs] @overrides
def forward(self, x)->Tuple[Tensor, Optional[Tensor]]:
#print(torch.cuda.memory_allocated()/1.0e6)
s0 = self.model_stems[0](x)
#print(torch.cuda.memory_allocated()/1.0e6)
s1 = self.model_stems[1](x)
#print(-1, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
logits_aux = None
for ci, (cell, aux_tower) in enumerate(zip(self.cells, self._aux_towers)):
#print(s0.shape, s1.shape, end='')
s0, s1 = s1, cell.forward(s0, s1)
#print(ci, s0.shape, s1.shape, torch.cuda.memory_allocated()/1.0e6)
# TODO: this mimics darts but won't work for multiple aux towers
if aux_tower is not None and self.training:
logits_aux = aux_tower(s1)
#print(ci, 'aux', logits_aux.shape)
# s1 is now the last cell's output
out = self.pool_op(s1)
logits = self.logits_op(out) # flatten
#print(-1, 'out', out.shape)
#print(-1, 'logits', logits.shape)
return logits, logits_aux
[docs] def device_type(self)->str:
return next(self.parameters()).device.type
[docs] def drop_path_prob(self, p:float):
""" Set drop path probability
This will be called externally so any DropPath_ modules get
new probability. Typically, every epoch we will reduce this probability.
"""
for module in self.modules():
if isinstance(module, DropPath_):
module.p = p
[docs]class AuxTower(nn.Module):
def __init__(self, aux_tower_desc:AuxTowerDesc):
"""assuming input size 14x14"""
# TODO: assert input size?
super().__init__()
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=aux_tower_desc.stride, padding=0, count_include_pad=False),
nn.Conv2d(aux_tower_desc.ch_in, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
# TODO: This batchnorm was omitted in orginal implementation due to a typo.
nn.BatchNorm2d(768),
nn.ReLU(inplace=True),
)
self.logits_op = nn.Linear(768, aux_tower_desc.n_classes)
[docs] def forward(self, x:torch.Tensor):
x = self.features(x)
x = self.logits_op(x.view(x.size(0), -1))
return x