# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Optional, Sequence, Tuple, List
import os
import argparse
import torch
from torch.optim.optimizer import Optimizer
from torch import Tensor, nn
from torch.backends import cudnn
import torch.distributed as dist
import ray
import psutil
from archai.common.config import Config
from archai.common import ml_utils, utils
from archai.common.ordereddict_logger import OrderedDictLogger
from archai.common.multi_optim import MultiOptim
[docs]class ApexUtils:
def __init__(self, apex_config:Config, logger:Optional[OrderedDictLogger])->None:
# region conf vars
self._enabled = apex_config['enabled'] # global switch to disable anything apex
self._distributed_enabled = apex_config['distributed_enabled'] # enable/disable distributed mode
self._mixed_prec_enabled = apex_config['mixed_prec_enabled'] # enable/disable distributed mode
self._opt_level = apex_config['opt_level'] # optimization level for mixed precision
self._bn_fp32 = apex_config['bn_fp32'] # keep BN in fp32
self._loss_scale = apex_config['loss_scale'] # loss scaling mode for mixed prec
self._sync_bn = apex_config['sync_bn'] # should be replace BNs with sync BNs for distributed model
self._scale_lr = apex_config['scale_lr'] # enable/disable distributed mode
self._min_world_size = apex_config['min_world_size'] # allows to confirm we are indeed in distributed setting
seed = apex_config['seed']
detect_anomaly = apex_config['detect_anomaly']
conf_gpu_ids = apex_config['gpus']
conf_ray = apex_config['ray']
self.ray_enabled = conf_ray['enabled']
self.ray_local_mode = conf_ray['local_mode']
# endregion
# to avoid circular references= with common, logger is passed from outside
self.logger = logger
# defaults for non-distributed mode
self._amp, self._ddp = None, None
self._set_ranks(conf_gpu_ids)
#_log_info({'apex_config': apex_config.to_dict()})
self._log_info({'ray.enabled': self.is_ray(), 'apex.enabled': self._enabled})
self._log_info({'torch.distributed.is_available': dist.is_available(),
'apex.distributed_enabled': self._distributed_enabled,
'apex.mixed_prec_enabled': self._mixed_prec_enabled})
if dist.is_available():
# dist.* properties are otherwise not accessible
self._op_map = {'mean': dist.ReduceOp.SUM, 'sum': dist.ReduceOp.SUM,
'min': dist.ReduceOp.MIN, 'max': dist.ReduceOp.MAX}
self._log_info({'gloo_available': dist.is_gloo_available(),
'mpi_available': dist.is_mpi_available(),
'nccl_available': dist.is_nccl_available()})
if self.is_mixed():
# init enable mixed precision
assert cudnn.enabled, "Amp requires cudnn backend to be enabled."
from apex import amp
self._amp = amp
# enable distributed processing
if self.is_dist():
assert not self.is_ray(), "Ray is not yet enabled for Apex distributed mode"
from apex import parallel
self._ddp = parallel
assert dist.is_available() # distributed module is available
assert dist.is_nccl_available()
if not dist.is_initialized():
dist.init_process_group(backend='nccl', init_method='env://')
assert dist.is_initialized()
assert dist.get_world_size() == self.world_size
assert dist.get_rank() == self.global_rank
if self.is_ray():
assert not self.is_dist(), "Ray is not yet enabled for Apex distributed mode"
import ray
if not ray.is_initialized():
ray.init(local_mode=self.ray_local_mode, include_dashboard=False,
# for some reason Ray is detecting wrong number of GPUs
num_gpus=torch.cuda.device_count())
ray_cpus = ray.nodes()[0]['Resources']['CPU']
ray_gpus = ray.nodes()[0]['Resources']['GPU']
self._log_info({'ray_cpus': ray_cpus, 'ray_gpus':ray_gpus})
assert self.world_size >= 1
assert not self._min_world_size or self.world_size >= self._min_world_size
assert self.local_rank >= 0 and self.local_rank < self.world_size
assert self.global_rank >= 0 and self.global_rank < self.world_size
assert self._gpu < torch.cuda.device_count()
torch.cuda.set_device(self._gpu)
self.device = torch.device('cuda', self._gpu)
self._setup_gpus(seed, detect_anomaly)
self._log_info({'amp_available': self._amp is not None,
'distributed_available': self._ddp is not None})
self._log_info({'dist_initialized': dist.is_initialized() if dist.is_available() else False,
'world_size': self.world_size,
'gpu': self._gpu, 'gpu_ids':self.gpu_ids,
'local_rank': self.local_rank,
'global_rank': self.global_rank})
def _setup_gpus(self, seed:float, detect_anomaly:bool):
utils.setup_cuda(seed, self.local_rank)
torch.autograd.set_detect_anomaly(detect_anomaly)
self._log_info({'set_detect_anomaly': detect_anomaly,
'is_anomaly_enabled': torch.is_anomaly_enabled()})
self._log_info({'gpu_names': utils.cuda_device_names(),
'gpu_count': torch.cuda.device_count(),
'CUDA_VISIBLE_DEVICES': os.environ['CUDA_VISIBLE_DEVICES']
if 'CUDA_VISIBLE_DEVICES' in os.environ else 'NotSet',
'cudnn.enabled': cudnn.enabled,
'cudnn.benchmark': cudnn.benchmark,
'cudnn.deterministic': cudnn.deterministic,
'cudnn.version': cudnn.version()
})
self._log_info({'memory': str(psutil.virtual_memory())})
self._log_info({'CPUs': str(psutil.cpu_count())})
# gpu_usage = os.popen(
# 'nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader'
# ).read().split('\n')
# for i, line in enumerate(gpu_usage):
# vals = line.split(',')
# if len(vals) == 2:
# _log_info('GPU {} mem: {}, used: {}'.format(i, vals[0], vals[1]))
def _set_ranks(self, conf_gpu_ids:str)->None:
# this function needs to work even when torch.distributed is not available
if 'WORLD_SIZE' in os.environ:
self.world_size = int(os.environ['WORLD_SIZE'])
else:
self.world_size = 1
if 'LOCAL_RANK' in os.environ:
self.local_rank = int(os.environ['LOCAL_RANK'])
else:
self.local_rank = 0
if 'RANK' in os.environ:
self.global_rank = int(os.environ['RANK'])
else:
self.global_rank = 0
assert self.local_rank < torch.cuda.device_count(), \
f'local_rank={self.local_rank} but device_count={torch.cuda.device_count()}' \
' Possible cause may be Pytorch is not GPU enabled or you have too few GPUs'
self.gpu_ids = [int(i) for i in conf_gpu_ids.split(',') if i]
# which GPU to use, we will use only 1 GPU per process to avoid complications with apex
# remap if GPU IDs are specified
if len(self.gpu_ids):
assert len(self.gpu_ids) > self.local_rank
self._gpu = self.gpu_ids[self.local_rank]
else:
self._gpu = self.local_rank % torch.cuda.device_count()
[docs] def is_mixed(self)->bool:
return self._enabled and self._mixed_prec_enabled
[docs] def is_dist(self)->bool:
return self._enabled and self._distributed_enabled
[docs] def is_master(self)->bool:
return self.global_rank == 0
[docs] def is_ray(self)->bool:
return self.ray_enabled
def _log_info(self, d:dict)->None:
if self.logger is not None:
self.logger.info(d)
[docs] def sync_devices(self)->None:
if self.is_dist():
torch.cuda.synchronize(self.device)
[docs] def barrier(self)->None:
if self.is_dist():
dist.barrier() # wait for all processes to come to this point
[docs] def reduce(self, val, op='mean'):
if self.is_dist():
if not isinstance(val, Tensor):
rt = torch.tensor(val).to(self.device)
converted = True
else:
rt = val.clone().to(self.device)
converted = False
r_op = self._op_map[op]
dist.all_reduce(rt, op=r_op)
if op=='mean':
rt /= self.world_size
if converted and len(rt.shape)==0:
return rt.item()
return rt
else:
return val
def _get_optim(self, multi_optim:MultiOptim)->Optimizer:
assert len(multi_optim)==1, \
'Mixed precision is only supported for one optimizer' \
f' but {len(multi_optim)} optimizers were supplied'
return multi_optim[0].optim
[docs] def backward(self, loss:torch.Tensor, multi_optim:MultiOptim)->None:
if self.is_mixed():
optim = self._get_optim(multi_optim)
with self._amp.scale_loss(loss, optim) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
[docs] def to_amp(self, model:nn.Module, multi_optim:MultiOptim, batch_size:int)\
->nn.Module:
# conver BNs to sync BNs in distributed mode
if self.is_dist() and self._sync_bn:
model = self._ddp.convert_syncbn_model(model)
self._log_info({'BNs_converted': True})
model = model.to(self.device)
if self.is_mixed():
optim = self._get_optim(multi_optim)
# scale LR
if self.is_dist() and self._scale_lr:
lr = ml_utils.get_optim_lr(optim)
scaled_lr = lr * self.world_size / float(batch_size)
ml_utils.set_optim_lr(optim, scaled_lr)
self._log_info({'lr_scaled': True, 'old_lr': lr, 'new_lr': scaled_lr})
model, optim = self._amp.initialize(
model, optim, opt_level=self._opt_level,
keep_batchnorm_fp32=self._bn_fp32, loss_scale=self._loss_scale
)
# put back amp'd optim
multi_optim[0].optim = optim
if self.is_dist():
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# delay_allreduce delays all communication to the end of the backward pass.
model = self._ddp.DistributedDataParallel(model, delay_allreduce=True)
return model
[docs] def clip_grad(self, clip:float, model:nn.Module, multi_optim:MultiOptim)->None:
if clip > 0.0:
if self.is_mixed():
optim = self._get_optim(multi_optim)
nn.utils.clip_grad_norm_(self._amp.master_params(optim), clip)
else:
nn.utils.clip_grad_norm_(model.parameters(), clip)
[docs] def state_dict(self):
if self.is_mixed():
return self._amp.state_dict()
else:
return None
[docs] def load_state_dict(self, state_dict):
if self.is_mixed():
self._amp.load_state_dict(state_dict)