import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def draw_matrix(matrix, names, fmt='.2f', vmin=0, vmax=None,
                annotsize=20, labelsize=18, xlabel='Predicted', ylabel='Actual', return_fig=False):
    mask = np.isnan(matrix)
    ax = sns.heatmap(
        matrix,
        mask=mask,
        annot=True, annot_kws=dict(size=annotsize), fmt=fmt, vmin=vmin, vmax=vmax, linewidth=1,
        cmap=sns.color_palette("light:b", as_cmap=True),
        xticklabels=names,
        yticklabels=names,
    )
    ax.set_facecolor('white')
    ax.tick_params(axis='x', labelsize=labelsize, rotation=45)
    ax.tick_params(axis='y', labelsize=labelsize, rotation=45)
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)

    if return_fig:
        fig = plt.gcf()
        plt.close()
        return fig
    else:
        plt.show()
        plt.close()


def draw_matrix_adj(matrix, names, fmt='.2f', vmin=0, vmax=None,
                    annotsize=20, labelsize=18, xlabel='Predicted', ylabel='Actual', return_fig=False, figsize=(20, 18),
                    title=None):
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(
        matrix,
        annot=True, annot_kws=dict(size=annotsize), fmt=fmt, vmin=vmin, vmax=vmax, linewidth=1,
        cmap=sns.color_palette("light:b", as_cmap=True),
        xticklabels=names,
        yticklabels=names,
        ax=ax
    )
    ax.set_facecolor('white')
    ax.tick_params(axis='x', labelsize=labelsize, rotation=0)
    ax.tick_params(axis='y', labelsize=labelsize, rotation=0)
    if xlabel:
        plt.xlabel(xlabel)
    if ylabel:
        plt.ylabel(ylabel)

    if title:
        plt.title(title)

    if return_fig:
        return fig
    else:
        plt.show()