from multimolecule import RnaTokenizer, UtrBertModel
from typing import List


class UTRBERTModel:
    model_names = [
        'multimolecule/utrbert-6mer', 
        'multimolecule/utrbert-5mer', 
        'multimolecule/utrbert-4mer', 
        'multimolecule/utrbert-3mer'
    ]
    
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        # {'<pad>': 0, '<cls>' : 1, '<eos>' : 2, '<unk>' : 3, '<mask>' : 4, '<null>' : 5,
        # 'AAAAAA' : 6, 'AAAAAC' : 7, ... 'NNNNNG': 15628, 'NNNNNU': 15629, 'NNNNNN': 15630}
        self.tokenizer = RnaTokenizer.from_pretrained(model_name)
        self.model = UtrBertModel.from_pretrained(model_name).to('cuda')
        self.model.eval()

    def __call__(self, batch_seqs: List[str]):
        inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
        outputs = self.model(**inputs)
        # K-mer embeddings: L-seq -> 1 + (L-k+1) + 1  -> (L-k+3)
        embeddings = outputs.last_hidden_state[:, 1:-1]
        return embeddings
