models = [
    'LongSafari/hyenadna-tiny-1k-seqlen-hf',
    'LongSafari/hyenadna-tiny-1k-seqlen-d256-hf',
    'LongSafari/hyenadna-tiny-16k-seqlen-d128-hf',
    'LongSafari/hyenadna-small-32k-seqlen-hf',
    'LongSafari/hyenadna-medium-160k-seqlen-hf',
    'LongSafari/hyenadna-medium-450k-seqlen-hf',
    'LongSafari/hyenadna-large-1m-seqlen-hf',
    'anonymous8/OmniGenome-418M',
    'anonymous8/OmniGenome-186M',
    'anonymous8/OmniGenome-52M',
    'multimolecule/rnafm',
    'multimolecule/mrnafm',
    'multimolecule/utrbert-6mer',
    'multimolecule/utrbert-5mer',
    'multimolecule/utrbert-4mer',
    'multimolecule/utrbert-3mer',
    'multimolecule/splicebert',
    'multimolecule/splicebert-human.510nt',
    'multimolecule/splicebert.510nt',
    'multimolecule/rnabert',
    'multimolecule/rnamsm',
    'multimolecule/rinalmo',
    'multimolecule/ernierna',
    'multimolecule/ernierna.ss',
    'multimolecule/rnaernie',
    'multimolecule/calm',
    'Rostlab/prot_t5_xl_half_uniref50-enc',
    'Rostlab/prot_bert',
    'Rostlab/prot_t5_xl_uniref50',
    'Rostlab/ProstT5_fp16',
    'NT2B5_multi_species',
    'NT500M_multi_species_v2',
    'NT50M_multi_species_v2',
    'NT100M_multi_species_v2',
    'NT250M_multi_species_v2',
    'NT500M_human_ref',
    'NT500M_1000G',
    'NT2B5_1000G',
    'NT1B_agro_nt',
    'genslm_2.5B_patric',
    'genslm_250M_patric',
    'genslm_25M_patric',
    'evo-1-8k-base',
    'evo-1-131k-base',
    'LucaOne',
    'esm3_sm_open_v1',
    'facebook/esm2_t33_650M_UR50D',
    'facebook/esm2_t48_15B_UR50D',
    'facebook/esm2_t36_3B_UR50D',
    'facebook/esm2_t30_150M_UR50D',
    'facebook/esm2_t12_35M_UR50D',
    'facebook/esm2_t6_8M_UR50D',
    'Rostlab/prot_bert',
    'Rostlab/ProstT5',
    'Rostlab/ProstT5_fp16',
    'Rostlab/prot_t5_xl_uniref50',
    'Rostlab/prot_t5_xl_half_uniref50-enc',
    'Rostlab/prot_t5_base_mt_uniref50',
    'Rostlab/prot_bert_bfd_ss3',
    'Rostlab/prot_bert_bfd_membrane',
    'Rostlab/prot_bert_bfd_localization',
    'Rostlab/prot_t5_xxl_uniref50',
    'Rostlab/prot_electra_generator_bfd',
    'Rostlab/prot_electra_discriminator_bfd',
    'Rostlab/prot_t5_xl_bfd',
    'Rostlab/prot_bert_bfd',
    'Rostlab/prot_t5_xxl_bfd',
    'Rostlab/prot_xlnet',
    'Rostlab/prot_albert',
    'zhihan1996/DNABERT-2-117M',
    'zhihan1996/DNABERT-S',
    'zhihan1996/DNA_bert_3',
    'zhihan1996/DNA_bert_4',
    'zhihan1996/DNA_bert_5',
    'zhihan1996/DNA_bert_6',
]


