import torch
from evo import Evo
from typing import List, Dict, Any


def get_EVO_embeds(self, x, inference_params_dict=None, padding_mask=None):
    L = x.shape[1]
    x = self.embedding_layer.embed(x)
    if inference_params_dict is not None:
        x, inference_params_dict_out = self.stateful_forward(
            x,
            inference_params_dict=inference_params_dict,
        )
    else:
        x, inference_params_dict_out = self.stateless_forward(x, padding_mask=padding_mask)

    x = self.norm(x)
    logits = self.unembed.unembed(x)
    return logits, x


class EVOModel:
    model_names = ['evo-1-8k-base', 'evo-1-131k-base']
    def __init__(self, model_name: str) -> None:
        assert model_name in self.model_names, f'Model {model_name} not found in {self.model_names}.'
        evo_model = Evo(model_name)
        self.model, self.tokenizer = evo_model.model.to('cuda'), evo_model.tokenizer
        self.model.eval()
    
    @torch.no_grad()
    def __call__(self, seqs: List[str]) -> Dict[str, Dict[str, Any]]:
        input_ids = torch.tensor(self.tokenizer.tokenize_batch(seqs), dtype=torch.int).to('cuda')
        logits, emb = get_EVO_embeds(self.model, input_ids) # (batch, length, vocab)
        return emb
