import torch
import shutil
import os
import torchvision.utils as tvu


def save_image(img, file_directory):
    if not os.path.exists(os.path.dirname(file_directory)):
        os.makedirs(os.path.dirname(file_directory))
    tvu.save_image(img, file_directory)


def save_checkpoint(state, filename):
    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    torch.save(state, filename + '.pth')


def load_checkpoint(path, device):
    if device is None:
        return torch.load(path, map_location=torch.device('cpu'))
    else:
        return torch.load(path, map_location=torch.device('cpu'))