from .resnet18 import OurModel as OurResnet
from .mlp import OurModel as OurMLP
from .modelnet import OurModel as OurModelnet

models_dict = {
                "ours": {
                        "resnet18": OurResnet,
                        "modelnet": OurModelnet,
                        "mlp": OurMLP
                        },
                "ours_mnar": {
                        "resnet18": OurResnet,
                        "modelnet": OurModelnet,
                        "mlp": OurMLP
                        },
                }


def get_model(method, model_name, dataset, args, config):
    try:
        Model = models_dict[method][model_name]
        if method in ["ours", "ours_mnar"]:
            return [Model(dataset, args.num_clients, args.cuda_id, args.ours_agg).to(args.device)]
        else:
            raise ValueError(f"Unknown method: {method}.")
    except KeyError:
        raise ValueError(f"Unknown model name ({model_name}) or method name ({method})")
