import os
import importlib
import torch
from .backbones import resnet18, resnet50, resnet18_apd, mobilenetv2
from .backbones import swin, vit, vit_apd
from .backbones import simmim, simsiam, mae
from .backbones import supervised
import re

def get_all_models():
    return [model.split('.')[0] for model in os.listdir('models')
            if not model.find('__') > -1 and 'py' in model]

def get_model(args, device, transform, logger):
    loss = torch.nn.CrossEntropyLoss()
    if args.model.backbone in ['supervised', 'simmim', 'simsiam', 'mae']:
        print('backbone info: %s'%(args.model.backbone))
        try:
            backbone = eval(f"{args.model.backbone}")(args)
        except:
            logger.info('MAE doesn not need arguments')
            backbone = eval(f"{args.model.backbone}")()
    else:
        NotImplementedError()

    names = {}
    for model in get_all_models():
        mod = importlib.import_module('models.' + model)
        class_name = {x.lower():x for x in mod.__dir__()}[model.replace('_', '')]
        names[model] = getattr(mod, class_name)

    return names[args.cl_model.lower()](backbone, loss, args, transform, logger)
