import numpy as np

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.colors as colors
from matplotlib.pyplot import ylabel
from matplotlib.font_manager import FontProperties

plt.rcParams["font.family"] = "sans-serif"
# Enable LaTeX and set the preamble
# plt.rcParams['text.usetex'] = True
# plt.rcParams['text.latex.preamble'] = r'\usepackage{mathptmx} \usepackage{amsmath}'


def plot_SVD(u, s, vt, dataset, fig_title):
    """function to plot the SVD of our Dataset

    Args:
        u (np.array): left singular vectors, orthonormal vectors 
        s (np.array): Diagonal singular value matrix (squareroot of eigenvalues of A^{T}A on diagonal)
        vt (np.array): Right singular vectors, orthonormal vectors
        dataset (np.array): the dataset to be decomposed (MxN)
        fig_title (String): Title of Figure
    """

    # Define figure
    fig, axes = plt.subplots(1, 4, figsize=(22, 8), dpi=200, facecolor='w')
    cmap = plt.get_cmap('bwr')

    # list of input, singular vectors and singular values for looping
    matrices = [dataset, u, s, vt]

    # norm for plotting of colour bar
    norm = colors.TwoSlopeNorm(vcenter=0, vmin=-.5, vmax=.5)

    # loop over the four matrices
    for i, ax in enumerate(axes.flatten()):
        im = ax.matshow(matrices[i], cmap=cmap, norm=norm)

    # remove the ticks and make tick labels
    y_tick_labels_1 = np.arange(1, 15, 1)
    y_tick_labels_2 =  np.arange(1, 15, 1)
    # x_tick_labels = ['Bias', 'Class 1', 'Class 2', 'Class 3', 'Class 4', 'Class 5', 'Class 6', 'Class 7', 'Class 8']
    x_tick_labels = ['Class 1', 'Class 2', 'Class 3', 'Class 4', 'Class 5', 'Class 6', 'Class 7', 'Class 8']


    # the titles 
    titles = ['$\Sigma^{yx}$', '$U$', '$S$', '$V^{T}$']

    for i, ax in enumerate(axes.flatten()):
        
        sub_title = ax.set_title(titles[i], fontsize = 34)
        ax.tick_params(labeltop=False, labelbottom = True)

        if i in {0, 1}:

            # make grid
            ax.set_xticks(np.arange(-.5, len(x_tick_labels)), minor=True)
            ax.set_yticks(np.arange(-.5, len(y_tick_labels_1)), minor=True)

            # Gridlines based on minor ticks
            ax.grid(which='minor', color='k', linestyle='-', linewidth=1)

            # remove ticks
            ax.tick_params(which='both', left=False, right=False, top=False, bottom=False, pad=1)
            ax.set_xticks(range(len(x_tick_labels)))

            # specific aspects of the two first axes 0 and 1
            if i == 0:
                ax.set_ylabel("Properties",fontweight='bold', fontsize = 20, labelpad=6)
                ax.set_yticks(range(len(y_tick_labels_1)))
                ax.set_yticklabels(y_tick_labels_1, fontsize=17)
                ax.set_xticklabels(x_tick_labels, rotation= -45, fontsize=16, ha="left", rotation_mode="anchor")

                # y label is set
                ax.set_xlabel("Items",fontweight='bold', fontsize = 20, labelpad=10)

            if i == 1:
                ax.set_ylabel("Properties",fontweight='bold', fontsize = 20, labelpad=6)
                ax.set_yticks(range(len(y_tick_labels_2)))
                ax.set_yticklabels(y_tick_labels_2, fontsize=17)
                # ax.set_xticklabels([1,2,3,4,5,6,7,8,9], fontsize=17, ha="left", rotation_mode="anchor")
                ax.set_xticklabels([1,2,3,4,5,6,7,8], fontsize=17, ha="left", rotation_mode="anchor")
                # y label is set
                ax.set_xlabel("Modes",fontweight='bold', fontsize = 20, labelpad=10)


        if i in {2, 3}:
            # make grid
            ax.set_xticks(np.arange(-.5, len(x_tick_labels)), minor=True)
            ax.set_yticks(np.arange(-.5, len(x_tick_labels)), minor=True)
            # Gridlines based on minor ticks
            ax.grid(which='minor', color='k', linestyle='-', linewidth=1)

            # remove ticks
            ax.tick_params(which='both', left=False, right=False, top=False, bottom=False, pad=1)
            # set x and y and remove add tick labels
            ax.set_xticks(range(len(x_tick_labels)))
            # ax.set_xticklabels([1,2,3,4,5,6,7,8,9], fontsize=17, ha="left", rotation_mode="anchor")
            # ax.set_yticklabels([0,1,2,3,4,5,6,7,8,9], fontsize=17)
            ax.set_xticklabels([1,2,3,4,5,6,7,8], fontsize=17, ha="left", rotation_mode="anchor")
            ax.set_yticklabels([0,1,2,3,4,5,6,7,8], fontsize=17)

            # set the y label
            ax.set_ylabel("Modes",fontweight='bold', fontsize = 20, labelpad=6)

            # specific aspects of the two first axes 2 and 3
            if i == 2:
                # set the x label 
                ax.set_xlabel("Modes",fontweight='bold', fontsize = 20, labelpad=10)

                # a little loop which draws the labels of our singular values in our cells <3
                # for j in range(9):
                for j in range(8):
                    ax.text(j, j, f'$a_{{{j+1}}}$', va='center', ha='center', fontsize = 16)

            if i == 3:
                ax.set_xlabel("Items",fontweight='bold', fontsize = 20, labelpad=10)
                ax.set_xticklabels(x_tick_labels, rotation= -45, fontsize=16, ha="left", rotation_mode="anchor")


    # add a supertitle
    fig.suptitle(fig_title, fontweight='bold', fontsize = 40)

    # set the spacing between subplots
    plt.subplots_adjust(left=0.1,
                        bottom=0.1, 
                        right=0.9, 
                        top=0.9, 
                        wspace=0.7, 
                        hspace=0.4)

    # add a colourbar
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.82, 0.3, 0.025, 0.4])
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.ax.tick_params(size=0, labelsize=15)
    

    # add equal sign
    fig.text(0.24, 0.5, '=', fontsize=40, fontweight='bold', ha='center', va='center')
    fig.text(0.44, 0.5, r'$\times$', fontsize=45, fontweight='bold', ha='center', va='center')
    fig.text(0.64, 0.5, r'$\times$', fontsize=45, fontweight='bold', ha='center', va='center')

    return fig, axes


