import numpy as np
import matplotlib.pyplot as plt
from scale_bar import add_scalebar


def plot_idxs(array,idxs,figsize=(10,10), title=None, values = None, tick_off=True):
    
    plt.figure(figsize=figsize)
    
    if title is not None:
        plt.title(title)
        
    n_plot = int(np.sqrt(len(idxs)))
    
    for i in range(n_plot**2):
        plt.subplot(n_plot, n_plot, i+1)
        plt.scatter(array[idxs[i],:,0], array[idxs[i],:,1], c=y_train, s=0.01, cmap='Spectral')
        if values is not None:
            plt.title(str(values[i]))
        if tick_off:
            plt.xticks([])
            plt.yticks([])
        
    return

def plot_low_k_idxs(array, metric, k, title=None):

    idxs_arg = np.argpartition(metric, k)[:k]
    values = metric[idxs_arg]

    idxs_arg_pointer = np.argsort(values)

    values = values[idxs_arg_pointer]
    idxs_arg = idxs_arg[idxs_arg_pointer]

    plot_idxs(array, idxs_arg, title=title, values=values)
    
    return

def plot_high_k_idxs(array, metric, k, title=None):

    idxs_arg = np.argpartition(metric, -k)[-k:]
    values = metric[idxs_arg]

    idxs_arg_pointer = np.argsort(values)

    values = values[idxs_arg_pointer]
    idxs_arg = idxs_arg[idxs_arg_pointer]

    plot_idxs(array, idxs_arg, title=title, values=values)
    
    return

def procrustes_distances(standard_array, array): 
    pds = []
    X_pdx = []

    for i in range(len(array)):
        _,x_pd,d = procrustes(standard_array, array[i])
        pds.append(d)
        X_pdx.append(x_pd)

    pds = np.array(pds)
    X_pdx = np.array(X_pdx)
    print('Procrusted Distance: Mean: ', np.mean(pds), ' STD: ', np.std(pds))
    
    return pds, X_pdx




def plot_matrix(embs, alphas, betas, 
                sil_scores, 
                t_scores,
                y_train, 
                neg_y_axis=False, 
                ttitle='Trustworthiness',
                stitle='Silhouette Score',
                vt=False, vmint = None, vmaxt=None,
                vs=False, vmins = None, vmaxs=None,
                colorbar=True,
                xlabel=True,
                ylabel=True,
                plot_matrix=True,
                savename=None):
    plt.figure()
    if vs:
        plt.imshow(sil_scores[1:,1:].T, cmap='cividis', origin='lower', vmax=vmaxs, vmin=vmins)
    else:
        plt.imshow(sil_scores[1:,1:].T, cmap='cividis', origin='lower')
    if xlabel:
        plt.xlabel(r'$\lambda_a$', fontsize=30)
    if ylabel:
        plt.ylabel(r'$\lambda_r$', fontsize=30)
    
    if xlabel:    
        plt.xticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10, fontsize=20)
    else:
        plt.xticks([])
    if ylabel:
        plt.yticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10, fontsize=20)
    else:
        plt.yticks([])
    #plt.yticks([0,2,4,6,8,10],np.array([0,2,4,6,8,10])/10)
    plt.title(stitle, fontsize=30)
    if colorbar:
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=15)
    plt.savefig(savename+'_k1k2_silhouette.eps', bbox_inches='tight')

    plt.figure()
    if vt:
        plt.imshow(t_scores[1:,1:].T, cmap='brg', origin='lower', vmax=vmaxt, vmin=vmint)
    else:
        plt.imshow(t_scores[1:,1:].T, cmap='brg', origin='lower')
    if xlabel:
        plt.xlabel(r'$\lambda_a$', fontsize=30)
    if ylabel:
        plt.ylabel(r'$\lambda_r$', fontsize=30)

    if xlabel:    
        plt.xticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10, fontsize=20)
    else:
        plt.xticks([])
    if ylabel:
        plt.yticks([0,2,4,6,8],(np.array([0,2,4,6,8])+1)/10, fontsize=20)
    else:
        plt.yticks([])
    plt.title(ttitle, fontsize=30)
    #plt.yticks([0,2,4,6,8,10],np.array([0,2,4,6,8,10])/10)
    if colorbar:
        cbar = plt.colorbar()
        cbar.ax.tick_params(labelsize=15)
    if savename is not None:
        plt.savefig(savename+'_k1k2_trustworthiness.eps', bbox_inches='tight')
    
    if not plot_matrix:
        return
    
    # Get the default figure size
    default_figsize = plt.rcParams['figure.figsize']

    N_width = len(alphas)-1
    N_height = len(betas)-1

    # Calculate the new figure size (9 times the default)
    new_figsize_width = default_figsize[0] * N_width
    new_figsize_height = default_figsize[1] * N_height

    # Create a figure with the new size
    fig = plt.figure(figsize=(new_figsize_width, new_figsize_height))
    
    try:
        fig.set_layout_engine(None)  # Matplotlib ≥3.8
    except Exception:
        pass
    
    gs = fig.add_gridspec(
        N_height, N_width,
        # tighten these if you want even less outer whitespace
        left=0.06, right=0.99, bottom=0.07, top=0.98,
        wspace=0.20, hspace=0.22
    )
    
    ysign = 1
    if neg_y_axis:
        ysign = -1
    
    for i in range(1,len(alphas)):
        colum_id = i
        for j in range(1, len(betas)):
            row_id = len(betas) - j
            #ax = plt.subplot(N_height, N_width, (row_id - 1) * N_width + colum_id)
            ax = fig.add_subplot(gs[row_id-1, colum_id-1])
            
            idx_ch = i*len(alphas)+j
            ax.scatter(embs[idx_ch][:, 0], ysign*embs[idx_ch][:, 1],
                       c=y_train, s=0.1, cmap='Spectral')

            # Remove per-axes padding INSIDE the axes so the cloud fills the panel
            ax.margins(x=0, y=0)
            #ax.set_title(r'$\lambda_a$=' + '{:0.1f}'.format(alphas[i]) + ' $\lambda_r$=' + '{:0.1f}'.format(betas[j]))
            if colum_id == 1:
                ax.set_ylabel('{:0.1f}'.format(betas[j]), fontsize=60)
            if row_id == len(betas) - 1:
                ax.set_xlabel(r'{:0.1f}'.format(alphas[i]), fontsize=60)
            add_scalebar(ax, hidex=True, hidey=True)
    fig.supxlabel(r'$\lambda_a$', fontsize=60, y=0.02) 
    fig.supylabel(r'$\lambda_r$', fontsize=60)
    if savename is not None:
        plt.savefig(savename+'.png', bbox_inches='tight', dpi=30)
    
    return

