from transformers import AutoTokenizer, EsmModel
import torch
from typing import List


class ESM2Model:
    model_names = [
        'facebook/esm2_t33_650M_UR50D',
        'facebook/esm2_t48_15B_UR50D',
        'facebook/esm2_t36_3B_UR50D',
        'facebook/esm2_t30_150M_UR50D',
        'facebook/esm2_t12_35M_UR50D',
        'facebook/esm2_t6_8M_UR50D',
    ]
    
    def __init__(self, model_name: str) -> None:
        assert model_name in self.model_names, f'Model {model_name} not found in {self.model_names}.'

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = EsmModel.from_pretrained(model_name).to('cuda')
        self.model.eval()

    @torch.no_grad()
    def __call__(self, seqs: List[str]):
        inputs = self.tokenizer(seqs, return_tensors="pt").to('cuda')
        output = self.model(**inputs)
        embed = output.last_hidden_state[:, 1:-1]
        return embed
