from config.base_config import Config
from model.clip_baseline import CLIPBaseline
from model.clip_transformer_txt_trunc_dm import CLIPTransformer_txt_trunc_dm

class ModelFactory:
    @staticmethod
    def get_model(config: Config):
        if config.arch == 'clip_baseline':
            return CLIPBaseline(config)
        elif config.arch == 'clip_transformer_txt_trunc_dm':
            return CLIPTransformer_txt_trunc_dm(config)

        else:
            raise NotImplemented
