from transformers import AutoTokenizer, AutoModel
from typing import List


class OmniGenomeModel:
    model_names = [
        "anonymous8/OmniGenome-418M", 
        "anonymous8/OmniGenome-186M", 
        "anonymous8/OmniGenome-52M",
    ]
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        # {'<cls>': 0, '<pad>': 1, '<eos>': 2, '<unk>': 3, 'A': 4, 'C': 5, 'G': 6, 'T': 7, 'N': 8, 
        # 'U': 9, 'a': 10, 'c': 11, 'g': 12, 't': 13, 'n': 14, 'u': 15, '(': 16, ')': 17, '.': 18, 
        # '*': 19, '1': 20, '2': 21, '3': 22, '<mask>': 23}
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to('cuda')
        self.model.eval()

    def __call__(self, batch_seqs: List[str]):
        inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
        outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[:, 1:-1]
        return embeddings
