import torch
import os

def load_best_model(args, save_file):
    path = os.path.join(
        save_file,
        "best_model",
        "model.pt",
    )
    assert os.path.exists(path), "{} is not exist".format(path)
    model = torch.load(path)
    return model, path

def load_model(args, save_file):
    path = os.path.join(
        save_file,
        "checkpoint-{}".format(args.t_total),
        "model.pt",
    )
    assert os.path.exists(path), "Path is not exist"
    model = torch.load(path)
    return model, path