import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns; sns.set_theme(style='white')
from matplotlib.colors import LogNorm



def get_tokens(
    input_ids = None,
    tokenizer = None,
) -> np.array:
    """
    Get tokens of inputs as np.array of shape (batch, seq_len)
    """
    
    tokens = [[tokenizer.decode(input_ids[i][j]) \
               for j in range(input_ids.shape[1])] \
                for i in range(len(input_ids))]
    tokens = np.array(tokens)

    return tokens



def see_heatmap(
    salience = None,
    tokens: np.array = None,
    fig_title: str = "Salience Pattern",
    y_labels = None,
    figsize = (50,12),
):
    """
    y_labels: y axis label for each batch (row), e.g. question index
    """

    if isinstance(salience, torch.Tensor): salience = salience.detach().cpu().numpy()
    
    plt.figure(figsize=figsize)
    sns.heatmap(
        salience,
        annot=tokens,
        fmt='',
        annot_kws={"size": 8, 'rotation':90},
        cbar=True,
        norm=LogNorm())
    if y_labels is not None:
        plt.yticks(
            ticks=list(reversed(range(len(y_labels)))), labels=reversed(y_labels), rotation='horizontal')
    plt.title(fig_title)



def get_top_tokens(
        salience = None,
        tokens = None,
        n_top_tokens = 10,
        remove_empty: bool = False,
        remove_special = None,
        unique_only: bool = False,
):
    """
    View the top tokens according to the salience map
    remove_special: you may want to consider setting this to tokenizer.all_special_tokens
    """

    if isinstance(salience, torch.Tensor): salience = salience.detach().cpu().numpy()

    assert salience.ndim == 1

    top_tokens = pd.Series(index=tokens, data=salience)
    top_tokens.name = 'salience_value'

    if remove_empty:
        top_tokens = top_tokens[top_tokens.keys() != '']
        top_tokens = top_tokens[top_tokens.keys() != '<0x0A>'] # line feed
    
    if remove_special is not None:
        top_tokens = top_tokens[~top_tokens.keys().isin(remove_special)]

    if unique_only:
        # Keep the max salience value of token where there exist repeats
        top_tokens = top_tokens.groupby(level=0).max()

    top_tokens = top_tokens.sort_values(ascending=False)

    top_tokens = top_tokens[:n_top_tokens]

    return top_tokens