import torch
from typing import List
from esm.models.esm3 import ESM3
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer


class ESM3Model:
    model_name = ["esm3_sm_open_v1"]
    def __init__(self, model_name: str) -> None:
        assert model_name in self.model_name
        self.model = ESM3.from_pretrained(model_name, device=torch.device('cpu')).to('cuda')
        self.model.eval()
        self.tokenizer = EsmSequenceTokenizer()
        
    @torch.no_grad()
    def __call__(self, seqs: List[str]):
        # protein = ESMProtein(sequence=prompt)
        # import pdb; pdb.set_trace()
        tokens = torch.tensor([self.tokenizer.encode(seq) for seq in seqs]).cuda()
        embeds = self.model(sequence_tokens=tokens).embeddings[:, 1:-1]
        return embeds
