from .deiqt_cl_topl2 import build_vit


def build_model(config, args):
    model_type = config.MODEL.TYPE
    
    if model_type == "purevit":
        model = build_vit(
            img_size=[224,384],
            patch_size=config.MODEL.VIT.PATCH_SIZE,
            embed_dim=config.MODEL.VIT.EMBED_DIM,
            depth=config.MODEL.VIT.DEPTH,
            num_heads=config.MODEL.VIT.NUM_HEADS,
            mlp_ratio=config.MODEL.VIT.MLP_RATIO,
            qkv_bias=config.MODEL.VIT.QKV_BIAS,
            pretrained=True,
            pretrained_model_path=config.MODEL.VIT.PRETRAINED_MODEL_PATH,
            lda = args.lda,
        )
    else:
        raise NotImplementedError(f"Unkown model: {model_type}")

    return model