def plot_data(dataset, xlabel, x_tick_labels):
    """function to plot the SVD of our Dataset

    Args:
        dataset (np.array): the dataset to be plotted as matrix
        ylabel (String): y label of the plot
        xlabel (String): x label of the plot
    """
    dataset = dataset/3.5

    # Define figure
    fig, ax = plt.subplots(1, 1, figsize=(4.5, 6), dpi=100, facecolor='w')
    cmap = plt.get_cmap('bwr')
    # norm for plotting of colour bar
    norm = colors.TwoSlopeNorm(vcenter=0, vmin=-.5, vmax=.5)

    # remove the ticks and make tick labels
    y_tick_labels_1 = np.arange(1, dataset.shape[0], 1)

    ax.matshow(dataset, cmap=cmap, norm=norm)

    # make grid
    ax.set_xticks(np.arange(-.5, len(x_tick_labels)), minor=True)
    ax.set_yticks(np.arange(-.5, len(y_tick_labels_1)), minor=True)

    # Gridlines based on minor ticks
    ax.grid(which='minor', color='k', linestyle='-', linewidth=1.6)
    ax.set_xticks(range(len(x_tick_labels)))

    # remove ticks
    ax.tick_params(which='both', left=False, right=False, top=False, bottom=False, labelbottom=True, labeltop=False, pad=1)
    ax.set_xticklabels(x_tick_labels, fontsize=16)
    # remove y tick labels
    ax.set_yticklabels([])

    return fig, ax

def plot_data_axis(ax,
                   dataset,
                   xlabel,   
                   x_tick_labels):
    """function to plot the SVD of our Dataset

    Args:
        dataset (np.array): the dataset to be plotted as matrix
        ylabel (String): y label of the plot
        xlabel (String): x label of the plot
    """
    dataset = dataset/3.5
    cmap = plt.get_cmap('bwr')
    # norm for plotting of colour bar
    norm = colors.TwoSlopeNorm(vcenter=0, vmin=-.5, vmax=.5)

    # remove the ticks and make tick labels
    y_tick_labels_1 = np.arange(1, dataset.shape[0], 1)

    ax.matshow(dataset, cmap=cmap, norm=norm)

    # make grid
    ax.set_xticks(np.arange(-.5, len(x_tick_labels)), minor=True)
    ax.set_yticks(np.arange(-.5, len(y_tick_labels_1)), minor=True)

    # Gridlines based on minor ticks
    ax.grid(which='minor', color='k', linestyle='-', linewidth=1.6)
    ax.set_xticks(range(len(x_tick_labels)))

    # remove ticks
    ax.tick_params(which='both', left=False, right=False, top=False, bottom=False, labelbottom=True, labeltop=False, pad=1)
    ax.set_xticklabels(x_tick_labels, fontsize=16)
    # remove y tick labels
    ax.set_yticklabels([])

    return ax