import os
import torch


def load_model(args, save_file):
    if args.max_steps:
        path = os.path.join(
            args.model_dir,
            args.model_type,
            save_file,
            "checkpoint-{}".format(args.max_steps),
            "model.pt",
        )
    else:
        root = os.path.join(args.model_dir, args.model_type, save_file)
        for item in os.listdir(root):
            if os.path.isdir(os.path.join(root, item)):
                path = os.path.join(
                    root,
                    item,
                    "model.pt",
                )
    assert os.path.exists(path), "Path is not exist"
    model = torch.load(path)
    return model, path
