import torch

def save_model_state(model, path):
    torch.save(model.state_dict(), path)
    
def load_model_state(path, target_device="cuda:0"):
    return torch.load(path, map_location=target_device)