# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
import os
from typing import List, Iterable, Union, Optional, Tuple
import atexit
import subprocess
import datetime
import yaml
import sys
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
import psutil
from .config import Config
from . import utils
from .ordereddict_logger import OrderedDictLogger
from .apex_utils import ApexUtils
from send2trash import send2trash
[docs]class SummaryWriterDummy:
def __init__(self, log_dir):
pass
[docs] def add_scalar(self, *args, **kwargs):
pass
SummaryWriterAny = Union[SummaryWriterDummy, SummaryWriter]
logger = OrderedDictLogger(None, None, yaml_log=False)
_tb_writer: SummaryWriterAny = None
_atexit_reg = False # is hook for atexit registered?
[docs]def get_conf(conf:Optional[Config]=None)->Config:
if conf is not None:
return conf
return Config.get_inst()
[docs]def get_conf_common(conf:Optional[Config]=None)->Config:
return get_conf(conf)['common']
[docs]def get_conf_dataset(conf:Optional[Config]=None)->Config:
return get_conf(conf)['dataset']
[docs]def get_experiment_name(conf:Optional[Config]=None)->str:
return get_conf_common(conf)['experiment_name']
[docs]def get_expdir(conf:Optional[Config]=None)->Optional[str]:
return get_conf_common(conf)['expdir']
[docs]def get_datadir(conf:Optional[Config]=None)->Optional[str]:
return get_conf(conf)['dataset']['dataroot']
[docs]def get_logger() -> OrderedDictLogger:
global logger
if logger is None:
raise RuntimeError('get_logger call made before logger was setup!')
return logger
[docs]def get_tb_writer() -> SummaryWriterAny:
global _tb_writer
return _tb_writer
[docs]class CommonState:
def __init__(self) -> None:
global logger, _tb_writer
self.logger = logger
self.tb_writer = _tb_writer
self.conf = get_conf()
[docs]def on_app_exit():
print('Process exit:', os.getpid(), flush=True)
writer = get_tb_writer()
writer.flush()
if isinstance(logger, OrderedDictLogger):
logger.close()
def _pt_dirs()->Tuple[str, str]:
# dirs for pt infrastructure are supplied in env vars
pt_data_dir = os.environ.get('PT_DATA_DIR', '')
# currently yaml should be copying dataset folder to local dir
# so below is not needed. The hope is that less reads from cloud
# storage will reduce overall latency.
# if pt_data_dir:
# param_args = ['--nas.eval.loader.dataset.dataroot', pt_data_dir,
# '--nas.search.loader.dataset.dataroot', pt_data_dir,
# '--nas.search.seed_train.loader.dataset.dataroot', pt_data_dir,
# '--nas.search.post_train.loader.dataset.dataroot', pt_data_dir,
# '--autoaug.loader.dataset.dataroot', pt_data_dir] + param_args
pt_output_dir = os.environ.get('PT_OUTPUT_DIR', '')
return pt_data_dir, pt_output_dir
def _pt_params(param_args: list)->list:
pt_data_dir, pt_output_dir = _pt_dirs()
if pt_output_dir:
# prepend so if supplied from outside it takes back seat
param_args = ['--common.logdir', pt_output_dir] + param_args
return param_args
[docs]def get_state()->CommonState:
return CommonState()
[docs]def init_from(state:CommonState, recreate_logger=True)->None:
global logger, _tb_writer
Config.set_inst(state.conf)
if recreate_logger:
create_logger(state.conf)
else:
logger = state.logger
logger.info({'common_init_from_state': True})
_tb_writer = state.tb_writer
[docs]def create_conf(config_filepath: Optional[str]=None,
param_args: list = [], use_args=True)->Config:
# modify passed args for pt infrastructure
# if pt infrastructure doesn't exit then param_overrides == param_args
param_overrides = _pt_params(param_args)
conf = Config(config_filepath=config_filepath,
param_args=param_overrides,
use_args=use_args)
_update_conf(conf)
return conf
# TODO: rename this simply as init
# initializes random number gen, debugging etc
[docs]def common_init(config_filepath: Optional[str]=None,
param_args: list = [], use_args=True,
clean_expdir=False)->Config:
if not utils.is_main_process():
raise RuntimeError('common_init should not be called from child process. Please use Common.init_from()')
conf = create_conf(config_filepath, param_args, use_args)
# setup global instance
Config.set_inst(conf)
# setup env vars which might be used in paths
update_envvars(conf)
# create experiment dir
create_dirs(conf, clean_expdir)
# create global logger
create_logger(conf)
_create_sysinfo(conf)
# create a[ex to know distributed processing paramters
conf_apex = get_conf_common(conf)['apex']
apex = ApexUtils(conf_apex, logger=logger)
# setup tensorboard
global _tb_writer
_tb_writer = create_tb_writer(conf, apex.is_master())
# create hooks to execute code when script exits
global _atexit_reg
if not _atexit_reg:
atexit.register(on_app_exit)
_atexit_reg = True
return conf
def _create_sysinfo(conf:Config)->None:
expdir = get_expdir(conf)
if expdir and not utils.is_debugging():
# copy net config to experiment folder for reference
with open(expdir_abspath('config_used.yaml'), 'w') as f:
yaml.dump(conf.to_dict(), f)
if not utils.is_debugging():
sysinfo_filepath = expdir_abspath('sysinfo.txt')
subprocess.Popen([f'./scripts/sysinfo.sh "{expdir}" > "{sysinfo_filepath}"'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
shell=True)
[docs]def expdir_abspath(path:str, create=False)->str:
"""Returns full path for given relative path within experiment directory."""
return utils.full_path(os.path.join('$expdir',path), create=create)
[docs]def create_tb_writer(conf:Config, is_master=True)-> SummaryWriterAny:
conf_common = get_conf_common(conf)
tb_dir, conf_enable_tb = utils.full_path(conf_common['tb_dir']), conf_common['tb_enable']
tb_enable = conf_enable_tb and is_master and tb_dir is not None and len(tb_dir) > 0
logger.info({'conf_enable_tb': conf_enable_tb,
'tb_enable': tb_enable,
'tb_dir': tb_dir})
WriterClass = SummaryWriter if tb_enable else SummaryWriterDummy
return WriterClass(log_dir=tb_dir)
def _update_conf(conf:Config)->None:
conf_common = get_conf_common(conf)
conf_dataset = get_conf_dataset(conf)
experiment_name = conf_common['experiment_name']
# make sure dataroot exists
dataroot = utils.full_path(conf_dataset['dataroot'])
# make sure logdir and expdir exists
logdir = conf_common['logdir']
if logdir:
logdir = utils.full_path(logdir)
expdir = os.path.join(logdir, experiment_name)
# directory for non-master replica logs
distdir = os.path.join(expdir, 'dist')
else:
expdir = distdir = logdir
# update conf so everyone gets expanded full paths from here on
# set environment variable so it can be referenced in paths used in config
conf_common['logdir'] = logdir
conf_dataset['dataroot'] = dataroot
conf_common['expdir'] = expdir
conf_common['distdir'] = distdir
[docs]def update_envvars(conf)->None:
conf_common = get_conf_common(conf)
logdir = conf_common['logdir']
expdir = conf_common['expdir']
distdir = conf_common['distdir']
conf_dataset = get_conf_dataset(conf)
dataroot = conf_dataset['dataroot']
# update conf so everyone gets expanded full paths from here on
# set environment variable so it can be referenced in paths used in config
os.environ['logdir'] = logdir
os.environ['dataroot'] = dataroot
os.environ['expdir'] = expdir
os.environ['distdir'] = distdir
[docs]def clean_ensure_expdir(conf:Optional[Config], clean_dir:bool, ensure_dir:bool)->None:
expdir = get_expdir(conf)
if clean_dir and os.path.exists(expdir):
send2trash(expdir)
if ensure_dir:
os.makedirs(expdir, exist_ok=True)
[docs]def create_dirs(conf:Config, clean_expdir:bool)->Optional[str]:
conf_common = get_conf_common(conf)
logdir = conf_common['logdir']
expdir = conf_common['expdir']
distdir = conf_common['distdir']
conf_dataset = get_conf_dataset(conf)
dataroot = conf_dataset['dataroot']
# make sure dataroot exists
os.makedirs(dataroot, exist_ok=True)
# make sure logdir and expdir exists
if logdir:
clean_ensure_expdir(conf, clean_dir=clean_expdir, ensure_dir=True)
os.makedirs(distdir, exist_ok=True)
else:
raise RuntimeError('The logdir setting must be specified for the output directory in yaml')
# get cloud dirs if any
pt_data_dir, pt_output_dir = _pt_dirs()
# validate dirs
assert not pt_output_dir or not expdir.startswith(utils.full_path('~/logdir'))
logger.info({'expdir': expdir,
# create info file for current system
'PT_DATA_DIR': pt_data_dir, 'PT_OUTPUT_DIR': pt_output_dir})
[docs]def create_logger(conf:Config):
global logger
logger.close() # close any previous instances
conf_common = get_conf_common(conf)
expdir = conf_common['expdir']
distdir = conf_common['distdir']
log_prefix = conf_common['log_prefix']
yaml_log = conf_common['yaml_log']
log_level = conf_common['log_level']
if utils.is_main_process():
logdir, log_suffix = expdir, ''
else:
logdir, log_suffix = distdir, '_' + str(os.getpid())
# ensure folders
os.makedirs(logdir, exist_ok=True)
# file where logger would log messages
sys_log_filepath = utils.full_path(os.path.join(logdir, f'{log_prefix}{log_suffix}.log'))
logs_yaml_filepath = utils.full_path(os.path.join(logdir, f'{log_prefix}{log_suffix}.yaml'))
experiment_name = get_experiment_name(conf) + log_suffix
#print(f'experiment_name={experiment_name}, log_stdout={sys_log_filepath}, log_file={sys_log_filepath}')
sys_logger = utils.create_logger(filepath=sys_log_filepath,
name=experiment_name, level=log_level,
enable_stdout=True)
if not sys_log_filepath:
sys_logger.warn(
'log_prefix not specified, logs will be stdout only')
# reset to new file path
logger.reset(logs_yaml_filepath, sys_logger, yaml_log=yaml_log,
backup_existing_file=False)
logger.info({'command_line': ' '.join(sys.argv) if utils.is_main_process() else f'Child process: {utils.process_name()}-{os.getpid()}'})
logger.info({'process_name': utils.process_name(), 'is_main_process': utils.is_main_process(),
'main_process_pid':utils.main_process_pid(), 'pid':os.getpid(), 'ppid':os.getppid(), 'is_debugging': utils.is_debugging()})
logger.info({'experiment_name': experiment_name, 'datetime:': datetime.datetime.now()})
logger.info({'logs_yaml_filepath': logs_yaml_filepath, 'sys_log_filepath': sys_log_filepath})