import torch
import contextlib

# happened to get `RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED` once
# https://stackoverflow.com/questions/66588715/runtimeerror-cudnn-error-cudnn-status-not-initialized-using-pytorch
def force_cudnn_initialization():
    s = 32
    dev = torch.device('cuda')
    torch.nn.functional.conv2d(torch.zeros(s, s, s, s, device=dev), torch.zeros(s, s, s, s, device=dev))

# write the memory report from the following library onto a file rather than stdout
# https://github.com/Stonesjtu/pytorch_memlab
def memreport_to_file(reporter, fname):
    with open(fname, "w") as f:
        with contextlib.redirect_stdout(f):
            reporter.report()
