import re
from typing import List
from transformers import AutoTokenizer, AutoModel
from transformers import T5EncoderModel, T5Tokenizer, XLNetTokenizer, AlbertTokenizer


class ProtTransModel:
    model_names = [
        "Rostlab/prot_bert",
        "Rostlab/ProstT5",
        "Rostlab/ProstT5_fp16",
        "Rostlab/prot_t5_xl_uniref50",
        "Rostlab/prot_t5_xl_half_uniref50-enc",
        "Rostlab/prot_t5_base_mt_uniref50",
        "Rostlab/prot_bert_bfd_ss3",
        "Rostlab/prot_bert_bfd_membrane",
        "Rostlab/prot_bert_bfd_localization",
        "Rostlab/prot_t5_xxl_uniref50",
        "Rostlab/prot_electra_generator_bfd",
        "Rostlab/prot_electra_discriminator_bfd",
        "Rostlab/prot_t5_xl_bfd",
        "Rostlab/prot_bert_bfd",
        "Rostlab/prot_t5_xxl_bfd",
        "Rostlab/prot_xlnet",
        "Rostlab/prot_albert",
    ]
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        self.model_name = model_name
        
        if 'T5' in model_name.upper():
            self.tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False)
        elif model_name in ["Rostlab/prot_xlnet",]:
            self.tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=False)
        elif model_name in ["Rostlab/prot_albert"]:
            self.tokenizer = AlbertTokenizer.from_pretrained(model_name, do_lower_case=False)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
        self.model = AutoModel.from_pretrained(model_name).to('cuda')
        self.model.eval()

    def __call__(self, batch_seqs: List[str]):
        batch_seqs = [" ".join(list(re.sub(r"[UZOB]", "X", seq))) for seq in batch_seqs]
        if 'T5' in self.model_name.upper():
            inputs = self.tokenizer.batch_encode_plus(batch_seqs, add_special_tokens=True, padding=True, return_tensors="pt").to('cuda')
            outputs = self.model.encoder(**inputs)
            embeddings = outputs.last_hidden_state[:, :-1]
        else:
            inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state[:, 1:-1]
        return embeddings
