from transformers import AutoTokenizer, AutoModel
from typing import List


class HyenaDNAModel:
    model_names = [
        'LongSafari/hyenadna-tiny-1k-seqlen-hf',
        'LongSafari/hyenadna-tiny-1k-seqlen-d256-hf',
        'LongSafari/hyenadna-tiny-16k-seqlen-d128-hf',
        'LongSafari/hyenadna-small-32k-seqlen-hf',
        'LongSafari/hyenadna-medium-160k-seqlen-hf',
        'LongSafari/hyenadna-medium-450k-seqlen-hf',
        'LongSafari/hyenadna-large-1m-seqlen-hf',
    ]
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"Model name {model_name} not found in {self.model_names}"
        
        # {'[CLS]': 0, '[SEP]': 1, '[BOS]': 2, '[MASK]': 3, '[PAD]': 4, '[RESERVED]': 5, 
        # '[UNK]': 6, 'A': 7, 'C': 8, 'G': 9, 'T': 10, 'N': 11}
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True).to('cuda')
        self.model.eval()
    
    def __call__(self, batch_seqs: List[str]):
        inputs = self.tokenizer(batch_seqs, return_tensors="pt").to('cuda')
        outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[:, :-1]
        return embeddings
