import os

import torch


def save_model_weights(model, filename):
    if (not os.path.exists(os.path.dirname(filename))):
        os.makedirs(os.path.dirname(filename))
    torch.save(model.state_dict(), filename)


def load_model_weights(model, filename):
    model.load_state_dict(torch.load(filename))
    model.eval()


def save_model(model, filename):
    if (not os.path.exists(os.path.dirname(filename))):
        os.makedirs(os.path.dirname(filename))
    torch.save(model, filename)


def load_model(filename):
    return torch.load(filename)
