import re
import sys
import numpy as np
from tqdm import tqdm
from pathlib import Path

def grab_attention_weights(model, tokenizer, sentences, MAX_LEN, device='cuda:0'):
    inputs = tokenizer.batch_encode_plus(
        [text_preprocessing(s) for s in sentences],
        return_tensors='pt',
        add_special_tokens=True,
        max_length=MAX_LEN,
        padding='max_length',
        truncation=True
    )

    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    if 'token_type_ids' in inputs:
        token_type_ids = inputs['token_type_ids'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    else:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    attention = outputs['attentions']  # layer x batch x head x seq x seq
    attention = np.asarray([layer.cpu().detach().numpy() for layer in attention], dtype=np.float16)

    return attention
  
def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&amp;' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = re.sub(r'(@.*?)[\s]', ' ', text)

    # Replace '&amp;' with '&'
    text = re.sub(r'&amp;', '&', text)

    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    return text
