import numpy as np
import torch
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import concurrent.futures
from multiprocessing import Pool
import os

class AttentionCalculator:
    def __init__(self, config, model, tokenizer, logger):
        self.config = config
        self.device = config.get('device')
        self.model = model.to(self.device)
        self.tokenizer = tokenizer
        self.logger = logger
        self.num_workers = config.get('num_workers')

    def grab_attention_weights_and_emb_tokens_sparse(self, seq, max_protein_len):
        inputs = self.tokenizer.batch_encode_plus(seq,
                                            return_tensors='pt',
                                            add_special_tokens=True,
                                            max_length=max_protein_len + 2,         # Max length to truncate/pad
                                            pad_to_max_length=True,
                                            padding='max_length',         # Pad sentence to max length)
                                            truncation=True
                                            )

        input_ids = inputs['input_ids'].to(self.device)
        attention_mask = inputs["attention_mask"].to(self.device)
        with torch.no_grad():
            seq_repr = self.model(input_ids, attention_mask, output_hidden_states=True)
            attention = seq_repr['attentions'] 
            embs = seq_repr['hidden_states'] 
            np_emb = seq_repr.last_hidden_state.squeeze().cpu().detach().numpy()[1:-1,:]

        attention = torch.stack([layer.cpu().detach() for layer in attention]).half()
        attention = torch.where(attention < 0.01, 0.0, attention) 

        embs = torch.stack([layer[:,1:-1,:] for layer in embs]).half().squeeze()
        np_embs = embs.cpu().detach().numpy()
        if self.device != 'cpu':
            torch.cuda.empty_cache()
        #attention is a Tuple of layers. Each layer is a tensor with shape [batch, num_heads, seq_length, seq_length]    
        attention = attention.permute(1, 0, 2, 3, 4).numpy()[0] #[:,:,:,1:-1,1:-1]  # [batch, layer, head, seq_length, seq_length]
        return attention, np_embs, np_emb #.to('cpu')
        