def plot_rows(embs,alphas,betas,sil_scores,t_scores,y_train,neg_y_axis=False,savename=None):
    
    # Get the default figure size
    default_figsize = plt.rcParams['figure.figsize']
    
    N_width = len(alphas)-1
    N_height = 1
    
    # Calculate the new figure size (9 times the default)
    new_figsize_width = default_figsize[0] * N_width
    new_figsize_height = default_figsize[1]
    
    fig = plt.figure(figsize=(new_figsize_width, new_figsize_height), layout="constrained")
    
    gs = fig.add_gridspec(
        N_height, N_width,
        # tighten these if you want even less outer whitespace
        left=0.06, right=0.99, bottom=0.07, top=0.98,
        wspace=0.20, hspace=0.22
    )
    
    ysign = 1
    if neg_y_axis:
        ysign = -1
    
    for i in range(1,len(alphas)):
        j = 0
        idx_ch = i*len(alphas)+j
        
        #ax = plt.subplot(N_height, N_width, i)
        ax = fig.add_subplot(gs[0, i-1])
        ax.scatter(embs[idx_ch][:,0], ysign*embs[idx_ch][:,1], c=y_train, s=0.1, cmap='Spectral')
        #ax.set_title(r'$k_1$='+'{:0.1f}'.format(alphas[i])+' $k_2$='+'{:0.1f}'.format(betas[j]))
        
        if i==1:
            ax.set_ylabel(r'{:0.1f}'.format(0.0), fontsize=60)
        ax.set_xlabel(r'{:0.1f}'.format(betas[i]), fontsize=60)
        
        SIL = sil_scores[i,j]
        T = t_scores[i,j]
        #ax.text(0.7,0.8,'SIL: {:0.3}'.format(SIL), fontsize=30, horizontalalignment='right', transform=ax.transAxes)
        #ax.text(0.7,0.75,'T: {:0.3}'.format(T), fontsize=30, horizontalalignment='right', transform=ax.transAxes)
        
        add_scalebar(ax, hideframe=False)
    
    fig.supxlabel(r'$\lambda_a$', fontsize=60, y=-0.15) 
    fig.supylabel(r'$\lambda_r$', fontsize=60)
        
    if savename is not None:
        plt.savefig(savename+'_alphas.png', bbox_inches='tight', dpi=30)
    
    fig = plt.figure(figsize=(new_figsize_width, new_figsize_height), layout="constrained")
    
    gs = fig.add_gridspec(
        N_height, N_width,
        # tighten these if you want even less outer whitespace
        left=0.06, right=0.99, bottom=0.07, top=0.98,
        wspace=0.20, hspace=0.22
    )
    

    
    for j in range(1,len(betas)):
        i = 0
        idx_ch = i*len(alphas)+j
        
        #ax = plt.subplot(N_height, N_width, j)
        ax = fig.add_subplot(gs[0, j-1])
        ax.scatter(embs[idx_ch][:,0], ysign*embs[idx_ch][:,1], c=y_train, s=0.1, cmap='Spectral')
        #ax.set_title(r'$k_1$='+'{:0.1f}'.format(alphas[i])+' $k_2$='+'{:0.1f}'.format(betas[j]))
        
        if j==1:
            ax.set_ylabel(r'{:0.1f}'.format(0.0), fontsize=60)
        ax.set_xlabel(r'{:0.1f}'.format(alphas[j]), fontsize=60)
        add_scalebar(ax, hidex=True, hidey=True)
    
    fig.supxlabel(r'$\lambda_r$', fontsize=60, y=-0.15) 
    fig.supylabel(r'$\lambda_a$', fontsize=60)
    if savename is not None:
        plt.savefig(savename+'_betas.png', bbox_inches='tight', dpi=30)
            
    return


