import argparse
import os
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from utils import ROOT_PATH
import torch
from torch import nn
from Normalize import Normalize, TfNormalize
MODEL_NAMES = ['vit_base_patch16_224', # ViT-B/16
                'resnet50'
               ]


def get_model(model_name):
        if model_name in MODEL_NAMES:
                model = create_model(
                        model_name,
                        pretrained=True,
                        num_classes=1000,
                        in_chans=3,
                        global_pool=None,
                        scriptable=False)
        # elif model_name in add_models.keys():

        print (f'Loading Model {model_name}.')
        return model