import torch
import os


def save_model(model, save_model_path):
    torch.save(model.state_dict(), save_model_path)


def load_model(model, save_model_path, strict_load=True):
    loaded_file = torch.load(save_model_path)
    if 'state_dict' in loaded_file.keys():
        loaded_file = loaded_file['state_dict']
    model.load_state_dict(loaded_file, strict_load)
    print(f'model loaded from {save_model_path}')
    return model


def save_checkpoint(model, directory, filename='ckpt.pth'):
    filename = os.path.join(directory, filename)
    save_model(model, filename)