from genslm import GenSLM
from typing import List


class GenslmModel:
    model_names = [
        'genslm_2.5B_patric',
        'genslm_250M_patric',
        'genslm_25M_patric',
    ]
    def __init__(self, model_name: str):
        assert model_name in self.model_names, f"model_name must be one of {self.model_names}"
        self.model = GenSLM(model_name, model_cache_dir="path/to/check/point").to('cuda')
        self.model.eval()
        self.tokenizer = self.model.tokenizer
    
    def __call__(self, sequences: List[str], kmer: int = 3):
        seqs = [" ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper() for seq in sequences]
        batch  = self.tokenizer(seqs, return_tensors='pt')
        outputs = self.model(
            batch["input_ids"].to('cuda'),
            batch["attention_mask"].to('cuda'),
            output_hidden_states=True,
        )
        # outputs.hidden_states shape: (layers, batch_size, sequence_length, hidden_size)
        embeds = outputs.hidden_states[-1].detach()
        return embeds
