from multimolecule import RnaTokenizer, CaLmModel
from typing import List


class CALMModel:
    model_names = [
        "multimolecule/calm"
    ]
    
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        # {'<pad>': 0, '<cls>': 1, '<eos>': 2, '<unk>': 3, '<mask>': 4, '<null>': 5, 'AAA': 6, 'AAC': 7, 
        # 'AAG': 8, 'AAU': 9, 'AAN': 10, 'ACA': 11, ..., 'NNC': 127, 'NNG': 128, 'NNU': 129, 'NNN': 130}
        self.tokenizer = RnaTokenizer.from_pretrained(model_name)
        self.model = CaLmModel.from_pretrained(model_name).to('cuda')
        self.model.eval()

    def __call__(self, batch_seqs: List[str]):
        inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
        outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[:, 1:-1]
        return embeddings
