from typing import Tuple, List
import sys
import torch
import scipy.sparse as sp_sparse
import numpy as np
import networkx as nx

def get_input_tokens(input_str: str, tokenizer) -> Tuple[dict, List[str]]:
    """
    Tokenizes the input string using the given tokenizer.

    Args:
        input_str (str): The input string to tokenize.
        tokenizer (PreTrainedTokenizer): The tokenizer to use.

    Returns:
        Tuple[List[str], dict]: A tuple containing the list of the tokenized inputs and the input tokens.
    """

    tokenized_inputs = tokenizer(input_str, return_tensors='pt')
    input_tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
    return tokenized_inputs, input_tokens


def get_modified_input_tokens(input_tokens: List[str], removed_indices: List[int]):
    """
        Generates the modified input tokens by removing elements at specified indices.

        Args:
            input_tokens (List[Any]): The list of input tokens.
            removed_indices (List[int]): The list of indices to be removed.

        Returns:
            List[Any]: The modified list of input tokens.
    """
    modified_input_tokens = [input_tokens[i] for i in range(len(input_tokens)) if i not in removed_indices]
    return modified_input_tokens

def get_zsc_input_tokens(input_str: str, labels:List[str], tokenizer) -> Tuple[dict, List[str]]:
    """
    Generate the input tokens for zero-shot classification.

    Args:
        input_str (str): The input string to tokenize.
        labels (List[str]): The list of labels to tokenize.
        tokenizer (Tokenizer): The tokenizer to use.

    Returns:
        Tuple[List[str], dict]: A tuple containing the tokenized inputs and the input tokens.
    """
    tokenized_inputs = tokenizer.batch_encode_plus([input_str] + labels,
                                     return_tensors='pt',
                                     padding=True)

    input_tokens = tokenizer.convert_ids_to_tokens(tokenized_inputs.input_ids[0])
    return tokenized_inputs, input_tokens


def get_model_outputs(model, tokenizer, input_str: str) -> Tuple[dict, List[torch.Tensor], List[torch.Tensor], dict, List[str]]:
    """
    Runs the given model on the input string using the given tokenizer.

    Args:
        model (PreTrainedModel): The model to use.
        tokenizer (PreTrainedTokenizer): The tokenizer to use.
        input_str (str): The input string to run the model on.

    Returns:
        Tuple[dict, List[torch.Tensor], List[torch.Tensor], List[str], dict]: A tuple containing the model outputs, the hidden states, the attentions, the tokenized inputs, and the input tokens.
    """
    tokenized_inputs, input_tokens = get_input_tokens(input_str, tokenizer)

    with torch.no_grad():
        model_outputs = model(**tokenized_inputs)

        hidden_states, attentions = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])

    return model_outputs, hidden_states, attentions, tokenized_inputs, input_tokens

def get_model_outputs_v0(model, tokenizer, input_str: str) -> Tuple[dict, List[torch.Tensor], List[torch.Tensor], dict, List[str]]:
    """
    Get model outputs for a given input string.

    Args:
        model (Model): The model used for prediction.
        tokenizer (Tokenizer): The tokenizer used to tokenize the input string.
        input_str (str): The input string to be processed.

    Returns:
        Tuple[dict, List[torch.Tensor], List[torch.Tensor], dict, List[str]]: A tuple containing the model outputs, hidden states, attentions, tokenized inputs, and input tokens.
    """
    
    tokenized_inputs, input_tokens = get_input_tokens(input_str, tokenizer)

    model_outputs = model(**tokenized_inputs)

    hidden_states, attentions = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])

    return model_outputs, hidden_states, attentions, tokenized_inputs, input_tokens


