import torch

# from .mi_estimator import *
from .full_model import FullModel


def get_model(args, finetune=False, from_scratch=False, include_top=False):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    return FullModel(args, args.model_name, 1000, args.n_layers, args.reversible, args.residual,
                     args.n_nonlinear_transform_blocks, args.feature_cache_dir, finetune=finetune,
                     from_scratch=from_scratch, include_top=include_top, device=device)
