""" 
This module provides utility functions for results.
"""

from typing import List
from matplotlib import pyplot as plt
from Code.attention_utils import add_residual_block, concat_tensors, get_modified_tensor, get_normalized_attribution
from Code.graph_utils import get_adj_matrix_backward, get_adj_matrix_forward, get_shap_info
from Code.nlp_utils import get_modified_input_tokens
from Code.post_processing_utils import plot_attr_vals, plot_shap_vals
import os

def get_plots(attention_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_path=None, fixed_width:bool=False):
    """
    Generate a set of plots based on the given attention information, modified input tokens, and figure information.

    Parameters:
    - attention_info (dict): A dictionary containing attention information.
    - modified_input_tokens (list): A list of modified input tokens.
    - fig_info (dict): A dictionary containing figure information.

    Returns:
    - fig (Figure): The generated figure object.
    - axes (Axes): The axes objects for the subplots in the figure.
    """
    attention_base_info = attention_info.get('attention_base')
    attention_grad_info = attention_info.get('attention_grad')
    attention_grad_base_info = attention_info.get('attention_grad_base')
    
    l1 = len(attention_grad_info.keys())
    l2 = len(attention_grad_base_info.keys())

    num_plots = 1 + l1 + l2
    print(f'number of plots: {num_plots}')

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5

    height = fig_info.get("height", 3)
    font_size = fig_info.get("font_size", 14)
    y_font_size = font_size - 3
    dpi = fig_info.get("dpi", 300)
    
    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    fig, axes = plt.subplots(num_plots, 1)

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[0])
    axes[0].set_ylabel('AF', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
    
    
    for i in range(0, l1):
        svl_lb = attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1])
        axes[i+1].set_ylabel(f'AGF-Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
    

    for i in range(0, l2):
        svl_lb = attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+l1+1])
        axes[i+l1+1].set_ylabel(f'GF-Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
        
    # use tight_layout to adjust the spacing between subplots
    plt.tight_layout()
    
    # reduce the sapce between plots
    if "heatmap" in fig_info.get('plot_type'):
        fig.subplots_adjust(wspace=0, hspace=0.5)

    else:
        fig.subplots_adjust(wspace=0, hspace=0.5)

    # set the figure size to the desired size
    fig.set_size_inches(width / 2.54, (1.25)*height * num_plots / 2.54)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')
    return fig, axes

def save_plots(attention_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_dir=None, fixed_width:bool=False):
    """
    Save plots for attention and gradient information for a given input using matplotlib.
    
    Parameters:
    - attention_info: dictionary containing attention information
    - input_tokens: list of input tokens
    - fig_info: dictionary containing figure information
    - removed_indices: list of indices to be removed
    - normalize: boolean indicating whether to normalize the data (default True)
    - save_path: string indicating the path to save the plots
    - fixed_width: boolean indicating whether to use a fixed width for the plots (default False)
    
    Returns:
    None
    """
    attention_base_info = attention_info.get('attention_base')
    attention_grad_info = attention_info.get('attention_grad')
    attention_grad_base_info = attention_info.get('attention_grad_base')
    
    l1 = len(attention_grad_info.keys())
    l2 = len(attention_grad_base_info.keys())

    h_space = 0.5

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5

    height = fig_info.get("height", 3)
    font_size = fig_info.get("font_size", 14)
    y_font_size = font_size - 3
    dpi = fig_info.get("dpi", 300)
    plot_type= fig_info.get('plot_type', 'heatmap')
    
    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    fig_0, axes_0 = plt.subplots(1,1)
    plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_0)
    axes_0.set_ylabel(' ', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    fig_0.set_size_inches(width / 2.54, (height)*1 / 2.54)
    print(axes_0.figure.get_size_inches()*2.54)
    fig_0.savefig(os.path.join(save_dir, f'{plot_type}_AF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')
    
    fig_1, axes_1 = plt.subplots(l1, 1)
    for i in range(0, l1):
        svl_lb = attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_1[i])
        axes_1[i].set_ylabel(f'Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    # reduce the sapce between plots
    fig_1.subplots_adjust(wspace=0, hspace=h_space)
    fig_1.set_size_inches(width / 2.54, ((height+h_space)*l1+h_space) / 2.54)
    print(axes_1[0].figure.get_size_inches()*2.54)
    fig_1.savefig(os.path.join(save_dir, f'{plot_type}_AGF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

    fig_2, axes_2 = plt.subplots(l2, 1)
    for i in range(0, l2):
        svl_lb = attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_2[i])
        axes_2[i].set_ylabel(f'Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
    
    # reduce the sapce between plots
    fig_2.subplots_adjust(wspace=0, hspace=h_space)
    fig_2.set_size_inches(width / 2.54, ((height+h_space)*l2+h_space) / 2.54)
    print(axes_2[0].figure.get_size_inches()*2.54)
    fig_2.savefig(os.path.join(save_dir, f'{plot_type}_GF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

def save_single_plots(attention_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_dir=None, fixed_width:bool=False):
    """
    Save single attention plots for different logit classes and their base classes.
    
    Args:
        attention_info: Dictionary containing attention information
        input_tokens: List of input tokens
        fig_info: Figure information
        removed_indices: List of indices to be removed
        normalize: Boolean indicating whether to normalize the data
        save_path: Path to save the plots
        fixed_width: Boolean indicating whether to use fixed width for plots
    
    Returns:
        None
    """
    attention_base_info = attention_info.get('attention_base')
    attention_grad_info = attention_info.get('attention_grad')
    attention_grad_base_info = attention_info.get('attention_grad_base')
    
    l1 = len(attention_grad_info.keys())
    l2 = len(attention_grad_base_info.keys())
    plot_type= fig_info.get('plot_type', 'heatmap')
    
    plt.rcParams["font.family"] = "Arial"

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    _, ax = plt.subplots()
    plot_shap_vals(layer_number=0, 
                   shapley_vals_layerwise=svl_lb, 
                   input_tokens=input_tokens, 
                   fig_info=fig_info, 
                   removed_indices=removed_indices, 
                   normalize=normalize, 
                   axis=ax,
                   save_path=os.path.join(save_dir, f'AF_{plot_type}.pdf'),
                   fixed_width=fixed_width
                   )
    
    for i in range(0, l1):
        svl_lb = attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        _, ax = plt.subplots()
        plot_shap_vals(layer_number=0, 
                       shapley_vals_layerwise=svl_lb, 
                       input_tokens=input_tokens, 
                       fig_info=fig_info, 
                       removed_indices=removed_indices, 
                       normalize=normalize, 
                       axis=ax,
                       save_path=os.path.join(save_dir, f'AGF-Class={i}_{plot_type}.pdf'),
                       fixed_width=fixed_width
                       )
        

    for i in range(0, l2):
        svl_lb = attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        _, ax = plt.subplots()
        plot_shap_vals(layer_number=0, 
                       shapley_vals_layerwise=svl_lb, 
                       input_tokens=input_tokens, 
                       fig_info=fig_info, 
                       removed_indices=removed_indices, 
                       normalize=normalize, 
                       axis=ax,
                       save_path=os.path.join(save_dir, f'GF-Class={i}_{plot_type}.pdf'),
                       fixed_width=fixed_width)

def get_all_info(model, removed_indices, **kwargs) -> dict:
    """
    Generates a dictionary of attention and model information.
    
    Args:
        model: The model object.
        removed_indices: A list of indices to be removed from the input tokens.
        **kwargs: Additional keyword arguments to be passed to the `get_info` method of the model.
        
    Returns:
        attention_info: A dictionary containing attention information.
        model_info: A dictionary containing model information.
    """
    attention_info = {}
    
    model_params = kwargs.get('model_params')
    model_info = model.get_info(**model_params)
    modified_input_tokens = get_modified_input_tokens(model_info.get('input_tokens'), removed_indices)
    
    model_info['modified_input_tokens'] = modified_input_tokens
    model_info['removed_indices'] = removed_indices
    
    attentions, attention_grads = model_info['attentions'], model_info['attention_grads']
    
    attention_tensor = concat_tensors(attentions, dim=0)
    
    attention_grads_tensor_list  = [concat_tensors(attention_grads[i], dim=0) for i in range(len(attention_grads))]
    
    modified_attentions = get_modified_tensor(attention_tensor, removed_indices)
    modified_attention_grads_list = [get_modified_tensor(attention_grads_tensor_list[i], removed_indices) for i in range(len(attention_grads))]

    attention_base_attribution = get_normalized_attribution(modified_attentions, input_epsilon=1e-10)
    agg_attentions = add_residual_block(attention_base_attribution, add_res=True, lambda_res=0.05)
    
    bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
    fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
    
    attention_base_shap_params = kwargs.get('attention_base_shap_params')
    bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_base_shap_params)
    fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_base_shap_params)
    
    attention_base_info = {
        'attribution': attention_base_attribution,
        'fw_gr_info': fw_attention_gr_info,
        'bw_gr_info': bw_attention_gr_info,
        'fw_shap_info': fw_shap_info,
        'bw_shap_info': bw_shap_info
    }
    
    attention_info['attention_base'] = attention_base_info
    
    attention_grad_info = {}
    attention_grad_shap_params = kwargs.get('attention_grad_shap_params')
    for i, modified_attention_grads in enumerate(modified_attention_grads_list):
        attention_grad_base_attribution = get_normalized_attribution(modified_attention_grads * modified_attentions, input_epsilon=1e-10)
        agg_attentions = add_residual_block(attention_grad_base_attribution, add_res=True, lambda_res=0.05)
        
        bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
        fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
        
        bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_grad_shap_params)
        fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_grad_shap_params)
        
        attention_grad_info_index = {
            'attribution': attention_grad_base_attribution,
            'fw_gr_info': fw_attention_gr_info,
            'bw_gr_info': bw_attention_gr_info,
            'fw_shap_info': fw_shap_info,
            'bw_shap_info': bw_shap_info
        }
        
        attention_grad_info[f'attention_grad_logit_{i}'] = attention_grad_info_index
    
    attention_info['attention_grad'] = attention_grad_info
    
    attention_grad_base_info = {}
    attention_grad_base_shap_params = kwargs.get('attention_grad_base_shap_params')
    
    for i, modified_attention_grads in enumerate(modified_attention_grads_list):
        grad_attribution = get_normalized_attribution(modified_attention_grads, input_epsilon=1e-10)
        agg_attentions = add_residual_block(grad_attribution, add_res=True, lambda_res=0.05)
        
        bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
        fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
        
        bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_grad_base_shap_params)
        fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_grad_base_shap_params)
        
        attention_grad_base_info_index = {
            'attribution': grad_attribution,
            'fw_gr_info': fw_attention_gr_info,
            'bw_gr_info': bw_attention_gr_info,
            'fw_shap_info': fw_shap_info,
            'bw_shap_info': bw_shap_info
        }
        
        attention_grad_base_info[f'attention_grad_base_logit_{i}'] = attention_grad_base_info_index
    
    attention_info['attention_grad_base'] = attention_grad_base_info
    return attention_info, model_info

