import torch
import torch.nn as nn
import esm


class ESM2(nn.Module):

    def __init__(self,
        ckp_path: str,
        frozen: bool=True,
        device: str="cpu",
    ):
        super(ESM2, self).__init__()

        # init
        self.model, self.alphabet = esm.pretrained.load_model_and_alphabet(ckp_path)

        # frozen
        if frozen == True:
            self.model = self.model.eval()

        # device
        self.device = device
        self.model = self.model.to(device)

        # batch_converter
        self.batch_converter = self.alphabet.get_batch_converter()

        if "650M" in ckp_path:
            self.layer = 33
            self.heads = 20

        elif "3B" in ckp_path:
            self.layer = 36
            self.heads = 40



    def forward(self,
        seq: str,
    ):

        batch_labels, batch_strs, batch_tokens = self.batch_converter([('seq', seq)])

        with torch.no_grad():
            results = self.model(batch_tokens.to(self.device), repr_layers=[self.layer], need_head_weights=True, return_contacts=False)

        L = (batch_tokens != self.alphabet.padding_idx).sum(1)[0] - 2
        repr = results["representations"][self.layer].squeeze()[1:L+1, :]
        attn = results['attentions'][:,:,:,1:L+1, 1:L+1].reshape([self.layer * self.heads, L, L]).permute(1,2,0)

        return repr, attn

