import torch
from typing import List
from utils.get_embed.lucaone_utils.get_embedding import load_lucaone_model, get_LUCAONE_embeds


class LucaOneModel:
    model_names = ["LucaOne"]
    def __init__(self, model_name: str, seq_type: str, llm_dir: str) -> None:
        assert model_name in self.model_names, f"Model {model_name} not found in {self.model_names}."
        args_info, model_config, DNA_model, tokenizer = load_lucaone_model(llm_dir)
        self.model = dict(args_info=args_info, model_config=model_config, model=DNA_model.to('cuda'),
                            tokenizer=tokenizer)
        self.seq_type = seq_type
    
    @torch.no_grad()
    def __call__(self, seqs: List[str]):
        logits, emb = get_LUCAONE_embeds(self.model['args_info'], self.model['model_config'],
                        self.model['tokenizer'], self.model['model'], seqs, self.seq_type)
        return emb
