Module utils.checkpointing
Expand source code
#!/usr/bin/env python3
import os
# Import PyTorch root package import torch
import torch
import shutil
import json
import copy
from utils.utils import get_model_str_from_obj
from utils.logger import Logger
from utils.utils import create_model_dir, create_metrics_dict
from utils import execution_context
import utils.model_funcs
import opts
from models import mutils
import numpy as np
def save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim):
"""
Persist checkpoint to disk.
Args:
model: Computation model
filename: Filename to persist model by
args: training setup
is_best: Whether model with best metric
metrics: metrics obtained from evaluation
metric_to_optim: metric to optimize, e.g. top 1 accuracy
"""
# Ignore saving checkpoints
if args.do_not_save_eval_checkpoints:
return
logger = Logger.get(args.run_id)
result_text = f"avg_loss={metrics['loss'].get_avg()}," \
f" avg_{metric_to_optim}={metrics[metric_to_optim].get_avg()}"
metrics_dict = create_metrics_dict(metrics)
# This is to persist the model without DataParallel wrapping. See also: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
if get_model_str_from_obj(model) == "DataParallel":
state = model.module
else:
state = model
model_dir = create_model_dir(args)
model_filename = os.path.join(model_dir, filename)
result_filename = os.path.join(model_dir, 'results.txt')
latest_filename = os.path.join(model_dir, 'latest.txt')
best_filename = os.path.join(model_dir, 'model_best.pth.tar')
last_filename = os.path.join(model_dir, 'model_last.pth.tar')
best_metric_filename = os.path.join(model_dir, 'best_metrics.json')
last_metric_filename = os.path.join(model_dir, 'last_metrics.json')
if not os.path.isdir(model_dir):
os.makedirs(model_dir, exist_ok=True)
logger.info("Saving checkpoint '{}'".format(model_filename))
with open(result_filename, 'a') as fout:
fout.write(result_text)
# Save entire model
torch.save(state, model_filename)
with open(latest_filename, 'w') as fout:
fout.write(model_filename)
save_dict_to_json(metrics_dict, last_metric_filename)
if is_best:
logger.info("Found new best.")
shutil.copyfile(model_filename, best_filename)
shutil.copyfile(last_metric_filename, best_metric_filename)
logger.info("Copying to {}".format(last_filename))
shutil.copyfile(model_filename, last_filename)
# Remove all previous checkpoints except for the best one for storage savings
logger.info("Removing redundant files")
files_to_keep = ['model_last.pth.tar', 'model_best.pth.tar',
'results.txt', 'latest.txt', 'args.json',
'last_metrics.json', 'best_metrics.json']
files_to_delete = [file for file in os.listdir(model_dir) if file not in files_to_keep]
for f in files_to_delete:
if not os.path.isdir(f):
os.remove("{}/{}".format(model_dir, f))
return
def save_checkpoint_worker(threadCtx, model, model_params, filename, args, metrics, metric_to_optim, current_round, exec_ctx):
"""
Worker routine dedicated for check is model is best and serialize it into filesystem.
Args:
threadCtx(ThreadPool): thread context (can be None)
model(torch.nn.Moduel): model instance which can be used for serialization (read, write)
model_params(torch.Tensor): model parameters
filename(str): filename with place where model should be serialized
args: command linea arguments
metrics: evaluate metrics for the model with model_params
metric_to_optim(str): metric which is used for optimization
current_round(int): current round
Returns:
True if dispatching happens fine. False if there are no threads in a thread pool or all threads have already complete their work.
"""
logger = Logger.get(args.run_id)
if threadCtx is not None:
device_used = threadCtx.device
else:
device_used = args.device
# Current logic for saving does not require GPU deivce for computation
avg_metric = metrics[metric_to_optim].get_avg()
mutils.set_params(model, model_params)
cur_metrics = {"loss": metrics["loss"].get_avg(),
"top_1_acc": metrics["top_1_acc"].get_avg(),
"top_5_acc": metrics["top_5_acc"].get_avg(),
"neq_perplexity": metrics["neq_perplexity"].get_avg()
}
is_best = avg_metric > exec_ctx.saver_thread.best_metric
if is_best:
exec_ctx.saver_thread.best_metric_lock.acquire()
exec_ctx.saver_thread.eval_metrics.update({metrics['round']: cur_metrics})
exec_ctx.saver_thread.best_metric = avg_metric
exec_ctx.saver_thread.best_metric_lock.release()
save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim)
else:
exec_ctx.saver_thread.best_metric_lock.acquire()
exec_ctx.saver_thread.eval_metrics.update({metrics['round']: cur_metrics})
exec_ctx.saver_thread.best_metric_lock.release()
save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim)
def eval_and_save_checkpoint_worker(threadCtx, model, params, testloader, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx)->bool:
logger = Logger.get(args.run_id)
if threadCtx is not None:
device_used = threadCtx.device
else:
device_used = args.device
mutils.print_current_gpu_context(device_used, args)
mutils.set_params(model, params)
metrics = model_funcs.evaluate_model(model, testloader, criterion, device_used, current_round,
print_freq=10, is_rnn=is_rnn, metric_to_optim=metric_to_optim)
avg_metric = metrics[metric_to_optim].get_avg()
# Save model checkpoint
model_filename = '{model}_{run_id}_checkpoint_{round:0>2d}.pth.tar' \
.format(model=args.model, run_id=args.run_id, round=current_round)
if exec_ctx.saver_thread.workers() > 0:
# there are dedicated workers for save
_model = exec_ctx.saver_thread.next_dispatched_thread().model_copy
fargs = (_model, params, model_filename, args, metrics, metric_to_optim, current_round, exec_ctx)
exec_ctx.saver_thread.dispatch(save_checkpoint_worker, fargs)
else:
# there are no workers for save threads
fargs = (None, model, params, model_filename, args, metrics, metric_to_optim, current_round, exec_ctx)
save_checkpoint_worker(*fargs)
if np.isnan(metrics['loss'].get_avg()):
logger.critical('NaN loss detected, aborting training procedure.')
# TODO: HANDLE PROPERLY
return True
else:
return True
def defered_eval_and_save_checkpoint(model, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx):
"""
Request for evalute models and after that serialize it asyncronously.
Args:
model(torch.nn.Module): model which is assesed (READ ONLY)
criterion(torch.nn.modules.loss): used criteria for target loss function
args: command line argument which conffigured this launch
current_round(int): number of a current communication round which has recently completed
is_rnn(bool): Flag which describes do we handle RNN model
metric_to_optim(str): Metric for optimize for validation
Returns:
"""
eval_thread = exec_ctx.eval_thread_pool.next_dispatched_thread()
_model = eval_thread.model_copy
_params = mutils.get_params(model)
_testloader = eval_thread.testloader_copy
arguments_for_call = (_model, _params, _testloader, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx)
function_for_call = eval_and_save_checkpoint_worker
exec_ctx.eval_thread_pool.dispatch(function_for_call, arguments_for_call)
def load_checkpoint(model_dir, load_best=True):
"""
Load model from checkpoint.
Args:
model_dir: Directory to read the model from.
load_best: Whether to read best or latest version of the model
Returns:
The state dictionary of the model
"""
logger = Logger.get("default")
if load_best:
model_filename = os.path.join(model_dir, 'model_best.pth.tar')
metric_filename = os.path.join(model_dir, 'best_metrics.json')
else:
model_filename = os.path.join(model_dir, 'model_last.pth.tar')
metric_filename = os.path.join(model_dir, 'last_metrics.json')
logger.info("Loading checkpoint '{}'".format(model_filename))
state = torch.load(model_filename)
logger.info("Loaded checkpoint '{}'".format(model_filename))
# read checkpoint json
with open(metric_filename, 'r') as f:
metrics = json.load(f)
return state, metrics['round']
def save_dict_to_json(d, json_path):
with open(json_path, 'w') as f:
d = {k: float(v) if (isinstance(v, float) or isinstance(v, int)) else [float(e) for e in v]
for k, v in d.items()}
json.dump(d, f, indent=4)
Functions
def defered_eval_and_save_checkpoint(model, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx)
-
Request for evalute models and after that serialize it asyncronously.
Args
- model(torch.nn.Module): model which is assesed (READ ONLY)
- criterion(torch.nn.modules.loss): used criteria for target loss function
args
- command line argument which conffigured this launch
current_round(int): number of a current communication round which has recently completed is_rnn(bool): Flag which describes do we handle RNN model metric_to_optim(str): Metric for optimize for validation Returns:
Expand source code
def defered_eval_and_save_checkpoint(model, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx): """ Request for evalute models and after that serialize it asyncronously. Args: model(torch.nn.Module): model which is assesed (READ ONLY) criterion(torch.nn.modules.loss): used criteria for target loss function args: command line argument which conffigured this launch current_round(int): number of a current communication round which has recently completed is_rnn(bool): Flag which describes do we handle RNN model metric_to_optim(str): Metric for optimize for validation Returns: """ eval_thread = exec_ctx.eval_thread_pool.next_dispatched_thread() _model = eval_thread.model_copy _params = mutils.get_params(model) _testloader = eval_thread.testloader_copy arguments_for_call = (_model, _params, _testloader, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx) function_for_call = eval_and_save_checkpoint_worker exec_ctx.eval_thread_pool.dispatch(function_for_call, arguments_for_call)
def eval_and_save_checkpoint_worker(threadCtx, model, params, testloader, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx) ‑> bool
-
Expand source code
def eval_and_save_checkpoint_worker(threadCtx, model, params, testloader, criterion, args, current_round, is_rnn, metric_to_optim, exec_ctx)->bool: logger = Logger.get(args.run_id) if threadCtx is not None: device_used = threadCtx.device else: device_used = args.device mutils.print_current_gpu_context(device_used, args) mutils.set_params(model, params) metrics = model_funcs.evaluate_model(model, testloader, criterion, device_used, current_round, print_freq=10, is_rnn=is_rnn, metric_to_optim=metric_to_optim) avg_metric = metrics[metric_to_optim].get_avg() # Save model checkpoint model_filename = '{model}_{run_id}_checkpoint_{round:0>2d}.pth.tar' \ .format(model=args.model, run_id=args.run_id, round=current_round) if exec_ctx.saver_thread.workers() > 0: # there are dedicated workers for save _model = exec_ctx.saver_thread.next_dispatched_thread().model_copy fargs = (_model, params, model_filename, args, metrics, metric_to_optim, current_round, exec_ctx) exec_ctx.saver_thread.dispatch(save_checkpoint_worker, fargs) else: # there are no workers for save threads fargs = (None, model, params, model_filename, args, metrics, metric_to_optim, current_round, exec_ctx) save_checkpoint_worker(*fargs) if np.isnan(metrics['loss'].get_avg()): logger.critical('NaN loss detected, aborting training procedure.') # TODO: HANDLE PROPERLY return True else: return True
def load_checkpoint(model_dir, load_best=True)
-
Load model from checkpoint.
Args
model_dir
- Directory to read the model from.
load_best
- Whether to read best or latest version of the model
Returns
The state dictionary of the model
Expand source code
def load_checkpoint(model_dir, load_best=True): """ Load model from checkpoint. Args: model_dir: Directory to read the model from. load_best: Whether to read best or latest version of the model Returns: The state dictionary of the model """ logger = Logger.get("default") if load_best: model_filename = os.path.join(model_dir, 'model_best.pth.tar') metric_filename = os.path.join(model_dir, 'best_metrics.json') else: model_filename = os.path.join(model_dir, 'model_last.pth.tar') metric_filename = os.path.join(model_dir, 'last_metrics.json') logger.info("Loading checkpoint '{}'".format(model_filename)) state = torch.load(model_filename) logger.info("Loaded checkpoint '{}'".format(model_filename)) # read checkpoint json with open(metric_filename, 'r') as f: metrics = json.load(f) return state, metrics['round']
def save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim)
-
Persist checkpoint to disk.
Args
model
- Computation model
filename
- Filename to persist model by
args
- training setup
is_best
- Whether model with best metric
metrics
- metrics obtained from evaluation
metric_to_optim
- metric to optimize, e.g. top 1 accuracy
Expand source code
def save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim): """ Persist checkpoint to disk. Args: model: Computation model filename: Filename to persist model by args: training setup is_best: Whether model with best metric metrics: metrics obtained from evaluation metric_to_optim: metric to optimize, e.g. top 1 accuracy """ # Ignore saving checkpoints if args.do_not_save_eval_checkpoints: return logger = Logger.get(args.run_id) result_text = f"avg_loss={metrics['loss'].get_avg()}," \ f" avg_{metric_to_optim}={metrics[metric_to_optim].get_avg()}" metrics_dict = create_metrics_dict(metrics) # This is to persist the model without DataParallel wrapping. See also: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html if get_model_str_from_obj(model) == "DataParallel": state = model.module else: state = model model_dir = create_model_dir(args) model_filename = os.path.join(model_dir, filename) result_filename = os.path.join(model_dir, 'results.txt') latest_filename = os.path.join(model_dir, 'latest.txt') best_filename = os.path.join(model_dir, 'model_best.pth.tar') last_filename = os.path.join(model_dir, 'model_last.pth.tar') best_metric_filename = os.path.join(model_dir, 'best_metrics.json') last_metric_filename = os.path.join(model_dir, 'last_metrics.json') if not os.path.isdir(model_dir): os.makedirs(model_dir, exist_ok=True) logger.info("Saving checkpoint '{}'".format(model_filename)) with open(result_filename, 'a') as fout: fout.write(result_text) # Save entire model torch.save(state, model_filename) with open(latest_filename, 'w') as fout: fout.write(model_filename) save_dict_to_json(metrics_dict, last_metric_filename) if is_best: logger.info("Found new best.") shutil.copyfile(model_filename, best_filename) shutil.copyfile(last_metric_filename, best_metric_filename) logger.info("Copying to {}".format(last_filename)) shutil.copyfile(model_filename, last_filename) # Remove all previous checkpoints except for the best one for storage savings logger.info("Removing redundant files") files_to_keep = ['model_last.pth.tar', 'model_best.pth.tar', 'results.txt', 'latest.txt', 'args.json', 'last_metrics.json', 'best_metrics.json'] files_to_delete = [file for file in os.listdir(model_dir) if file not in files_to_keep] for f in files_to_delete: if not os.path.isdir(f): os.remove("{}/{}".format(model_dir, f)) return
def save_checkpoint_worker(threadCtx, model, model_params, filename, args, metrics, metric_to_optim, current_round, exec_ctx)
-
Worker routine dedicated for check is model is best and serialize it into filesystem.
Args
- threadCtx(ThreadPool): thread context (can be None)
- model(torch.nn.Moduel): model instance which can be used for serialization (read, write)
- model_params(torch.Tensor): model parameters
- filename(str): filename with place where model should be serialized
args
- command linea arguments
metrics
- evaluate metrics for the model with model_params
metric_to_optim(str): metric which is used for optimization current_round(int): current round
Returns
True if dispatching happens fine. False if there are no threads in a thread pool or all threads have already complete their work.
Expand source code
def save_checkpoint_worker(threadCtx, model, model_params, filename, args, metrics, metric_to_optim, current_round, exec_ctx): """ Worker routine dedicated for check is model is best and serialize it into filesystem. Args: threadCtx(ThreadPool): thread context (can be None) model(torch.nn.Moduel): model instance which can be used for serialization (read, write) model_params(torch.Tensor): model parameters filename(str): filename with place where model should be serialized args: command linea arguments metrics: evaluate metrics for the model with model_params metric_to_optim(str): metric which is used for optimization current_round(int): current round Returns: True if dispatching happens fine. False if there are no threads in a thread pool or all threads have already complete their work. """ logger = Logger.get(args.run_id) if threadCtx is not None: device_used = threadCtx.device else: device_used = args.device # Current logic for saving does not require GPU deivce for computation avg_metric = metrics[metric_to_optim].get_avg() mutils.set_params(model, model_params) cur_metrics = {"loss": metrics["loss"].get_avg(), "top_1_acc": metrics["top_1_acc"].get_avg(), "top_5_acc": metrics["top_5_acc"].get_avg(), "neq_perplexity": metrics["neq_perplexity"].get_avg() } is_best = avg_metric > exec_ctx.saver_thread.best_metric if is_best: exec_ctx.saver_thread.best_metric_lock.acquire() exec_ctx.saver_thread.eval_metrics.update({metrics['round']: cur_metrics}) exec_ctx.saver_thread.best_metric = avg_metric exec_ctx.saver_thread.best_metric_lock.release() save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim) else: exec_ctx.saver_thread.best_metric_lock.acquire() exec_ctx.saver_thread.eval_metrics.update({metrics['round']: cur_metrics}) exec_ctx.saver_thread.best_metric_lock.release() save_checkpoint(model, filename, args, is_best, metrics, metric_to_optim)
def save_dict_to_json(d, json_path)
-
Expand source code
def save_dict_to_json(d, json_path): with open(json_path, 'w') as f: d = {k: float(v) if (isinstance(v, float) or isinstance(v, int)) else [float(e) for e in v] for k, v in d.items()} json.dump(d, f, indent=4)