from .swin_transformer import build_swin
from .vision_transformer import build_vit
from .vision_transformer_1 import vit_small, vit_large, vit_base, vit_tiny, vit_small_head
from .mfm import build_mfm
from .student_mfm import build_mfm_student, build_resnet
from .student_mfm_multi import build_mfm_multi_head_student
from .classifier import build_classifier


def build_model(config, is_pretrain=True, is_student=True, get_classifier=False, logger=None, model_type="none"):
    if model_type == "ibot" or model_type == "attmask":
        model = vit_small_head()
        return model

    if is_pretrain and not get_classifier:
        if config.TRAIN.STUDENT_STRATEGY:
            model = build_mfm_student(config, is_student)
        elif config.TRAIN.MULTI_HEAD_STUDENT_STRATEGY:
            model = build_mfm_multi_head_student(config, is_student, logger)
        else:
            model = build_mfm(config)
    elif not get_classifier:
        model_type = config.MODEL.TYPE
        if model_type == 'swin':
            model = build_swin(config, qkv_bias=config.MODEL.SWIN.QKV_BIAS)
        # if model_type == 'vit':
        #     model = build_vit(config)
        elif model_type == 'vit':
            model = build_vit(config, qkv_bias=config.MODEL.VIT.QKV_BIAS)
        else:
            model = build_resnet(config)
    elif get_classifier:
        model = build_classifier(config)

    return model
