# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Iterable, Mapping, Sized, Sequence
import math
import torch
import torch.nn as nn
from collections import OrderedDict
import numpy as np
from numbers import Number
[docs]def summary(model, input_size):
result, params_info = summary_string(model, input_size)
print(result)
return params_info
[docs]def is_scaler(o):
return isinstance(o, Number) or isinstance(o, str) or o is None
[docs]def get_tensor_stat(tensor):
assert isinstance(tensor, torch.Tensor)
# some pytorch low-level memory management constant
# the minimal allocate memory size (Byte)
PYTORCH_MIN_ALLOCATE = 2 ** 9
# the minimal cache memory size (Byte)
PYTORCH_MIN_CACHE = 2 ** 20
numel = tensor.numel()
element_size = tensor.element_size()
fact_numel = tensor.storage().size()
fact_memory_size = fact_numel * element_size
# since pytorch allocate at least 512 Bytes for any tensor, round
# up to a multiple of 512
memory_size = math.ceil(fact_memory_size / PYTORCH_MIN_ALLOCATE) \
* PYTORCH_MIN_ALLOCATE
# tensor.storage should be the actual object related to memory
# allocation
data_ptr = tensor.storage().data_ptr()
size = tuple(tensor.size())
# torch scalar has empty size
if not size:
size = (1,)
return ([size], numel, memory_size)
[docs]def get_all_tensor_stats(o):
if is_scaler(o):
return ([[]], 0, 0)
elif isinstance(o, torch.Tensor):
return get_tensor_stat(o)
elif isinstance(o, Mapping):
return get_all_tensor_stats(o.values())
elif isinstance(o, Iterable): # tuple, list, maps
stats = [[]], 0, 0
for oi in o:
tz = get_all_tensor_stats(oi)
stats = tuple(x+y for x,y in zip(stats, tz))
return stats
elif hasattr(o, '__dict__'):
return get_all_tensor_stats(o.__dict__)
else:
return ([[]], 0, 0)
[docs]def get_shape(o):
if is_scaler(o):
return str(o)
elif hasattr(o, 'shape'):
return f'shape{o.shape}'
elif hasattr(o, 'size'):
return f'size{o.size()}'
elif isinstance(o, Sequence):
if len(o)==0:
return 'seq[]'
elif is_scaler(o[0]):
return f'seq[{len(o)}]'
return f'seq{[get_shape(oi) for oi in o]}'
elif isinstance(o, Mapping):
if len(o)==0:
return 'map[]'
elif is_scaler(next(o)):
return f'map[{len(o)}]'
arr = [(get_shape(ki), get_shape(vi)) for ki, vi in o]
return f'map{arr}'
else:
return 'N/A'
[docs]def summary_string(model, input_size, dtype=torch.float32):
summary_str = ''
# create properties
summary = OrderedDict()
hooks = []
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = "%s-%i" % (class_name, module_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]["input"] = get_all_tensor_stats(input)
summary[m_key]["output"] = get_all_tensor_stats(output)
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size()))).item()
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size()))).item()
summary[m_key]["nb_params"] = params
if (
not isinstance(module, nn.Sequential)
and not isinstance(module, nn.ModuleList)
):
hooks.append(module.register_forward_hook(hook))
# batch_size of 2 for batchnorm
x = torch.rand(input_size, dtype=dtype,
device=next(model.parameters()).device)
# register hook
model.apply(register_hook)
# make a forward pass
# print(x.shape)
model(x)
# remove these hooks
for h in hooks:
h.remove()
summary_str += "----------------------------------------------------------------" + "\n"
line_new = "{:>20} {:>25} {:>15}".format(
"Layer (type)", "Output (elments, mem)", "Param #")
summary_str += line_new + "\n"
summary_str += "================================================================" + "\n"
total_params = 0
total_input = get_tensor_stat(x)
total_output = ([[], 0, 0])
trainable_params = 0
for layer in summary:
# input_shape, output_shape, trainable, nb_params
line_new = "{:>20} {:>25} {:>15}".format(
layer,
str(summary[layer]["output"][1:]),
"{0:,}".format(summary[layer]["nb_params"]),
)
total_params += summary[layer]["nb_params"]
total_output = tuple(x+y for x,y in zip(total_output, summary[layer]["output"]))
if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
summary_str += line_new + "\n"
total_numel = total_params + total_output[1] + total_input[1]
summary_str += "================================================================" + "\n"
summary_str += "Total params: {0:,}".format(total_params) + "\n"
summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n"
summary_str += "Non-trainable params: {0:,}".format(total_params -
trainable_params) + "\n"
summary_str += "----------------------------------------------------------------" + "\n"
summary_str += f"Input Elments: {total_input[1]:.4e}\n"
summary_str += f"Input Mem: {total_input[2]:.4e}\n"
summary_str += f"Layer Output Elements: {total_output[1]:.4e}\n"
summary_str += f"Layer Output Mem: {total_output[2]:.4e}\n"
summary_str += f"Params {total_params:.4e}\n"
summary_str += f"Total Elements {total_numel:.4e}\n"
summary_str += "----------------------------------------------------------------" + "\n"
# return summary
return summary_str, (total_params, trainable_params)