import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
plt.style.use('seaborn-whitegrid')


def heat_plotter(array_plot, values = True, plot_args: dict = {}, fig = None, ax = None, vmin = 0, vmax = 100, colourbar = True, number_of_dp = 2):
    '''
    
    This function allows you to plot an array as a heatmap with labels.
    
    Arguments
    ---------
        array_plot: array
            This is the array that will be plotted.
        
        values: bool
            This value dictates whether the graph will have labels on the 
            pixels.
        
        plot_args: dict
            This is a dictionary containing the arguments and values that 
            will be passed to the image plotting function. These need to 
            be compatible with plt.imshow(). Please use the dedicatied vmin and vmax
            instead of this dictionary for those arguments.
        
        fig: matplotlib figure
            This is the matplotlib figure that will be drawn on.

        ax: matplotlib axis
            This needs to be a matplotlib axis and is where the graph will be drawn.
        
        vmin: float
            This is the value of the lowest colour in the plot.
        
        vmax: float
            This is the value of the highest colour in the plot.

        colourbar: bool
            Whether to add a colour bar to the graph.
    
    Returns:
    ---------

        fig: matplotloib figure
            This is a figure containing the axes.

        axes: matplotlib axes
            These contain the plots.
            
    '''
    
    if (ax is None) or (fig is None):
        fig, ax = plt.subplots(1,1,figsize = (8,3))

    im = ax.imshow(array_plot, origin = 'upper', **plot_args, vmin = vmin, vmax = vmax)
    if colourbar: fig.colorbar(im, shrink = 0.6)

    ax.grid(False)
    
    colour_switch = vmin + 0.6*(vmax-vmin)
    for i in np.arange(array_plot.shape[0]):
        for j in np.arange(array_plot.shape[1]):
            value = array_plot[i,j]
            value_plot = round(value, number_of_dp)
            if number_of_dp == 0:
                value_plot = int(value_plot)
            
            ax.text(np.arange(array_plot.shape[1])[j],np.arange(array_plot.shape[0])[i],
                '{}'.format(value_plot), fontsize =10, 
                color = ['black','white'][int(value>colour_switch)], verticalalignment='center', 
                horizontalalignment = 'center')

    return fig, ax




