import torch
import sys
from project.utils.saprot.esm_loader import load_esm_saprot    # Load the SaProt model

class SaProt():
    def __init__(self,
                 model_dir: str = 'utils/saprot/SaProt_650M_AF2.pt',
                 frozen: bool = True,
                 device: str = 'cpu'):
        
        # Import model
        self.model, self.alphabet = load_esm_saprot(model_dir)

        if frozen:
            self.model.eval()

        self.device = device
        self.model.to(device)

        # Constants and variables
        self.layers = 33
        self.heads = 20

    def forward(self, struct_seq: str):
        # Prepare data
        batch_converter = self.alphabet.get_batch_converter()
        batch_labels, batch_strs, batch_tokens = batch_converter([('seq', struct_seq)])

        # Extract per-residue representations
        with torch.no_grad():
            results = self.model(batch_tokens.to(self.device), repr_layers=[self.layers], need_head_weights=True, return_contacts=False)

        L = (batch_tokens != self.alphabet.padding_idx).sum(1)[0] - 2
        repr = results["representations"][self.layers].squeeze()[1:L+1, :]
        attn = results['attentions'][:,:,:,1:L+1, 1:L+1].reshape([self.layers * self.heads, L, L]).permute(1,2,0)

        return repr, attn
