from multimolecule import RnaTokenizer, RnaBertModel
from typing import List


class RNABERTModel:
    model_names = [
        "multimolecule/rnabert"
    ]
    
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        # {'<pad>': 0, '<cls>': 1, '<eos>': 2, '<unk>': 3, '<mask>': 4, '<null>': 5, 'A': 6, 'C': 7, 
        # 'G': 8, 'U': 9, 'N': 10, 'R': 11, 'Y': 12, 'S': 13, 'W': 14, 'K': 15, 'M': 16, 'B': 17, 
        # 'D': 18, 'H': 19, 'V': 20, '.': 21, 'X': 22, '*': 23, '-': 24, 'I': 25}
        self.tokenizer = RnaTokenizer.from_pretrained(model_name, cls_token=None, eos_token=None, truncate=True)
        self.model = RnaBertModel.from_pretrained(model_name).to('cuda')
        self.model.eval()

    def __call__(self, batch_seqs: List[str]):
        inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
        import pdb; pdb.set_trace()
        outputs = self.model(**inputs)
        
        # K-mer embeddings: L-seq -> 1 + (L-k+1) + 1  -> (L-k+3)
        embeddings = outputs.last_hidden_state[:, 1:-1]
        return embeddings
