from transformers import BertConfig, BertModel, BertTokenizer
import torch
from pipeline.registry import registry


BERT_MODEL_PATH = "pretrain_weights/bert-base-uncased"


@registry.register_language_model("bert_tokenizer")
def get_bert_tokenizer():
    tokenizer = BertTokenizer.from_pretrained(
        BERT_MODEL_PATH,
        do_lower_case=True,
        local_files_only=True)
    return tokenizer

@registry.register_language_model("bert_lang_encoder")
def get_bert_lang_encoder(num_hidden_layer=3):
    txt_bert_config = BertConfig(
        hidden_size=768,
        num_hidden_layers=num_hidden_layer,
        num_attention_heads=12,
        type_vocab_size=2
    )
    
    
    txt_encoder = BertModel.from_pretrained(
        BERT_MODEL_PATH,
        config=txt_bert_config,
        local_files_only=True,
        add_pooling_layer=False)
    return txt_encoder


