import torch
import numpy as np

def grad_rollout(attentions, gradients, discard_ratio):
    result = torch.eye(attentions[0].size(0))

    with torch.no_grad():
        for attention, grad in zip(attentions, gradients):   

            attention = attention.squeeze(0).float()

            try:
                grad = grad.reshape(attention.shape).float()
            except:
                continue

            attention_heads_fused = (attention*grad).mean(axis=0)
            attention_heads_fused[attention_heads_fused < 0] = 0

            if attention_heads_fused.shape[-1] != result.shape[0]:
                continue

            # Drop the lowest attentions, but
            # don't drop the class token
            flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
            _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
            #indices = indices[indices != 0]
            flat[0, indices] = 0
            flat = flat.reshape(attention_heads_fused.shape)

            I = torch.eye(flat.size(-1))
            a = (flat + 1.0*I)/2
            a = a / a.sum(dim=-1)
            result = torch.matmul(a, result)
    
    # Look at the total attention between the class token,
    # and the image patches
    mask = result[0, 1 :].flatten().numpy()
    mask = mask / np.max(mask)
    return mask    

class BERTAttentionGradRollout:
    def __init__(self, model, attention_layer_name='attention.self.dropout',
        discard_ratio=0.9):
        self.model = model
        self.discard_ratio = discard_ratio
        for name, module in self.model.base_model.encoder.named_modules():
            if attention_layer_name in name:
                module.register_forward_hook(self.get_attention)
                module.register_backward_hook(self.get_attention_gradient)

        self.attentions = []
        self.attention_gradients = []

    def get_attention(self, module, input, output):
        self.attentions.append(output[0].cpu())

    def get_attention_gradient(self, module, grad_input, grad_output):
        self.attention_gradients.append(grad_input[0].cpu())

    def __call__(self, input_tensor, category_index):
        self.model.zero_grad()
        output = self.model(**input_tensor)
        category_mask = torch.zeros(output.logits.size())
        category_mask[:, category_index] = 1
        A = output.logits*category_mask
        loss = A.sum()
        loss.backward()

        return grad_rollout(self.attentions, self.attention_gradients,
            self.discard_ratio)