def get_plots_qa(attention_info, input_tokens, fig_info, removed_indices:List[int], 
                normalize:bool=True, save_path=None, fixed_width=False):
    """
    Generate the plots for QA attention analysis.

    Parameters:
    - attention_info (dict): A dictionary containing attention information.
    - input_tokens (list): A list of input tokens.
    - fig_info (dict): A dictionary containing figure information.
    
    Returns:
    - fig (Figure): The generated figure object.
    - axes (Axes): The axes objects for the subplots in the figure.
    """
    attention_base_info = attention_info.get('attention_base')
    
    start_attention_grad_info = attention_info.get('start_attention_grad')
    end_attention_grad_info = attention_info.get('end_attention_grad')
    
    start_attention_grad_base_info = attention_info.get('start_attention_grad_base')
    end_attention_grad_base_info = attention_info.get('end_attention_grad_base')
    
    start_l1 = len(start_attention_grad_info.keys())
    end_l1 = len(end_attention_grad_info.keys())
    start_l2 = len(start_attention_grad_base_info.keys())
    end_l2 = len(end_attention_grad_base_info.keys())

    num_plots = 1 + start_l1 + end_l1 + start_l2 + end_l2
    print(f'num_plots: {num_plots}')

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5
    
    height = fig_info.get('height', 3)
    font_size = fig_info.get('font_size', 14)
    y_font_size = font_size - 3
    dpi = fig_info.get("dpi", 300)


    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    fig, axes = plt.subplots(num_plots, 1)

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[0])

    axes[0].set_ylabel('AF', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    
    for i in range(0, start_l1):
        svl_lb = start_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1])
        axes[i+1].set_ylabel(f'AGF-ST={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
    
    for i in range(0, end_l1):
        svl_lb = end_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1+start_l1])
        axes[i+1+end_l1].set_ylabel(f'AGF-ET={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
        
    for i in range(0, start_l2):
        svl_lb = start_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1+start_l1+end_l1])
        axes[i+1+start_l1+end_l1].set_ylabel(f'GF-ST={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
    
    for i in range(0, end_l2):
        svl_lb = end_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1+start_l1+end_l1+start_l2])
        axes[i+1+start_l1+end_l1+start_l2].set_ylabel(f'GF-ET={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})
        
    # use tight_layout to adjust the spacing between subplots
    plt.tight_layout()
    
    # reduce the sapce between plots
    if "heatmap" in fig_info.get('plot_type'):
        fig.subplots_adjust(wspace=0, hspace=0.5)

    else:
        fig.subplots_adjust(wspace=0, hspace=0.5)

    # set the figure size to the desired size
    fig.set_size_inches(width / 2.54, (1.25)*height * num_plots / 2.54)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')
    return fig, axes


def save_plots_qa(attention_info, input_tokens, fig_info, removed_indices:List[int], 
                  normalize:bool=True, save_dir=None, fixed_width:bool=False):
    """
    Generate a set of plots based on the given attention information, modified input tokens, and figure information.

    Parameters:
    - attention_info (dict): A dictionary containing attention information.
    - modified_input_tokens (list): A list of modified input tokens.
    - fig_info (dict): A dictionary containing figure information.

    Returns:
    - fig (Figure): The generated figure object.
    - axes (Axes): The axes objects for the subplots in the figure.
    """
    attention_base_info = attention_info.get('attention_base')
    
    start_attention_grad_info = attention_info.get('start_attention_grad')
    end_attention_grad_info = attention_info.get('end_attention_grad')
    
    start_attention_grad_base_info = attention_info.get('start_attention_grad_base')
    end_attention_grad_base_info = attention_info.get('end_attention_grad_base')
    
    start_l1 = len(start_attention_grad_info.keys())
    end_l1 = len(end_attention_grad_info.keys())
    start_l2 = len(start_attention_grad_base_info.keys())
    end_l2 = len(end_attention_grad_base_info.keys())

    h_space = 0.5

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5

    height = fig_info.get("height", 3)
    font_size = fig_info.get("font_size", 14)
    y_font_size = font_size - 3
    dpi = fig_info.get("dpi", 300)
    plot_type= fig_info.get('plot_type', 'heatmap')
    
    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    fig_0, axes_0 = plt.subplots(1,1)
    plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_0)
    axes_0.set_ylabel(' ', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    fig_0.set_size_inches(width / 2.54, (height)*1 / 2.54)
    print(axes_0.figure.get_size_inches()*2.54)
    fig_0.savefig(os.path.join(save_dir, f'{plot_type}_AF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

    if start_l1 >1:
        fig_1_start, axes_1_start = plt.subplots(start_l1, 1)
    else:
        fig_1_start, axes= plt.subplots(1, 1)
        axes_1_start = [axes]

    for i in range(0, start_l1):
        svl_lb = start_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_1_start[i])
        axes_1_start[i].set_ylabel(f'Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    # reduce the sapce between plots
    if start_l1 > 1:
        fig_1_start.subplots_adjust(wspace=0, hspace=h_space)
        fig_1_start.set_size_inches(width / 2.54, ((height+h_space)*start_l1+h_space) / 2.54)
    else:
        fig_1_start.set_size_inches(width / 2.54, height / 2.54)
        
    print(axes_1_start[0].figure.get_size_inches()*2.54)
    fig_1_start.savefig(os.path.join(save_dir, f'{plot_type}_AGF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

    if end_l1 >1:
        fig_1_end, axes_1_end = plt.subplots(end_l1, 1)
    else:
        fig_1_end, axes= plt.subplots(1, 1)
        axes_1_end = [axes]

    for i in range(0, end_l1):
        svl_lb = end_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_1_end[i])
        axes_1_end[i].set_ylabel(f'Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    # reduce the sapce between plots
    if end_l1 > 1:
        fig_1_end.subplots_adjust(wspace=0, hspace=h_space)
        fig_1_end.set_size_inches(width / 2.54, ((height+h_space)*end_l1+h_space) / 2.54)
    else:
        fig_1_end.set_size_inches(width / 2.54, height / 2.54)

    print(axes_1_end[0].figure.get_size_inches()*2.54)
    fig_1_end.savefig(os.path.join(save_dir, f'{plot_type}_AGF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

    if start_l2 >1:
        fig_2_start, axes_2_start = plt.subplots(start_l2, 1)
    else:
        fig_2_start, axes= plt.subplots(1, 1)
        axes_2_start = [axes]

    for i in range(0, start_l2):
        svl_lb = start_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_2_start[i])
        axes_2_start[i].set_ylabel(f'Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    # reduce the sapce between plots
    if start_l2 > 1:
        fig_2_start.subplots_adjust(wspace=0, hspace=h_space)
        fig_2_start.set_size_inches(width / 2.54, ((height+h_space)*start_l2+h_space) / 2.54)
    else:
        fig_2_start.set_size_inches(width / 2.54, height / 2.54)

    print(axes_2_start[0].figure.get_size_inches()*2.54)
    fig_2_start.savefig(os.path.join(save_dir, f'{plot_type}_GF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

    if end_l2 >1:
        fig_2_end, axes_2_end = plt.subplots(end_l2, 1)
    else:
        fig_2_end, axes= plt.subplots(1, 1)
        axes_2_end = [axes]

    for i in range(0, end_l2):
        svl_lb = end_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes_2_end[i])
        axes_2_end[i].set_ylabel(f'Class={i}', fontdict={'fontsize': y_font_size, 'fontweight': 'normal'})

    # reduce the sapce between plots
    if end_l2 > 1:
        fig_2_end.subplots_adjust(wspace=0, hspace=h_space)
        fig_2_end.set_size_inches(width / 2.54, ((height+h_space)*end_l2+h_space) / 2.54)
    else:
        fig_2_end.set_size_inches(width / 2.54, height / 2.54)
    print(axes_2_end[0].figure.get_size_inches()*2.54)
    fig_2_end.savefig(os.path.join(save_dir, f'{plot_type}_GF.pdf'), format='pdf', dpi=dpi, bbox_inches='tight')

def save_single_plots_qa(attention_info, input_tokens, fig_info, removed_indices:List[int], 
              normalize:bool=True, save_dir=None, fixed_width:bool=False):
    """
    Saves the plots of attention and model information for a specific layer in a neural network.

    Args:
        layer_number (int): The index of the layer for which the attention and model information are plotted.
        attention_info (dict): The dictionary containing attention information.
        input_tokens (list): The list of input tokens.
        fig_info (dict): The dictionary containing figure information.
        removed_indices (list): The list of indices to be removed from the input tokens.
        normalize (bool): Whether to normalize the attention values.
        save_dir (str): The directory where the plots are saved.
        fixed_width (bool): Whether to use a fixed width for the plot.
    """

    attention_base_info = attention_info.get('attention_base')
    
    start_attention_grad_info = attention_info.get('start_attention_grad')
    end_attention_grad_info = attention_info.get('end_attention_grad')
    
    start_attention_grad_base_info = attention_info.get('start_attention_grad_base')
    end_attention_grad_base_info = attention_info.get('end_attention_grad_base')
    
    start_l1 = len(start_attention_grad_info.keys())
    end_l1 = len(end_attention_grad_info.keys())
    start_l2 = len(start_attention_grad_base_info.keys())
    end_l2 = len(end_attention_grad_base_info.keys())

    plot_type= fig_info.get('plot_type', 'heatmap')
    
    plt.rcParams["font.family"] = "Arial"

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    _, ax = plt.subplots()
    plot_shap_vals(layer_number=0, 
                   shapley_vals_layerwise=svl_lb, 
                   input_tokens=input_tokens, 
                   fig_info=fig_info, 
                   removed_indices=removed_indices, 
                   normalize=normalize, 
                   axis=ax,
                   save_path=os.path.join(save_dir, f'AF_{plot_type}.pdf'),
                   fixed_width=fixed_width
                   )
    
    for i in range(0, start_l1):
        svl_lb = start_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        _, ax = plt.subplots()
        plot_shap_vals(layer_number=0, 
                       shapley_vals_layerwise=svl_lb, 
                       input_tokens=input_tokens, 
                       fig_info=fig_info, 
                       removed_indices=removed_indices, 
                       normalize=normalize, 
                       axis=ax,
                       save_path=os.path.join(save_dir, f'AGF-Start={i}_{plot_type}.pdf'),
                       fixed_width=fixed_width
                       )

    for i in range(0, end_l1):
        svl_lb = end_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        _, ax = plt.subplots()
        plot_shap_vals(layer_number=0, 
                       shapley_vals_layerwise=svl_lb, 
                       input_tokens=input_tokens, 
                       fig_info=fig_info, 
                       removed_indices=removed_indices, 
                       normalize=normalize, 
                       axis=ax,
                       save_path=os.path.join(save_dir, f'AGF-End={i}_{plot_type}.pdf'),
                       fixed_width=fixed_width
                       )
    for i in range(0, start_l2):
        svl_lb = start_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        _, ax = plt.subplots()
        plot_shap_vals(layer_number=0, 
                       shapley_vals_layerwise=svl_lb, 
                       input_tokens=input_tokens, 
                       fig_info=fig_info, 
                       removed_indices=removed_indices, 
                       normalize=normalize, 
                       axis=ax,
                       save_path=os.path.join(save_dir, f'GF-Start={i}_{plot_type}.pdf'),
                       fixed_width=fixed_width
                       )

    for i in range(0, end_l2):
        svl_lb = end_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        _, ax = plt.subplots()
        plot_shap_vals(layer_number=0, 
                       shapley_vals_layerwise=svl_lb, 
                       input_tokens=input_tokens, 
                       fig_info=fig_info, 
                       removed_indices=removed_indices, 
                       normalize=normalize, 
                       axis=ax,
                       save_path=os.path.join(save_dir, f'GF-End={i}_{plot_type}.pdf'),
                       fixed_width=fixed_width
                       )
    


def get_all_info_qa(model, removed_indices, **kwargs):
    """
    Generates a dictionary of attention and model information.
    
    Args:
        model: The model object.
        removed_indices: A list of indices to be removed from the input tokens.
        **kwargs: Additional keyword arguments to be passed to the `get_info` method of the model.
        
    Returns:
        attention_info: A dictionary containing attention information.
        model_info: A dictionary containing model information.
    """
    attention_info = {}
    
    model_params = kwargs.get('model_params')
    model_info = model.get_info(**model_params)
    modified_input_tokens = get_modified_input_tokens(model_info.get('input_tokens'), removed_indices)
    
    model_info['modified_input_tokens'] = modified_input_tokens
    model_info['removed_indices'] = removed_indices
    
    attentions, end_attention_grads, start_attention_grads = model_info['attentions'], model_info['end_attention_grads'], model_info['start_attention_grads']
    
    attention_tensor = concat_tensors(attentions, dim=0)
    
    start_attention_grads_tensor_list  = [concat_tensors(start_attention_grads[i], dim=0) for i in range(len(start_attention_grads))]
    end_attention_grads_tensor_list  = [concat_tensors(end_attention_grads[i], dim=0) for i in range(len(end_attention_grads))]
    
    modified_attentions = get_modified_tensor(attention_tensor, removed_indices)
    modified_start_attention_grads_list = [get_modified_tensor(start_attention_grads_tensor_list[i], removed_indices) for i in range(len(start_attention_grads))]
    modified_end_attention_grads_list = [get_modified_tensor(end_attention_grads_tensor_list[i], removed_indices) for i in range(len(end_attention_grads))]

    attention_base_attribution = get_normalized_attribution(modified_attentions, input_epsilon=1e-10)
    agg_attentions = add_residual_block(attention_base_attribution, add_res=True, lambda_res=0.05)
    
    bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
    fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
    
    attention_base_shap_params = kwargs.get('attention_base_shap_params')
    bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_base_shap_params)
    fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_base_shap_params)
    
    attention_base_info = {
        'attribution': attention_base_attribution,
        'fw_gr_info': fw_attention_gr_info,
        'bw_gr_info': bw_attention_gr_info,
        'fw_shap_info': fw_shap_info,
        'bw_shap_info': bw_shap_info
    }
    attention_info['attention_base'] = attention_base_info
    
    start_attention_grad_info = {}
    attention_grad_shap_params = kwargs.get('attention_grad_shap_params')
    
    for i, modified_start_attention_grads in enumerate(modified_start_attention_grads_list):
        attention_grad_base_attribution = get_normalized_attribution(modified_start_attention_grads*modified_attentions, input_epsilon=1e-10)
        agg_attentions = add_residual_block(attention_grad_base_attribution, add_res=True, lambda_res=0.05)
        
        bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
        fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
        
        bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_grad_shap_params)
        fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_grad_shap_params)
        
        attention_grad_info_index = {
            'attribution': attention_grad_base_attribution,
            'fw_gr_info': fw_attention_gr_info,
            'bw_gr_info': bw_attention_gr_info,
            'fw_shap_info': fw_shap_info,
            'bw_shap_info': bw_shap_info
        }
        
        start_attention_grad_info[f'attention_grad_logit_{i}'] = attention_grad_info_index
    
    attention_info['start_attention_grad'] = start_attention_grad_info
    
    end_attention_grads_info = {}
    for i, modified_end_attention_grads in enumerate(modified_end_attention_grads_list):
        attention_grad_base_attribution = get_normalized_attribution(modified_end_attention_grads*modified_attentions, input_epsilon=1e-10)
        agg_attentions = add_residual_block(attention_grad_base_attribution, add_res=True, lambda_res=0.05)
        
        bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
        fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
        
        bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_grad_shap_params)
        fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_grad_shap_params)
        
        attention_grad_info_index = {
            'attribution': attention_grad_base_attribution,
            'fw_gr_info': fw_attention_gr_info,
            'bw_gr_info': bw_attention_gr_info,
            'fw_shap_info': fw_shap_info,
            'bw_shap_info': bw_shap_info
        }
        
        end_attention_grads_info[f'attention_grad_logit_{i}'] = attention_grad_info_index
    
    attention_info['end_attention_grad'] = end_attention_grads_info
    
    start_attention_grad_base_info = {}
    attention_grad_base_shap_params = kwargs.get('attention_grad_base_shap_params')
    
    for i, modified_start_attention_grads in enumerate(modified_start_attention_grads_list):
        grad_attribution = get_normalized_attribution(modified_start_attention_grads, input_epsilon=1e-10)
        agg_attentions = add_residual_block(grad_attribution, add_res=True, lambda_res=0.05)
        
        bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
        fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
        
        bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_grad_base_shap_params)
        fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_grad_base_shap_params)
        
        attention_grad_base_info_index = {
            'attribution': grad_attribution,
            'fw_gr_info': fw_attention_gr_info,
            'bw_gr_info': bw_attention_gr_info,
            'fw_shap_info': fw_shap_info,
            'bw_shap_info': bw_shap_info
        }
        
        start_attention_grad_base_info[f'attention_grad_base_logit_{i}'] = attention_grad_base_info_index
    
    attention_info['start_attention_grad_base'] = start_attention_grad_base_info
    
    end_attention_grad_base_info = {}
    for i, modified_end_attention_grads in enumerate(modified_end_attention_grads_list):
        grad_attribution = get_normalized_attribution(modified_end_attention_grads, input_epsilon=1e-10)
        agg_attentions = add_residual_block(grad_attribution, add_res=True, lambda_res=0.05)
        
        bw_attention_gr_info = get_adj_matrix_backward(agg_attentions)
        fw_attention_gr_info = get_adj_matrix_forward(agg_attentions)
        
        bw_shap_info = get_shap_info(bw_attention_gr_info, len(modified_input_tokens), backward=True, **attention_grad_base_shap_params)
        fw_shap_info = get_shap_info(fw_attention_gr_info, len(modified_input_tokens), backward=False, **attention_grad_base_shap_params)
        
        attention_grad_base_info_index = {
            'attribution': grad_attribution,
            'fw_gr_info': fw_attention_gr_info,
            'bw_gr_info': bw_attention_gr_info,
            'fw_shap_info': fw_shap_info,
            'bw_shap_info': bw_shap_info
        }
        
        end_attention_grad_base_info[f'attention_grad_base_logit_{i}'] = attention_grad_base_info_index
    
    attention_info['end_attention_grad_base'] = end_attention_grad_base_info
    return attention_info, model_info


