"""Interface to construct models."""

from .huggingface_interface import construct_huggingface_model
from .scriptable_bert import construct_scriptable_bert
from .funnel_transformers import construct_scriptable_funnel
from .recurrent_transformers import construct_scriptable_recurrent
from .sanity_check import SanityCheckforPreTraining

import logging
from ..utils import is_main_process

log = logging.getLogger(__name__)


def construct_model(cfg_arch, vocab_size, downstream_classes=None):
    if "ScriptableMaskedLM" in cfg_arch.architectures:
        model = construct_scriptable_bert(cfg_arch, vocab_size, downstream_classes)
    elif "ScriptableFunnelLM" in cfg_arch.architectures:
        model = construct_scriptable_funnel(cfg_arch, vocab_size, downstream_classes)
    elif "ScriptableRecurrentLM" in cfg_arch.architectures:
        model = construct_scriptable_recurrent(cfg_arch, vocab_size, downstream_classes)
    elif "SanityCheckLM" in cfg_arch.architectures:
        model = SanityCheckforPreTraining(cfg_arch.width, vocab_size)
    else:
        try:
            model = construct_huggingface_model(cfg_arch, vocab_size, downstream_classes)
        except Exception as e:
            raise ValueError(f"Invalid model architecture {cfg_arch.architectures} given. Error: {e}")

    num_params = sum([p.numel() for p in model.parameters()])
    if is_main_process():
        log.info(f"Model with architecture {cfg_arch.architectures[0]} loaded with {num_params:,} parameters.")
    return model
