from multimolecule import RnaTokenizer, SpliceBertModel
from typing import List


class SpliceBERTModel:
    model_names = [
        'multimolecule/splicebert',
        'multimolecule/splicebert-human.510nt', 
        'multimolecule/splicebert.510nt',
    ]
    
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        # {'<pad>': 0, '<cls>': 1, '<eos>': 2, '<unk>': 3, '<mask>': 4, '<null>': 5, 'A': 6, 'C': 7, 
        # 'G': 8, 'U': 9, 'N': 10, 'R': 11, 'Y': 12, 'S': 13, 'W': 14, 'K': 15, 'M': 16, 'B': 17, 
        # 'D': 18, 'H': 19, 'V': 20, '.': 21, 'X': 22, '*': 23, '-': 24, 'I': 25}
        self.tokenizer = RnaTokenizer.from_pretrained(model_name)
        self.model = SpliceBertModel.from_pretrained(model_name).to('cuda')
        self.model.eval()

    def __call__(self, batch_seqs: List[str]):
        inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
        outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[:, 1:-1]
        return embeddings