def MODEL_MAP(model_name, seq_type, cfg):
    assert model_name in models, f'Model {model_name} not found in {models}.'
    if model_name in ['esm3_sm_open_v1']:
        from utils.get_embed.esm3 import ESM3Model
        return ESM3Model(model_name)
    elif model_name in ['facebook/esm2_t33_650M_UR50D', 'facebook/esm2_t48_15B_UR50D', 'facebook/esm2_t36_3B_UR50D', 
                        'facebook/esm2_t30_150M_UR50D', 'facebook/esm2_t12_35M_UR50D', 'facebook/esm2_t6_8M_UR50D']:
        from utils.get_embed.esm2 import ESM2Model
        return ESM2Model(model_name)
    elif model_name in ['Rostlab/prot_bert', 'Rostlab/ProstT5', 'Rostlab/ProstT5_fp16', 'Rostlab/prot_t5_xl_uniref50',
            'Rostlab/prot_t5_xl_half_uniref50-enc', 'Rostlab/prot_t5_base_mt_uniref50', 'Rostlab/prot_bert_bfd_ss3',
            'Rostlab/prot_bert_bfd_membrane', 'Rostlab/prot_bert_bfd_localization', 'Rostlab/prot_t5_xxl_uniref50',
            'Rostlab/prot_electra_generator_bfd', 'Rostlab/prot_electra_discriminator_bfd', 'Rostlab/prot_t5_xl_bfd',
            'Rostlab/prot_bert_bfd', 'Rostlab/prot_t5_xxl_bfd', 'Rostlab/prot_xlnet', 'Rostlab/prot_albert']:
        from utils.get_embed.prottrans import ProtTransModel
        return ProtTransModel(model_name)
    elif model_name in ['LucaOne']:
        from utils.get_embed.lucaone import LucaOneModel
        return LucaOneModel(model_name, seq_type, cfg.llm_dir)
    elif model_name in ['evo-1-8k-base', 'evo-1-131k-base']:
        from utils.get_embed.Evo import EVOModel
        return EVOModel(model_name)
    elif model_name in ['genslm_2.5B_patric', 'genslm_250M_patric', 'genslm_25M_patric']:
        from utils.get_embed.Genslm import GenslmModel
        return GenslmModel(model_name)
    elif model_name in ['NT2B5_multi_species', 'NT500M_multi_species_v2', 'NT50M_multi_species_v2', 'NT100M_multi_species_v2', 'NT250M_multi_species_v2', 'NT500M_human_ref', 'NT500M_1000G', 'NT2B5_1000G']:
        from utils.get_embed.nucleotidetransformer import NucleotideTransformerModel
        return NucleotideTransformerModel(model_name)
    elif model_name in ['zhihan1996/DNABERT-2-117M', 'zhihan1996/DNABERT-S', 'zhihan1996/DNA_bert_3', 
                        'zhihan1996/DNA_bert_4', 'zhihan1996/DNA_bert_5', 'zhihan1996/DNA_bert_6']:
        from utils.get_embed.dnabert import DNABERTModel
        return DNABERTModel(model_name)
    elif model_name in ['LongSafari/hyenadna-tiny-1k-seqlen-hf', 'LongSafari/hyenadna-tiny-1k-seqlen-d256-hf', 'LongSafari/hyenadna-tiny-16k-seqlen-d128-hf', 'LongSafari/hyenadna-small-32k-seqlen-hf', 'LongSafari/hyenadna-medium-160k-seqlen-hf', 'LongSafari/hyenadna-medium-450k-seqlen-hf', 'LongSafari/hyenadna-large-1m-seqlen-hf']:
        from utils.get_embed.hyenadna import HyenaDNAModel
        return HyenaDNAModel(model_name)
    elif model_name in ['anonymous8/OmniGenome-418M', 'anonymous8/OmniGenome-186M', 'anonymous8/OmniGenome-52M']:
        from utils.get_embed.omnigenome import OmniGenomeModel
        return OmniGenomeModel(model_name)
    elif model_name in ['multimolecule/rnafm', 'multimolecule/mrnafm']:
        from utils.get_embed.rnafm import RNAFMModel
        return RNAFMModel(model_name)
    elif model_name in ['multimolecule/utrbert-6mer', 'multimolecule/utrbert-5mer', 'multimolecule/utrbert-4mer', 'multimolecule/utrbert-3mer']:
        from utils.get_embed.utrbert import UTRBERTModel
        return UTRBERTModel(model_name)
    elif model_name in ['multimolecule/splicebert', 'multimolecule/splicebert-human.510nt', 'multimolecule/splicebert.510nt']:
        from utils.get_embed.splicebert import SpliceBERTModel
        return SpliceBERTModel(model_name)
    elif model_name in ['multimolecule/rnamsm']:
        from utils.get_embed.rnamsm import RNAMSMModel
        return RNAMSMModel(model_name)
    elif model_name in ['multimolecule/rinalmo']:
        from utils.get_embed.rinalmo import RINALMOModel
        return RINALMOModel(model_name)
    elif model_name in ['multimolecule/ernierna']:
        from utils.get_embed.ernierna import ERNIERNAModel
        return ERNIERNAModel(model_name)
    elif model_name in ['multimolecule/rnaernie']:
        from utils.get_embed.rnaernie import RNAERNIEModel
        return RNAERNIEModel(model_name)
    elif model_name in ['multimolecule/calm']:
        from utils.get_embed.calm import CALMModel
        return CALMModel(model_name)
    elif model_name in ['Rostlab/prot_t5_xl_half_uniref50-enc', 'Rostlab/prot_bert', 'Rostlab/prot_t5_xl_uniref50', 'Rostlab/ProstT5_fp16']:
        from utils.get_embed.prottrans import ProtTransModel
        return ProtTransModel(model_name)
    else:
        raise ValueError(f'Model {model_name} not found in {models}.')