def set_fontsize(axes, alpha:float) -> float:
    """
    Set the font size of the axes based on the height of the figure.

    Parameters:
        axes (Axes): The axes for which to set the font size.
        alpha (float): The scaling factor for the font size calculation.

    Returns:
        float: The calculated font size.
    """
    # Get the figure that contains the axes
    fig = axes.figure

    # Calculate the font size based on the height of the figure
    font_size = alpha * fig.get_figheight()
    return font_size


def get_benchmark_plots(attention_info, benchamrk_attribution_info, input_tokens, 
                        fig_info, removed_indices:List[int], normalize:bool=True, 
                        save_path=None, fixed_width:bool=False):
    """
    Generate a set of plots based on the given attention information, modified input tokens, and figure information.

    Parameters:
    - attention_info (dict): A dictionary containing attention information.
    - modified_input_tokens (list): A list of modified input tokens.
    - fig_info (dict): A dictionary containing figure information.

    Returns:
    - fig (Figure): The generated figure object.
    - axes (Axes): The axes objects for the subplots in the figure.
    """
    attention_base_info = attention_info.get('attention_base')
    attention_grad_info = attention_info.get('attention_grad')
    attention_grad_base_info = attention_info.get('attention_grad_base')
    
    l1 = len(attention_grad_info.keys())
    l2 = len(attention_grad_base_info.keys())

    l3 = len(benchamrk_attribution_info.get('model_normalized_attributions_info').keys())

    num_plots = 1 + l1 + l2 + l3
    print(f'number of plots: {num_plots}')
    
    dpi = fig_info.get('dpi', 300)

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5
    
    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    fig, axes = plt.subplots(num_plots, 1, figsize=(width, 2.75*num_plots))

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[0])
    axes[0].set_ylabel('AF', fontdict={'fontsize': set_fontsize(axes[0], 0.9), 'fontweight': 'normal'})
    
    for i in range(0, l1):
        svl_lb = attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1])
        axes[i+1].set_ylabel(f'AGF-Class={i}', fontdict={'fontsize': set_fontsize(axes[i+1], 0.9), 'fontweight': 'normal'})
    
    for i in range(0, l2):
        svl_lb = attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+l1+1])
        axes[i+l1+1].set_ylabel(f'GF-Class={i}', fontdict={'fontsize': set_fontsize(axes[i+l1+1], 0.9), 'fontweight': 'normal'})
    
    for i in range(0, l3):
        svl_lb = benchamrk_attribution_info.get('model_normalized_attributions_info').get(f'normalized_attributions_label_{i}').round(4)
        plot_attr_vals(svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+l1+l2+1])
        axes[i+l1+l2+1].set_ylabel(f'LIG-Class={i}', fontdict={'fontsize': set_fontsize(axes[i+l1+l2+1], 0.9), 'fontweight': 'normal'})

    # use tight_layout to adjust the spacing between subplots
    plt.tight_layout()
    
    # reduce the sapce between plots
    if "heatmap" in fig_info.get('plot_type'):
        fig.subplots_adjust(wspace=0, hspace=0.5)

    else:
        fig.subplots_adjust(wspace=0, hspace=0.5)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')
    return fig, axes


