from .resnet import *
from .wide_resnet import *
import torch
from autogluon.tabular import TabularDataset, TabularPredictor
import builtins
import os

_MODELS_DIR = os.path.dirname(os.path.abspath(__file__))

def load_model(model_arch, semantic=True):
    # if "cifar10" in dataset:
    #     model = RN18_10()
    # elif "blob" in dataset:
    #     path = os.path.join(builtins.ROOT_PATH, "models")
    #     model = TabularPredictor(label="target", path=path)
    # return model

    if "Res18" in model_arch:
        model_path = os.path.join(_MODELS_DIR, "checkpoint", "resnet-18.pth")
        model = RN18_10(semantic=semantic)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(torch.load(model_path))
        model.eval()

    if "WRN28" in model_arch:
        model_path = os.path.join(_MODELS_DIR, "checkpoint", "wide-resnet-28x10.pth")
        model = WRN28_10(semantic=semantic)
        model.load_state_dict(torch.load(model_path)['net'])
        model = torch.nn.DataParallel(model)
        model.eval()

    return model