""" 
This module provides utility functions for post processing the results."""

from typing import List, Tuple
from matplotlib import cm as mcm
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

def create_custom_colormap(colors, cmap_name:str=None):
    """
    Creates a custom colormap.

    Args:
        colors (list): The list of colors to use in the colormap.
        cmap_name (str): The name of the existing colormap to modify.

    Returns:
        new_cmap (matplotlib.colors.LinearSegmentedColormap): The new colormap.
    """
    if cmap_name is None:
        
        # Create the new colormap using LinearSegmentedColormap
        new_cmap = mcolors.LinearSegmentedColormap.from_list("new_cmap", colors)
    
    else:
        # Get the existing colormap
        cmap = plt.get_cmap(cmap_name)

        # Create a new colormap from the existing colormap
        cmap_colors = cmap(np.linspace(0, 1, cmap.N))

        for i, color in enumerate(colors):
            cmap_colors[i] = color

        # Create the new colormap using LinearSegmentedColormap
        new_cmap = mcolors.LinearSegmentedColormap.from_list('new_cmap', cmap_colors, cmap.N)

    return new_cmap


def plot_shap_vals(layer_number, shapley_vals_layerwise, input_tokens, fig_info:dict, removed_indices: List[int], normalize:bool=True, axis=None, save_path=None, fixed_width:bool=False):
    """
    Plots the heatmap of Shapley values for a specific layer in a neural network.

    Args:
        layer_number (int): The index of the layer for which the Shapley values are plotted.
        shapley_vals_layerwise (numpy.ndarray): The array of Shapley values for each layer.
        input_tokens (list): The list of input tokens.
        fig_info (dict): The dictionary containing figure information.

    Returns:
        None
    """
    cmap = fig_info.get("cmap", "YlOrRd")
    vmin = fig_info.get("vmin", 0)
    vmax = fig_info.get("vmax", 1)
    width = fig_info.get("width", 20)
    height = fig_info.get("height", 3)
    dpi = fig_info.get("dpi", 300)
    font_size = fig_info.get("font_size", 10)
    plot_type = fig_info.get("plot_type","heatmap")
    
    if not fixed_width:
        width = len(input_tokens) * 2.5
    
    if axis is None:
        fig, axes = plt.subplots(1, 1, figsize=(width / 2.54, height / 2.54), dpi=dpi)
    else:
        axes=axis
        plt.gcf().set_size_inches(width / 2.54, height / 2.54)
        fig = plt.gcf()

    # Set the default font to "Times New Roman"
    # plt.rcParams["font.family"] = "Arial"

    # plot heatmap of shapley_vals
    

    shap_vals = shapley_vals_layerwise[layer_number,:]
    input_tokens, shap_vals = get_modified_vals(shap_vals, input_tokens, removed_indices)
    
    if normalize:
        shap_vals = get_normalized_shap_vals(shap_vals)

    if plot_type == "heatmap":
        axes.imshow(shap_vals.reshape(1,-1), cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')

        for i in range(len(input_tokens)):
            axes.axvline(x=i+0.5, color="black", linestyle="-", linewidth=1)
            axes.text(i, 0, str(shap_vals[i].round(2)), color="black", fontsize=font_size-2, fontweight="bold", ha='center')

        axes.set_yticks([])
        axes.set_xticks(list(range(len(input_tokens))))
        axes.set_xticklabels(labels=input_tokens, fontsize=font_size)
        
    if plot_type == "bare_heatmap":

        axes.imshow(shap_vals.reshape(1,-1), cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')

        # Add input tokens as text
        for i, token in enumerate(input_tokens):
            # add vertical lines between tokens
            axes.axvline(x=i+0.5, color="black", linestyle="-", linewidth=1)
            axes.text(i, 0, token, ha='center', va='center', color='black')

        axes.set_yticks([])
        axes.set_xticks(list(range(len(input_tokens))))
        axes.set_xticklabels(labels=[])
        
        
    if plot_type == "bar":

        axes.bar(range(len(input_tokens)), shap_vals, color="skyblue")

        # Add input tokens as text
        for i, token in enumerate(input_tokens):
            axes.text(i, shap_vals[i]+0.01, str(shap_vals[i].round(2)), color="black", fontsize=font_size-2, fontweight="bold", ha='center')

        axes.set_yticks([])
        axes.set_xticks(list(range(len(input_tokens))))
        axes.set_xticklabels(labels=input_tokens)
        axes.set_xticklabels(labels=input_tokens, fontsize=font_size)

        # If the text above your barplot is going outside of the box, you can adjust the y-axis limits to make more room for the text.
        # Adjust the y-axis upper limits:
        axes.set_ylim(0, 1.25)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')

    return fig, axes


def plot_attr_vals(attr_vals, input_tokens, fig_info:dict, removed_indices:List[int], normalize: bool=False, axis=None, save_path=None, fixed_width:bool=False):
    """
    Plots the heatmap of Shapley values for a specific layer in a neural network.

    Args:
        layer_number (int): The index of the layer for which the Shapley values are plotted.
        shapley_vals_layerwise (numpy.ndarray): The array of Shapley values for each layer.
        input_tokens (list): The list of input tokens.
        fig_info (dict): The dictionary containing figure information.

    Returns:
        None
    """
    cmap = fig_info.get("cmap", "YlOrRd")
    vmin = fig_info.get("vmin", 0)
    vmax = fig_info.get("vmax", 1)
    width = fig_info.get("width", 20)
    height = fig_info.get("height", 3)
    dpi = fig_info.get("dpi", 300)
    font_size = fig_info.get("font_size", 10)
    plot_type = fig_info.get("plot_type","heatmap")
    
    if not fixed_width:
        width = len(input_tokens) * 2.5
    
    if axis is None:
        fig, axes = plt.subplots(1, 1, figsize=(width / 2.54, height / 2.54), dpi=dpi)
    else:
        axes=axis
        plt.gcf().set_size_inches(width / 2.54, height / 2.54)
        fig = plt.gcf()

    # Set the default font to "Times New Roman"
    # plt.rcParams["font.family"] = "Arial"

    # plot heatmap of shapley_vals
    
    input_tokens, shap_vals = get_modified_vals(attr_vals, input_tokens, removed_indices)
    if normalize:
        shap_vals = get_normalized_shap_vals(shap_vals)

    if plot_type == "heatmap":
        axes.imshow(shap_vals.reshape(1,-1), cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')

        for i in range(len(input_tokens)):
            axes.axvline(x=i+0.5, color="black", linestyle="-", linewidth=1)
            axes.text(i, 0, str(shap_vals[i].round(2)), color="black", fontsize=font_size-2, fontweight="bold", ha='center')

        axes.set_yticks([])
        axes.set_xticks(list(range(len(input_tokens))))
        axes.set_xticklabels(labels=input_tokens, fontsize=font_size)
        
    if plot_type == "bare_heatmap":
        axes.imshow(shap_vals.reshape(1,-1), cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')

        # Add input tokens as text
        for i, token in enumerate(input_tokens):
            # add vertical lines between tokens
            axes.axvline(x=i+0.5, color="black", linestyle="-", linewidth=1)
            axes.text(i, 0, token, ha='center', va='center', color='black')

        axes.set_yticks([])
        axes.set_xticks(list(range(len(input_tokens))))
        axes.set_xticklabels(labels=[])
        
        
    if plot_type == "bar":
        axes.bar(range(len(input_tokens)), shap_vals, color="skyblue")

        # Add input tokens as text
        for i, token in enumerate(input_tokens):
            axes.text(i, shap_vals[i]+0.01, str(shap_vals[i].round(2)), color="black", fontsize=font_size-2, fontweight="bold", ha='center')

        axes.set_yticks([])
        axes.set_xticks(list(range(len(input_tokens))))
        axes.set_xticklabels(labels=input_tokens)
        axes.set_xticklabels(labels=input_tokens, fontsize=font_size)

        # If the text above your barplot is going outside of the box, you can adjust the y-axis limits to make more room for the text.
        # Adjust the y-axis upper limits:
        axes.set_ylim(0, 1.25)

    if save_path is not None:
        fig.savefig(save_path, format='pdf', dpi=dpi, bbox_inches='tight')

    return axes

def get_normalized_shap_vals(shap_vals: np.ndarray, epsilon: float = 1e-20) -> np.ndarray:
    """
    Normalize the shap values in the last dimension so that the sum of each row is 1.0.

    Args:
        shap_vals (torch.Tensor): A tensor of shape number_layers x number_heads x length_tokens x length_tokens representing the shap values.

    Returns:
        torch.Tensor: A tensor of shape number_layers x number_heads x length_tokens x length_tokens representing the normalized shap values.
    """
    
    # check all values are equal
    if (shap_vals.max(axis=-1, keepdims=True)==shap_vals.min(axis=-1, keepdims=True)).all():
        normalized_shap_vals = (shap_vals-shap_vals.min(axis=-1, keepdims=True)+epsilon) / (shap_vals.max(axis=-1, keepdims=True)-shap_vals.min(axis=-1, keepdims=True)+epsilon)
    else:
        normalized_shap_vals = (shap_vals-shap_vals.min(axis=-1, keepdims=True)) / (shap_vals.max(axis=-1, keepdims=True)-shap_vals.min(axis=-1, keepdims=True))
    return normalized_shap_vals


def plot_attr_vals_html(tokens, scores,  model_type, cmap='YlOrRd', font_size='5px'):
    colormap = mcm.get_cmap(cmap)  # get the colormap
    max_score = max(scores)  # get the maximum score
    min_score = min(scores)  # get the minimum score

    html_output = f'<span style="color: maroon; font-size: {font_size}; font-family: Times New Roman;">{model_type }</span><br> '  # add the start string to the output

    # crteate empty html
    # html_output = ''

    for token, score in zip(tokens, scores):
        normalized_score = (score - min_score) / (max_score - min_score)  # normalize the score to [0, 1]
        rgba = colormap(normalized_score)  # get the color corresponding to the score
        color = mcolors.rgb2hex(rgba)  # convert the color to hex format

        # if the background color is dark, use white text, otherwise use black text
        text_color =  "black"

        html_output += f'<span style="background-color: {color}; color: {text_color}; font-size: {font_size}; font-family: Times New Roman;">{token} </span>'  # add the colored token to the output

    return html_output


def get_modified_vals(shap_vals: np.ndarray, input_tokens: List[str], removed_indices: List[int]) -> Tuple[List[str], np.ndarray]:
    """
    Modifies the shap values for the removed indices.

    Parameters:
        shap_vals (array): An array of numbers between 0 and 1 of length equal to the number of words in input_tokens.
        input_tokens (list): A list of words.
        removed_indices (list): A list of indices to remove.

    Returns:
        modified_shap_vals (array): An array of numbers between 0 and 1 of length equal to the number of words in input_tokens.
    """
    if len(input_tokens) != len(shap_vals):
        raise ValueError(f"""
                         Input tokens and shap values must have the same length. 
                         Input tokens length: {len(input_tokens)}, shap values length: {len(shap_vals)}
                        """)
    
    if len(removed_indices) == 0:
        modified_input_tokens, modified_shap_vals = input_tokens, shap_vals

    else:
        modified_input_tokens = [input_tokens[i] for i in range(len(input_tokens)) if i not in removed_indices]
        modified_shap_vals = np.array([shap_vals[i] for i in range(len(input_tokens)) if i not in removed_indices])

    return modified_input_tokens, modified_shap_vals
    

def rgba_to_hex(rgba):
    """
    Convert an RGBA color to its hexadecimal representation.

    Args:
    rgba (tuple): A tuple containing the red, green, blue, and alpha values of the color.

    Returns:
    str: The hexadecimal representation of the RGBA color.

    Example:
    >>> rgba = (0.98, 0.165, 0.333, 1.0)
    >>> hex_color = rgba_to_hex(rgba)
    >>> print(hex_color)
    #e26b21
    """
    r, g, b, a = rgba
    return "#{:02x}{:02x}{:02x}".format(int(r*255), int(g*255), int(b*255))