def get_benchmark_plots_qa(attention_info, benchamrk_attribution_info, input_tokens, 
                           fig_info, removed_indices:List[int], normalize: bool=True, 
                           save_path=None, fixed_width=False):
    """
    Generate the plots for QA attention analysis.

    Parameters:
    - attention_info (dict): A dictionary containing attention information.
    - input_tokens (list): A list of input tokens.
    - fig_info (dict): A dictionary containing figure information.
    
    Returns:
    - fig (Figure): The generated figure object.
    - axes (Axes): The axes objects for the subplots in the figure.
    """
    attention_base_info = attention_info.get('attention_base')
    
    start_attention_grad_info = attention_info.get('start_attention_grad')
    end_attention_grad_info = attention_info.get('end_attention_grad')
    
    start_attention_grad_base_info = attention_info.get('start_attention_grad_base')
    end_attention_grad_base_info = attention_info.get('end_attention_grad_base')
    
    start_l1 = len(start_attention_grad_info.keys())
    end_l1 = len(end_attention_grad_info.keys())
    start_l2 = len(start_attention_grad_base_info.keys())
    end_l2 = len(end_attention_grad_base_info.keys())

    num_plots = 1 + start_l1 + end_l1 + start_l2 + end_l2 + 1+ 1
    print(f'num_plots: {num_plots}')

    dpi = fig_info.get('dpi', 300)

    if fixed_width:
        width = fig_info.get('width', 20)
    else:
        width = len(input_tokens) * 2.5

    # Set the default font to "Times New Roman"
    plt.rcParams["font.family"] = "Arial"

    fig, axes = plt.subplots(num_plots, 1, figsize=(width, 2.75*num_plots))

    svl_lb = attention_base_info.get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
    plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[0])

    axes[0].set_ylabel('AF', fontdict={'fontsize': set_fontsize(axes[0], 0.9), 'fontweight': 'normal'})

    
    for i in range(0, start_l1):
        svl_lb = start_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1])
        axes[i+1].set_ylabel(f'AGF-ST={i}', fontdict={'fontsize': set_fontsize(axes[i+1], 0.9), 'fontweight': 'normal'})
    
    for i in range(0, end_l1):
        svl_lb = end_attention_grad_info.get(f'attention_grad_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1+start_l1])
        axes[i+1+end_l1].set_ylabel(f'AGF-ET={i}', fontdict={'fontsize': set_fontsize(axes[i+1+start_l1], 0.9), 'fontweight': 'normal'})
        
    for i in range(0, start_l2):
        svl_lb = start_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1+start_l1+end_l1])
        axes[i+1+start_l1+end_l1].set_ylabel(f'GF-ST={i}', fontdict={'fontsize': set_fontsize(axes[i+1+start_l1+end_l1], 0.9), 'fontweight': 'normal'})
    
    for i in range(0, end_l2):
        svl_lb = end_attention_grad_base_info.get(f'attention_grad_base_logit_{i}').get('bw_shap_info').get('normalized_shapley_vals_layerwise').round(4)
        plot_shap_vals(0, svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[i+1+start_l1+end_l1+start_l2])
        axes[i+1+start_l1+end_l1+start_l2].set_ylabel(f'GF-ET={i}', fontdict={'fontsize': set_fontsize(axes[0], 0.9), 'fontweight': 'normal'})
    
    svl_lb = benchamrk_attribution_info.get('model_normalized_attributions_info').get('attributions_start')
    plot_attr_vals(svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[num_plots-2])
    axes[num_plots-2].set_ylabel('LIG-Start', fontdict={'fontsize': set_fontsize(axes[num_plots-2], 0.9), 'fontweight': 'normal'})
    
    svl_lb = benchamrk_attribution_info.get('model_normalized_attributions_info').get('attributions_end')
    plot_attr_vals(svl_lb, input_tokens, fig_info, removed_indices, normalize, axis=axes[num_plots-1])
    axes[num_plots-1].set_ylabel('LIG-End', fontdict={'fontsize': set_fontsize(axes[num_plots-1], 0.9), 'fontweight': 'normal'})

    # use tight_layout to adjust the spacing between subplots
    plt.tight_layout()
    
    # reduce the sapce between plots
    if "heatmap" in fig_info.get('plot_type'):
        fig.subplots_adjust(wspace=0, hspace=0.5)

    else:
        fig.subplots_adjust(wspace=0, hspace=0.5)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')
    return fig, axes