def get_attention_hidden_states_grad(model, tokenizer, input_str, objective:str="probs", class_index:int=0, grad_index:int=0):
    """
    Calculates the gradients of the attention hidden states with respect to the input string.

    Args:
    - hf_model (Model): The pretrained model from the Hugging Face library.
    - hf_tokenizer (Tokenizer): The tokenizer corresponding to the pretrained model.
    - input_str (str): The input string to be processed.
    - objective (str): The objective for which the gradients are calculated. Default is "probs".
    - class_index (int): The index of the class for which the gradients are calculated. Default is 0.
    - grad_index (int): The index of the hidden states and attentions to retrieve the gradients from. Default is 0.

    Returns:
    - hidden_states_grad (torch.Tensor): The gradients of the hidden states with respect to the input string.
    - attention_grads (torch.Tensor): The gradients of the attentions with respect to the input string.
    - class_probabilities (torch.Tensor): The class probabilities of the model output.
    - class_logits (torch.Tensor): The logits of the model output.
    """
    model_outputs, hidden_states, attentions, _, _ = get_model_outputs_v0(model, tokenizer, input_str)

    [hidden_states[i].retain_grad() for i in range(len(hidden_states))]
    [attentions[i].retain_grad() for i in range(len(attentions))]

    # hidden_states[0].retain_grad()
    # attentions[0].retain_grad()

    class_logits = model_outputs.get('logits')
    class_probabilities = model_outputs.get('logits').softmax(dim=-1)

    if objective == "probs":
        class_probabilities.flatten()[class_index].backward(retain_graph=True)

    elif objective == "logits":
        class_logits.flatten()[class_index].backward(retain_graph=True)

    else:
        raise Exception("The objective must be either 'probs' or 'logits'")

    hidden_states_grad = hidden_states[grad_index].grad
    attention_grads = attentions[grad_index].grad

    return hidden_states_grad, attention_grads, class_probabilities, class_logits

def get_zsc_model_outputs(model, tokenizer, input_str: str, labels:List[str]) -> Tuple[dict, List[torch.Tensor], List[torch.Tensor], dict, List[str]]:
    """
    Runs the given model on the input string using the given tokenizer.

    Args:
        model (PreTrainedModel): The model to use.
        tokenizer (PreTrainedTokenizer): The tokenizer to use.
        input_str (str): The input string to run the model on.
        labels (List[str]): The list of labels to run the model on.

    Returns:
        Tuple[dict, List[torch.Tensor], List[torch.Tensor], List[str], dict]: A tuple containing the model outputs, the hidden states, the attentions, the tokenized inputs, and the input tokens.
    """
    tokenized_inputs, input_tokens = get_zsc_input_tokens(input_str, labels, tokenizer)

    with torch.no_grad():
        model_outputs = model(**tokenized_inputs)

        hidden_states_list, attentions_list = model_outputs.get('hidden_states', []), model_outputs.get('attentions', [])

        # Get the attentions of the input_str
        attentions = [attention[0, :, :, :].unsqueeze(0) for attention in attentions_list]

        # Get the hidden states of the input_strz
        hidden_states = [hidden_state[0, :, :].unsqueeze(0) for hidden_state in hidden_states_list]

    return model_outputs, hidden_states, attentions, tokenized_inputs, input_tokens

def add_res_block(attentions: List[torch.Tensor], add_res: bool = True, lambda_res: float = 0.5) -> torch.Tensor:
    """
    Computes the aggregated attention weights for a residual block.

    Args:
        attentions (List[torch.Tensor]): A list of attention weight tensors for each layer of the residual block.
        add_res (bool, optional): Whether to add the residual connection to the aggregated attention weights. Defaults to True.

    Returns:
        A tensor containing the aggregated attention weights for the residual block.
    """
    attentions_tensor = torch.cat([att for att in attentions], dim=0)
    mean_attention_tensor = attentions_tensor.mean(axis=1)

    identity_tensor = torch.eye(mean_attention_tensor.shape[1]).unsqueeze(0)

    if add_res:
        agg_attentions_tensor = (1-lambda_res)*mean_attention_tensor + lambda_res*identity_tensor
    else:
        agg_attentions_tensor = mean_attention_tensor

    return agg_attentions_tensor

def get_modified_attention(attention: torch.Tensor, removed_indices: List[int]) -> torch.Tensor:
    """
    Returns a modified attention tensor with the rows and columns corresponding to the removed indices removed.

    Args:
    - attention: A tensor of shape number_layers x number_heads x length_tokens x length_tokens representing the attention weights.
    - removed_indices: A list of indices to be removed from the attention tensor.

    Returns:
    - modified_attention: A tensor of shape number_layers x number_heads x new_length_modified_tokens x new_length_modified_tokens representing the modified attention weights.
    """
    modified_attention = torch.clone(attention)

    # Check the shape of the attention tensor to be number_layers x number_heads x length_tokens x length_tokens
    attention_shape = attention.shape
    number_layers, number_heads, length_tokens, _ = attention_shape
    print(f"The shape of the attention tensor is {number_layers} x {number_heads} x {length_tokens} x {length_tokens}")

    remained_indices = [i for i in range(length_tokens) if i not in removed_indices]
    print(f"The remained indices are {remained_indices} and the removed indices are {removed_indices}")

    # Remove the special tokens rows and columns
    modified_attention = modified_attention[:, :, remained_indices, :]
    modified_attention = modified_attention[:, :, :, remained_indices]
    
    # Check the shape of the modified attention tensor
    print(f"The shape of the modified attention tensor is {modified_attention.shape}")

    return modified_attention


def get_normalized_attention(attention: torch.Tensor) -> torch.Tensor:
    """
    Normalize the attention values in the last dimension so that the sum of each row is 1.0.

    Args:
    - attention (torch.Tensor): The attention tensor to be normalized.

    Returns:
    - normalized_attention (torch.Tensor): The normalized attention tensor.
    """
    normalized_attention = attention / attention.sum(dim=-1, keepdim=True)
    return normalized_attention


def get_adj_matrix_backward(agg_attentions:torch.Tensor) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Calculate the adjacency matrix for a given set of aggregated attentions.
    
    Args:
        agg_attentions (torch.Tensor): A tensor containing the aggregated attentions. The shape of the tensor is (n_layers, length_tokens, length_tokens).

    Returns:
        Tuple[np.ndarray, np.ndarray, float, np.ndarray]: A tuple containing the following:
            - adj_mat (np.ndarray): The adjacency matrix. It has the shape ((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2).
            - scaled_adj_mat (np.ndarray): The scaled adjacency matrix. It has the same shape as adj_mat, but with integer values.
            - max_flow_val (float): The maximum flow value from the source to the target.
            - max_flow_arr (np.ndarray): The flow array with maximum flow values. It has the same shape as scaled_adj_mat.
            - shapley_vals_layerwise (np.ndarray): The Shapley values for each layer, represented as a matrix of shape (n_layers, length_tokens).
    """
    
    n_layers, length_tokens = agg_attentions.shape[0], agg_attentions.shape[1]

    # Set the threshold of printing
    np.set_printoptions(linewidth=sys.maxsize)

    # Find the smallest value that is greater than 0 in the agg_attentions tensor.
    mask = torch.gt(agg_attentions, 0)
    beta_min = agg_attentions[mask].min().detach().numpy()
    print(beta_min)
    
    beta = -np.floor(np.log10(beta_min))
    scale_factor = 10**beta
    
    # Initialize the adjacency matrix with the shape of ((n_layers+1) * length_tokens + 2) x ((n_layers+1) * length_tokens + 2)
    adj_mat = np.zeros(((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2))

    inf_cap = length_tokens # Source and sink: infinity capacity ;)
    # inf_cap is always equal to 1!

    # Fill first layer -> target
    for i in range(length_tokens):
        adj_mat[i+1][0] = inf_cap

    # Fill source -> last layer
    for i in range(length_tokens):
        adj_mat[-1][-i-2] = inf_cap

    # Fill the rest of the adjacency matrix
    for layer in range(0, n_layers):
        # Fill block from layer+1 -> layer
        for k_f in range(length_tokens):
            index_from = (layer+1) * length_tokens + k_f + 1
            
            for k_t in range(length_tokens):
                index_to = (layer) * length_tokens + k_t + 1
                adj_mat[index_from][index_to] = agg_attentions[layer][k_f][k_t]
    
    # Scale the adjacency matrix and round it to the nearest integer
    scaled_adj_mat = (scale_factor * adj_mat).astype(int)

    # Calculate the maximum flow from the source to the target
    src = (n_layers+1)*length_tokens+1
    target = 0

    graph = sp_sparse.csr_matrix(scaled_adj_mat)
    maxflow = sp_sparse.csgraph.maximum_flow(graph, src, target, method='edmonds_karp')

    max_flow_val = maxflow.flow_value * (scale_factor)**-1

    max_flow_arr = maxflow.flow.toarray() * (scale_factor)**-1
    max_flow_arr[max_flow_arr<0] = 0

    # Calculate the Shapley values for each layer
    shapley_vals = np.sum(max_flow_arr, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((n_layers+1, length_tokens))

    return adj_mat, scaled_adj_mat, max_flow_val, max_flow_arr, shapley_vals_layerwise


def get_adj_matrix_backward_v0(agg_attentions:torch.Tensor) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Calculate the adjacency matrix for a given set of aggregated attentions.
    
    Args:
        agg_attentions (torch.Tensor): A tensor containing the aggregated attentions. The shape of the tensor is (n_layers, length_tokens, length_tokens).

    Returns:
        Tuple[np.ndarray, np.ndarray, float, np.ndarray]: A tuple containing the following:
            - adj_mat (np.ndarray): The adjacency matrix. It has the shape ((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2).
            - scaled_adj_mat (np.ndarray): The scaled adjacency matrix. It has the same shape as adj_mat, but with integer values.
            - max_flow_val (float): The maximum flow value from the source to the target.
            - max_flow_arr (np.ndarray): The flow array with maximum flow values. It has the same shape as scaled_adj_mat.
            - shapley_vals_layerwise (np.ndarray): The Shapley values for each layer, represented as a matrix of shape (n_layers, length_tokens).
    """

    n_layers, length_tokens = agg_attentions.shape[0], agg_attentions.shape[1]

    # Set the threshold of printing
    np.set_printoptions(linewidth=sys.maxsize)

    # Find the smallest value that is greater than 0 in the agg_attentions tensor.
    mask = torch.gt(agg_attentions, 0)
    beta_min = agg_attentions[mask].min().detach().numpy()
    print(beta_min)
    
    beta = -np.floor(np.log10(beta_min))
    scale_factor = 10**beta

    agg_attentions = agg_attentions.detach().numpy()
    
    # Initialize the adjacency matrix with the shape of ((n_layers+1) * length_tokens + 2) x ((n_layers+1) * length_tokens + 2)
    adj_mat = np.zeros(((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2))

    inf_cap = length_tokens # Source and sink: infinity capacity ;)
    # inf_cap is always equal to 1!

    # Fill first layer -> target
    for i in range(length_tokens):
        adj_mat[i+1][0] = inf_cap

    # Fill source -> last layer
    for i in range(length_tokens):
        adj_mat[-1][-i-2] = inf_cap

    # Fill the rest of the adjacency matrix
    for layer in range(0, n_layers):
        adj_mat[length_tokens*(layer+1)+1:length_tokens*(layer+2)+1, length_tokens*(layer)+1:length_tokens*(layer+1)+1] = agg_attentions[layer]
    
    # Scale the adjacency matrix and round it to the nearest integer
    scaled_adj_mat = (scale_factor * adj_mat).astype(int)

    # Calculate the maximum flow from the source to the target
    src = (n_layers+1)*length_tokens+1
    target = 0

    graph = sp_sparse.csr_matrix(scaled_adj_mat)
    maxflow = sp_sparse.csgraph.maximum_flow(graph, src, target, method='edmonds_karp')

    max_flow_val = maxflow.flow_value * (scale_factor)**-1

    max_flow_arr = maxflow.flow.toarray() * (scale_factor)**-1
    max_flow_arr[max_flow_arr<0] = 0

    # Calculate the Shapley values for each layer
    shapley_vals = np.sum(max_flow_arr, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((n_layers+1, length_tokens))

    return adj_mat, scaled_adj_mat, max_flow_val, max_flow_arr, shapley_vals_layerwise


def get_adj_matrix_forward(agg_attentions:torch.Tensor) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Calculates the adjacency matrix and maximum flow of a graph based on the given attention values.
    
    Args:
        agg_attentions (torch.Tensor): A tensor of shape (n_layers, length_tokens, length_tokens) representing the attention values between different tokens in each layer.
    
    Returns:
        Tuple[np.ndarray, np.ndarray, float, np.ndarray]: A tuple containing:
            - adj_mat (np.ndarray): The adjacency matrix of the graph, with shape ((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2).
            - maxflow (np.ndarray): The maximum flow from the source to the target in the graph, represented as a sparse matrix.
            - max_flow_val (float): The value of the maximum flow.
            - max_flow_arr (np.ndarray): The maximum flow values for each edge in the graph, represented as a dense matrix.
            - shapley_vals_layerwise (np.ndarray): The Shapley values for each layer, represented as a matrix of shape (n_layers, length_tokens).
    """
    n_layers, length_tokens = agg_attentions.shape[0], agg_attentions.shape[1]

    # Set the threshold of printing
    np.set_printoptions(linewidth=sys.maxsize)

    # Find the smallest value that is greater than 0 in the agg_attentions tensor.
    mask = torch.gt(agg_attentions, 0)
    beta_min = agg_attentions[mask].min().detach().numpy()
    print(beta_min)
    
    beta = -np.floor(np.log10(beta_min))
    scale_factor = 10**beta
    
    # Initialize the adjacency matrix with the shape of ((n_layers+1) * length_tokens + 2) x ((n_layers+1) * length_tokens + 2)
    adj_mat = np.zeros(((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2))

    inf_cap = length_tokens # Source and sink: infinity capacity ;)
    # inf_cap is always equal to 1!

    # Fill source -> First Layer
    for i in range(length_tokens):
        adj_mat[0][i+1] = inf_cap

    # Fill Last Layer -> target
    for i in range(length_tokens):
        adj_mat[-i-2][-1] = inf_cap


    # Fill the rest of the adjacency matrix
    for layer in range(n_layers):
    # Fill block from layer -> layer+1
        for k_f in range(length_tokens):
            index_from = length_tokens*layer + k_f + 1

            for k_t in range(length_tokens):
                index_to = length_tokens*(layer+1) + k_t + 1
                adj_mat[index_from][index_to] = agg_attentions[layer][k_f][k_t]

    # Scale the adjacency matrix and round it to the nearest integer
    scaled_adj_mat = (scale_factor * adj_mat).astype(int)

    # Calculate the maximum flow from the source to the target
    src = 0
    target = (n_layers+1)*length_tokens + 1

    graph = sp_sparse.csr_matrix(scaled_adj_mat)
    maxflow = sp_sparse.csgraph.maximum_flow(graph, src, target, method='edmonds_karp')

    max_flow_val = maxflow.flow_value * (scale_factor)**-1

    max_flow_arr = maxflow.flow.toarray() * (scale_factor)**-1
    max_flow_arr[max_flow_arr<0] = 0

    # Calculate the Shapley values for each layer
    shapley_vals = np.sum(max_flow_arr, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((n_layers+1, length_tokens))

    return  adj_mat, maxflow, max_flow_val, max_flow_arr, shapley_vals_layerwise


def get_adj_matrix_forward_v0(agg_attentions:torch.Tensor) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray, np.ndarray]:
    """
    Calculates the adjacency matrix and maximum flow of a graph based on the given attention values.
    
    Args:
        agg_attentions (torch.Tensor): A tensor of shape (n_layers, length_tokens, length_tokens) representing the attention values between different tokens in each layer.
    
    Returns:
        Tuple[np.ndarray, np.ndarray, float, np.ndarray]: A tuple containing:
            - adj_mat (np.ndarray): The adjacency matrix of the graph, with shape ((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2).
            - maxflow (np.ndarray): The maximum flow from the source to the target in the graph, represented as a sparse matrix.
            - max_flow_val (float): The value of the maximum flow.
            - max_flow_arr (np.ndarray): The maximum flow values for each edge in the graph, represented as a dense matrix.
            - shapley_vals_layerwise (np.ndarray): The Shapley values for each layer, represented as a matrix of shape (n_layers, length_tokens).
    """

    n_layers, length_tokens = agg_attentions.shape[0], agg_attentions.shape[1]

    # Set the threshold of printing
    np.set_printoptions(linewidth=sys.maxsize)

    # Find the smallest value that is greater than 0 in the agg_attentions tensor.
    mask = torch.gt(agg_attentions, 0)
    beta_min = agg_attentions[mask].min().detach().numpy()
    print(beta_min)
    
    beta = -np.floor(np.log10(beta_min))
    scale_factor = 10**beta
    
    agg_attentions = agg_attentions.detach().numpy()
    
    # Initialize the adjacency matrix with the shape of ((n_layers+1) * length_tokens + 2) x ((n_layers+1) * length_tokens + 2)
    adj_mat = np.zeros(((n_layers+1) * length_tokens + 2, (n_layers+1) * length_tokens + 2))

    inf_cap = length_tokens # Source and sink: infinity capacity ;)
    # inf_cap is always equal to 1!

    # Fill source -> First Layer
    for i in range(length_tokens):
        adj_mat[0][i+1] = inf_cap

    # Fill Last Layer -> target
    for i in range(length_tokens):
        adj_mat[-i-2][-1] = inf_cap


       # Fill the rest of the adjacency matrix
    for layer in range(0, n_layers):
        adj_mat[length_tokens*(layer)+1:length_tokens*(layer+1)+1, length_tokens*(layer+1)+1:length_tokens*(layer+2)+1] = agg_attentions[layer].transpose(1,0)


    # Scale the adjacency matrix and round it to the nearest integer
    scaled_adj_mat = (scale_factor * adj_mat).astype(int)

    # Calculate the maximum flow from the source to the target
    src = 0
    target = (n_layers+1)*length_tokens + 1

    graph = sp_sparse.csr_matrix(scaled_adj_mat)
    maxflow = sp_sparse.csgraph.maximum_flow(graph, src, target, method='edmonds_karp')

    max_flow_val = maxflow.flow_value * (scale_factor)**-1

    max_flow_arr = maxflow.flow.toarray() * (scale_factor)**-1
    max_flow_arr[max_flow_arr<0] = 0

    # Calculate the Shapley values for each layer
    shapley_vals = np.sum(max_flow_arr, axis=1)
    shapley_vals_layerwise = shapley_vals[1:-1].reshape((n_layers+1, length_tokens))

    return  adj_mat, scaled_adj_mat, max_flow_val, max_flow_arr, shapley_vals_layerwise


def get_pagerank_adj_matrix(agg_attentions: torch.Tensor) -> np.ndarray:
    """
    Generate the adjacency matrix based on the aggregated attentions.

    Args:
        agg_attentions (torch.Tensor): A tensor containing the aggregated attentions.
        
    Returns:
        np.ndarray: The adjacency matrix with the shape of ((n_layers+1) * length_tokens) x ((n_layers+1) * length_tokens).
    """
    n_layers, length_tokens = agg_attentions.shape[0], agg_attentions.shape[1]
    
    # Initialize the adjacency matrix with the shape of ((n_layers+1) * length_tokens) x ((n_layers+1) * length_tokens)
    adj_mat = np.zeros(((n_layers) * length_tokens, (n_layers) * length_tokens))

    # Fill the rest of the adjacency matrix
    for layer in range(0, n_layers):
        adj_mat[layer*length_tokens:(layer+1)*length_tokens, layer*length_tokens:(layer+1)*length_tokens] = agg_attentions[layer]

    return adj_mat


def get_pagerank(adj_matrix:np.ndarray, alpha=0.85) -> np.ndarray:
    """
    Calculates the PageRank of a graph represented by an adjacency matrix.

    Parameters:
    - adj_matrix (np.ndarray): The adjacency matrix of the graph.
    - alpha (float): The damping factor for the PageRank algorithm. Default is 0.85.

    Returns:
    - np.ndarray: The PageRank values of each node in the graph, as a numpy array.
    """
    G = nx.DiGraph(adj_matrix)

    # Calculate PageRank using NetworkX
    pagerank = nx.pagerank(G, alpha=alpha)

    # Convert the values of the PageRank dictionary to numpy arrays
    pagerank_values = np.array(list(pagerank.values()))
    pagerank_numpy = pagerank_values.reshape(1, -1)

    return pagerank_numpy


def plot_shap_vals(layer_number, shapley_vals_layerwise, input_tokens, fig_size):
    """
    Plots the heatmap of Shapley values for a specific layer in a neural network.

    Args:
        layer_number (int): The index of the layer for which the Shapley values are plotted.
        shapley_vals_layerwise (numpy.ndarray): The array of Shapley values for each layer.
        input_tokens (list): The list of input tokens.
        fig_size (tuple): The size of the figure.

    Returns:
        None
    """
    import matplotlib.pyplot as plt

    # plot heatmap of shapley_vals
    fig, axes = plt.subplots(1, 1, figsize=fig_size)
    shap_vals = shapley_vals_layerwise[layer_number,:].reshape(1,-1)
    print(shap_vals)

    axes.imshow(shap_vals, cmap="YlOrRd", vmin=0, vmax=1)

    for i in range(len(input_tokens)):
        # add vertical lines between tokens
        axes.axvline(x=i+0.5, color="black", linestyle="--", linewidth=1)

        # add values to heatmap and align with tokens
        axes.text(i, 0, str(shap_vals[0, i].round(2)), color="black", fontsize=8, fontweight="bold", ha='center')


    # Remove yticks
    axes.set_yticks([])

    # Set xticks
    axes.set_xticks(list(range(len(input_tokens))))

    # Set xtick labels
    axes.set_xticklabels(labels=input_tokens)

    # Set xlabel
    axes.set_xlabel("Input Tokens", labelpad=10, fontsize=12)

    # save figure to pdf file with high dpi
    # plt.savefig("shapley_vals.pdf", dpi=300)

    plt.show()

    return fig, axes


    