from __future__ import print_function, division
import torch
from matplotlib import cm
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import numpy as np
from numpy.polynomial.polynomial import polyval

import math
from scipy.special import expit
from scipy.stats.kde import gaussian_kde

from model.model_utils import score_graphs_batch, metrics_at_thresholds
from model.metrics import edge_mse, batch_graph_metrics
from utils.util_funcs import upper_tri_as_vec, upper_tri_as_vec_batch, correlation_from_covariance, \
    percentile_upper_matrices, round_up, sparsity, normalize_slices

import copy
# https://matplotlib.org/stable/tutorials/colors/colormaps.html
fc_cm = copy.copy(mpl.cm.get_cmap("RdYlGn")) #cm.RdYlGn # diverging
raw_sc_cm = copy.copy(mpl.cm.get_cmap("gist_gray")) #cm.gist_gray# inferno #perceptually linear. Might need to cutt off crazy high vals
binary_sc_cm = copy.copy(mpl.cm.get_cmap("gist_gray")) #cm.gist_gray
raw_pred_cm = copy.copy(mpl.cm.get_cmap("inferno")) #cm.inferno
intermed_out_cm = copy.copy(mpl.cm.get_cmap("gist_gray")) #cm.gist_gray

inferno_cm = cm.inferno

import matplotlib.colors as colors



colormaps = [fc_cm, raw_sc_cm, binary_sc_cm, raw_pred_cm, intermed_out_cm]
for cm in colormaps:
    cm.set_bad("black") #nans will be red


def log_heatmaps(scs, preds, prior, thresholds, num_patients, subject_ids=None, scan_dirs=None, title_str = ""):
    # is scs binary?
    is_scs_binary = np.allclose(scs, ((scs > 0) + 0.0))

    if is_scs_binary:
        return binary_label_heatmap_outputs(scs, preds, thresholds, num_patients, title_str = title_str)
    else:
        return contin_label_heatmap_outputs(scs, preds, thresholds, num_patients, prior=prior, subject_ids=subject_ids, scan_dirs=scan_dirs, title_str=title_str)


def binary_label_heatmap_outputs(scs, preds, thresholds, num_patients, title_str = ""):
    assert len(scs.shape)==3 and len(preds.shape)==3 and scs.shape[1]==preds.shape[1]
    _, N, _ = scs.shape

    fs = 8
    ncols = 2 + len(thresholds) # raw pred in front, true adj at end

    fig, axes = plt.subplots(nrows=num_patients, ncols=ncols)
    #set titles of each column
    axes[0, 0].set_title(f'Raw Pred', fontsize=fs)
    axes[0, ncols-1].set_title(f'Raw Adj', fontsize=fs)
    for i, threshold in enumerate(thresholds):
        threshold_round = round(threshold, 3)
        axes[0, i+1].set_title(f'> {threshold_round}', fontsize=fs)

    # find maximum value of all adjs and preds
    # pull out first num_patients, make matrices into rows, find max of rows, ignore indices
    max_val = preds[0:num_patients].view(num_patients, int(N**2)).max(1)[0].max().item()
    max_val = np.max([max_val, 1])
    # max_val = 5 #threshold max_val for more contrast for intermediate values

    # use one colorbar for raw pred, others are binary thus wont require colorbar
    # using 1 cololarbar: https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
    raw_pred_im=0 #for use later
    for row in range(num_patients):
        patient_idx = row
        raw_pred_ax = axes[row,0]
        #cmap = 'plasma'
        #setting vmin, vmax the same for all images makes the single colorbar consistant
        raw_pred_im = raw_pred_ax.imshow(preds[row], vmin=0, vmax=max_val, cmap=raw_pred_cm)

        for i, threshold in enumerate(thresholds):
            ave_pr, ave_re, ave_f1, ave_macro_f1, ave_acc, ave_mcc = \
                score_graphs_batch(threshold=threshold, adjs=scs[patient_idx], preds=preds[row], o='ave')
            im = axes[row, i+1].imshow(preds[row], vmin=0, vmax=1, cmap=intermed_out_cm)
            axes[row, i+1].set_xlabel(f'acc/f1/macro-f1/mcc: {ave_acc:.2f}/{ave_f1:.2f}/{ave_macro_f1:.2f}/{ave_mcc:.2f}', fontsize=fs-3)

        raw_adj_ax = axes[row,ncols-1]
        raw_adj_im  = raw_adj_ax.imshow(scs[row], vmin=0, vmax=1, cmap=binary_sc_cm)

        axes[row, 0].set_ylabel(f'Patient {row}', rotation=90, fontsize=fs)  # size='large')
        # Turn off tick labels
        for col in range(ncols):
            axes[row, col].set_xticklabels([])
            axes[row, col].set_yticklabels([])

    fig.suptitle(f'Model Outputs: ' + title_str, fontsize=fs + 4)
    fig.subplots_adjust(left=0.18, bottom=.05, wspace=.08)
    cbar_ax = fig.add_axes([0.07, 0.1, .03, 0.8]) #[0.85, 0.15, 0.05, 0.7])
    fig.colorbar(raw_pred_im, cax=cbar_ax)
    cbar_ax.yaxis.set_ticks_position('left')

    return fig


def find_best_by(pr, re, f1, macro_f1, acc, mcc, mse, sort_by='acc'):

    if sort_by in ['acc', 'accs']:
        sort_idxs = np.flip(np.argsort(acc))
    elif sort_by in ['mcc', 'mccs']:
        sort_idxs = np.flip(np.argsort(mcc))
    elif sort_by in ['macro_f1', 'macro-f1', 'macro_f1s', 'macro-f1s']:
        sort_idxs = np.flip(np.argsort(macro_f1))
    elif sort_by in ['f1', 'f1s', 'fmsr', 'f_msr', 'fmsrs', 'f_msrs']:
        # best fmsr = largest fmsr ==> descending
        sort_idxs = np.flip(np.argsort(f1))
    elif sort_by == 're':
        sort_idxs = np.flip(np.argsort(re))
    elif sort_by == 'pr':
        sort_idxs = np.flip(np.argsort(pr))
    elif sort_by in ['mse', 'mses']:
        # best loss = smallest loss ==> ascending
        sort_idxs = np.argsort(mse)
    else:
        raise ValueError(f"sort_by {sort_by} must valid metric name")
    return sort_idxs


def polynomial_filters(h1_coeffs, h2_coeffs, xlims = (0, 20), num_points=10):
    assert len(h1_coeffs) is len(h2_coeffs)
    if torch.is_tensor(h1_coeffs):
        h1_coeffs = h1_coeffs.detach().numpy()
    if torch.is_tensor(h2_coeffs):
        h2_coeffs = h2_coeffs.detach().numpy()


    x = np.linspace(xlims[0], xlims[1], num_points)

    # set up axes grid
    ncols = len(h1_coeffs)
    nrows = 2 # H1 and H2
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey='row', constrained_layout=True)

    fs = 10
    for i, ax in enumerate(axes[0, :]):
        ax.set_title(f'{i+1}', fontsize=fs)

    for col, (h1_c, h2_c) in enumerate(zip(h1_coeffs, h2_coeffs)):
        # titles
        h1_c_label = ', '.join([f'{c:.2f}' for c in h1_c])
        h2_c_label = ', '.join([f'{c:.2f}' for c in h2_c])
        axes[0, col].set_title(f'{col+1}\n{h1_c_label}', fontsize=fs-1)
        axes[1, col].set_title(f'{h2_c_label}',fontsize=fs-1)

        # plotting of polynomials
        axes[0, col].plot(x, polyval(x, h1_c))
        axes[1, col].plot(x, polyval(x, h2_c))

        # formatting
        axes[0, col].grid(b=True, axis='y')
        axes[1, col].grid(b=True, axis='y')

    axes[0, 0].set_ylabel(f'H1',  rotation=90, fontsize=fs)
    axes[1, 0].set_ylabel(f'H2',  rotation=90, fontsize=fs)
    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)

    return fig



def model_outputs(fcs, scs, preds, mses, threshold, subject_ids, scan_dirs, prior, fc_pctile=95.0, pred_pctile=95.0, sort_by='acc'):
    assert len(scs.shape)==3 and len(preds.shape)==3 and scs.shape[1]==preds.shape[1] and len(scs)==len(preds)

    fs = 10
    # use prior as prediction...how well do we do on each sc
    _, _, prior_f1, prior_macro_f1, prior_acc, prior_mcc = \
        score_graphs_batch(threshold, adjs=scs, preds=np.tile(prior, (len(scs),1,1)), o='raw')

    # select which plots to choose by the best/worst/median f1/acc/mcc/mse
    pr, re, f1, macro_f1, acc, mcc = \
        score_graphs_batch(threshold, adjs=scs, preds=preds, o='raw')
    sort_idxs = find_best_by(pr=pr, re=re, f1=f1, macro_f1=macro_f1,
                             acc=acc, mcc=mcc, mse=mses, sort_by=sort_by)
    sc_sparsities = sparsity(As=scs, directed=False, self_loops=False)

    # set up labels
    l = len(sort_idxs)
    patient_idxs = [sort_idxs[0], sort_idxs[1], sort_idxs[l//2-1], sort_idxs[l//2], sort_idxs[-2], sort_idxs[-1]]
    patient_txts = ["best1", "best2", "median1", "median2", "worst2", "worst1"]

    # set up axes grid
    ncols = 5 # input fc | prior | raw output | threshold output | truth
    nrows = len(patient_idxs) # best + num_medians + worst
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, constrained_layout=True)

    # set up titles of each column
    axes_titles = [f'FC (%:{fc_pctile:.0f})', f'Prior>0', f'Pred (%:{pred_pctile:.0f})', f'Pred>{threshold:.3f}', f'True Adj>0']
    for i, ax in enumerate(axes[0, :]):
        ax.set_title(axes_titles[i], fontsize=fs)

    # collect all fcs/preds that we are going to plot
    fcs_to_plot = [fcs[j] for j in patient_idxs]
    fc_pos_pctile_lim, fc_neg_pctile_lim  = percentile_upper_matrices(fcs_to_plot, fc_pctile, center=0)
    divnorm_fc = colors.TwoSlopeNorm(vmin=fc_neg_pctile_lim-1e-8, vcenter=0, vmax=fc_pos_pctile_lim)

    preds_to_plot = [preds[j] for j in patient_idxs]
    pred_pctile_lim, _ = percentile_upper_matrices(preds_to_plot, pred_pctile, center=0)


    prior = ((prior > 0) + 0.0)
    # place plots
    for row, idx in enumerate(patient_idxs):
        fc, sc, pred, subject_id, scan_dir = \
            fcs[idx], scs[idx], preds[idx], subject_ids[idx], scan_dirs[idx]

        fc_ax, threshold_prior_ax, raw_output_ax, threshold_output_ax, true_adj_ax =\
            axes[row, 0], axes[row, 1], axes[row, 2], axes[row, 3], axes[row, 4]

        # fc_corr = correlation_from_covariance(fc)#*(np.ones((N, N)) - np.diag(np.ones(N)))
        # upper_limit_corr = .5 # for more rez
        fc_im = fc_ax.imshow(fc, norm=divnorm_fc, cmap=fc_cm)
        threshold_prior_im = threshold_prior_ax.imshow(prior, vmin=0, vmax=1, cmap=binary_sc_cm)
        raw_output_im = raw_output_ax.imshow(pred, vmin=0, vmax=pred_pctile_lim) #lower this for more rez
        threshold_output_im = threshold_output_ax.imshow( (pred > threshold) +0.0, vmin=0, vmax=1, cmap=binary_sc_cm)
        true_adj_im = true_adj_ax.imshow( (sc > 0) + 0.0, vmin=0, vmax=1, cmap=binary_sc_cm)

        raw_output_ax.set_xlabel(f'mse: {mses[idx]:.4f}', fontsize=fs - 2, labelpad=0)
        threshold_output_ax.set_xlabel(f'{macro_f1[idx]:.2f}||{acc[idx]:.2f}||{mcc[idx]:.2f}', fontsize=fs - 2, labelpad=0)
        threshold_prior_ax.set_xlabel(f'{prior_macro_f1[idx]:.2f}||{prior_acc[idx]:.2f}||{prior_mcc[idx]:.2f}', fontsize=fs - 2, labelpad=0)
        true_adj_ax.set_xlabel(f'sparsity {sc_sparsities[idx]:.2f}', fontsize=fs - 2, labelpad=-1)

        # y label information on scan => per row
        axes[row, 0].set_ylabel(f'{patient_txts[row]}\n{subject_ids[idx]}-{scan_dirs[idx]}', rotation=90, fontsize=fs-2)  # size='large')
        # Turn off tick labels
        for col in range(ncols):
            axes[row, col].set_xticklabels([])
            axes[row, col].set_yticklabels([])
            axes[row, col].set_xticks([]) #minor=True
            axes[row, col].set_yticks([])

    title_txt = f'Ave M-F1||Acc||MCC {100*np.mean(macro_f1):.0f}||{100*np.mean(acc):.0f}||{100*np.mean(mcc):.0f} @ {threshold:.2f}.'
    title_txt += f'MSE: {np.mean(mses):.3f}. Sorted {sort_by}.'
    fig.suptitle(title_txt, fontsize=fs + 3)


    cb_fc = plt.colorbar(fc_im, ax=axes[:, 0], shrink=.8, pad=0.00, aspect=60, location='right')
    cb_pred = plt.colorbar(raw_output_im, ax=axes[:, 2], shrink=.8, pad=0.00, aspect=60, location='right')
    cb_fc.ax.tick_params(labelsize=6, rotation=90)
    cb_pred.ax.tick_params(labelsize=6, rotation=90)
    cb_fc.ax.yaxis.set_ticks_position('left')
    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)

    return fig


def intermediate_outputs(model, depth, fcs: np.ndarray, scs: np.ndarray, preds: np.ndarray, subject_ids, scan_dirs, threshold, sort_by='acc', fc_pctile=95, pred_pctile=95):
    batch_size, N, _ = fcs.shape
    # sort data by performance
    return
    # removed loss_per_scan <- use prediction_metrics in model_utils
    scan_mses, scan_maes = None, None #loss_per_scan(scs=scs, preds=preds)
    # select which plots to choose by the best/worst/median f1/acc/mcc/mse
    pr, re, f1, macro_f1, acc, mcc = \
        score_graphs_batch(threshold, adjs=scs, preds=preds, o='raw')
    sort_idxs = find_best_by(pr=pr, re=re, f1=f1, macro_f1=macro_f1,
                             acc=acc, mcc=mcc, mse=scan_mses, sort_by=sort_by)
    sc_sparsities = sparsity(As=scs, directed=False, self_loops=False)

    # set up labels
    l = len(sort_idxs)
    patient_idxs = [sort_idxs[0], sort_idxs[1], sort_idxs[l // 2 - 1], sort_idxs[l // 2], sort_idxs[-2], sort_idxs[-1]]
    patient_txts = ["best1", "best2", "median1", "median2", "worst2", "worst1"]
    #x_subset = x[patient_idxs, :, :]

    # set up plots
    fs = 10
    ncols = 1 + 1 + depth + 1  # Cov, Prior (input), intermed out 1, ... intermed out L, true sc

    fig, axes = plt.subplots(nrows=len(patient_idxs), ncols=ncols,  constrained_layout=True)
    axes_titles = [f"FC ({fc_pctile}%)", "Prior"] + [str(num+1) for num in range(depth-1)] + [f"{depth} ({pred_pctile}%)"] + ["SC"]
    for i, ax in enumerate(axes[0, :]):
        ax.set_title(axes_titles[i], fontsize=fs)

    # collect all fcs/preds that we are going to plot
    fcs_to_plot = [fcs[j] for j in patient_idxs]
    fc_pos_pctile_lim, fc_neg_pctile_lim  = percentile_upper_matrices(fcs_to_plot, fc_pctile, center=0)
    divnorm_fc = colors.TwoSlopeNorm(vmin=fc_neg_pctile_lim-1e-8, vcenter=0, vmax=fc_pos_pctile_lim)

    preds_to_plot = [preds[j] for j in patient_idxs]
    pred_pos_pctile_lim, _ = percentile_upper_matrices(preds_to_plot, pred_pctile, center=0)

    # place plots
    for row, idx in enumerate(patient_idxs):
        fc, sc, pred, subject_id, scan_dir = \
            fcs[idx], scs[idx], preds[idx], subject_ids[idx], scan_dirs[idx]

        fc_ax, prior_ax, intermed_out_axes, true_adj_ax = \
            axes[row, 0], axes[row, 1], axes[row, 2:(2+depth)], axes[row, -1]
        intermed_outs = model.forward_intermed_outs(covs_raw=torch.from_numpy(fc).view(1, N, N))['S_outs']
        intermed_outs = intermed_outs.view(depth+1, N, N).numpy() # remove batch dim

        fc_im = fc_ax.imshow(fc, norm=divnorm_fc, cmap=fc_cm)# vmin=-fc_pctile_lim, vmax=fc_pctile_lim, cmap=fc_cm)
        prior_im = prior_ax.imshow(intermed_outs[0], vmin=0, vmax=pred_pos_pctile_lim, cmap=inferno_cm)
        for im_idx, ax in enumerate(intermed_out_axes):
            ax.imshow(intermed_outs[im_idx+1], vmin=0, vmax=pred_pos_pctile_lim, cmap=inferno_cm)
        true_adj_im = true_adj_ax.imshow(sc, vmin=0, vmax=pred_pos_pctile_lim, cmap=inferno_cm)

        intermed_out_axes[-1].set_xlabel('mse'+f'{scan_mses[idx]:.4f}'[1:], fontsize=fs - 2, labelpad=0)

        #.set_xlabel(f'{macro_f1[idx]:.2f}||{acc[idx]:.2f}||{mcc[idx]:.2f}', fontsize=fs - 3, labelpad=0)
        #prior_ax.set_xlabel(f'{prior_macro_f1[idx]:.2f}||{prior_acc[idx]:.2f}||{prior_mcc[idx]:.2f}', fontsize=fs - 3, labelpad=0)
        true_adj_ax.set_xlabel('spsty'+f'{sc_sparsities[idx]:.2f}'[1:], fontsize=fs - 2, labelpad=0)

        # y label information on scan => per row
        axes[row, 0].set_ylabel(f'{patient_txts[row]}\n{subject_ids[idx]}-{scan_dirs[idx]}', rotation=90,
                                fontsize=fs - 2)  # size='large')
        # Turn off tick labels
        for col in range(ncols):
            axes[row, col].set_xticklabels([])
            axes[row, col].set_yticklabels([])
            axes[row, col].set_xticks([])  # minor=True
            axes[row, col].set_yticks([])

    cb_fc = plt.colorbar(fc_im, ax=axes[:, 0], shrink=.8, pad=0.00, aspect=60,  location='right')
    cb_pred = plt.colorbar(true_adj_im, ax=axes[:, -1],  shrink=.8, pad=0.00, aspect=60, location='right')
    cb_fc.ax.tick_params(labelsize=6, rotation=90)
    cb_fc.ax.yaxis.set_ticks_position('left')
    cb_pred.ax.tick_params(labelsize=6, rotation=90)
    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)

    return fig


def contin_label_heatmap_outputs(scs, preds,  thresholds, num_patients, prior, pt=98.0, subject_ids=None, scan_dirs=None, title_str = ""):
    assert len(scs.shape)==3 and len(preds.shape)==3 and scs.shape[1]==preds.shape[1] and len(scs)==len(preds)
    N = scs.shape[1]

    fs = 10
    num_pre_thresh_plots = 3
    ncols = num_pre_thresh_plots + len(thresholds) + 1 # raw adj + raw_pred+ + prior + threshold preds + threshold adj

    fig, axes = plt.subplots(nrows=num_patients, ncols=ncols)
    #set titles of each column
    axes[0, 0].set_title(f'Raw Adj', fontsize=fs)
    axes[0, 1].set_title(f'Raw Pred', fontsize=fs)
    axes[0, 2].set_title(f'Prior>0', fontsize=fs)

    for i, threshold in enumerate(thresholds):
        threshold_round = round(threshold, 3)
        axes[0, i+num_pre_thresh_plots].set_title(f'> {threshold_round}', fontsize=fs)
    axes[0, ncols-1].set_title(f'True Adj>0', fontsize=fs)


    # there are outliers in the prediction. Let's set everything above the 99th percentile
    # value to be the maximum value
    pred_vals = []
    for i in range(num_patients):
        pred_vals.extend( list(upper_tri_as_vec(preds[i])) ) # change to extend??
    max_cbar_val = np.percentile(pred_vals, pt)
    max_cbar_val = max([max_cbar_val, 1.0])

    #find maximum value of all adjs and preds
    # pull out first num_patients, make matrices into rows, find max of rows, ignore indices
    max_scs   = scs[0:num_patients].view(num_patients, int(N**2)).max(1)[0]
    max_scs   = max_scs.max().item()
    max_preds = preds[0:num_patients].view(num_patients, int(N**2)).max(1)[0]

    #max_val   = math.ceil(torch.max(torch.cat((max_scs, max_preds))).item())
    #max_val   = np.around(max_val, 1)+.1 #make nice number
    #max_val   = np.max([0, max_val])

    min_pred = np.min(preds[0:num_patients])


    #use one colorbar for raw pred/adj, others are binary thus wont require colorbar
    # using 1 cololarbar: https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
    prior = ((prior > 0) + 0.0)
    raw_pred_im=0 #for use later
    for row in range(num_patients):
        patient_idx = row
        raw_adj_ax, raw_pred_ax, prior_ax = axes[row, 0], axes[row, 1], axes[row, 2]
        #setting vmin, vmax the same for all images makes the single colorbar consistant
        raw_adj_im = raw_adj_ax.imshow(scs[row], vmin=0, vmax=max_scs, cmap=raw_sc_cm)

        raw_pred_im = raw_pred_ax.imshow(preds[row], vmin=0, vmax=max_cbar_val, cmap=raw_pred_cm)
        if min_pred<0.0: #how would this happen?
            raw_pred_im = raw_pred_ax.imshow(preds[row], vmax=max_cbar_val, cmap=raw_pred_cm)

        prior_im = prior_ax.imshow(prior, vmin=0, vmax=1, cmap=binary_sc_cm)

        for i, threshold in enumerate(thresholds):
            threshold_pred = ((preds[row] > threshold) + 0.0)
            pr, re, f1, macro_f1, acc, mcc = \
                score_graphs_batch(threshold=threshold, adjs=scs[patient_idx], preds=threshold_pred, o='ave')
            thresh_ax = axes[row, i + num_pre_thresh_plots]
            im = thresh_ax.imshow(threshold_pred, vmin=0, vmax=1, cmap=intermed_out_cm)
            thresh_ax.set_xlabel(f'f1{f1:.2f}||macro-f1{macro_f1:.2f}||ACC{acc:.2f}||MCC{mcc:.2f}', fontsize=fs/2)

        threshold_adj = ((scs[row]>0)+0.0)
        im = axes[row, ncols-1].imshow(threshold_adj, vmin=0, vmax=1, cmap=binary_sc_cm)

        axes[row, 0].set_ylabel(f'Patient {row}', rotation=90, fontsize=fs)  # size='large')
        # Turn off tick labels
        for col in range(ncols):
            axes[row, col].set_xticklabels([])
            axes[row, col].set_yticklabels([])

    fig.suptitle(f'Model Outputs: ' + title_str, fontsize=fs + 4)
    fig.subplots_adjust(left=0.18, bottom=.05, wspace=.1)
    cbar_ax = fig.add_axes([0.07, 0.1, .03, 0.8]) #[0.85, 0.15, 0.05, 0.7])
    fig.colorbar(raw_pred_im, cax=cbar_ax)
    cbar_ax.yaxis.set_ticks_position('left')


    return fig


def viz_prior(mask, summary_values, mask_str = '', values_str = '', title_str=''):
    _, N = mask.shape

    fs = 10
    nrows = 1
    ncols = 3 #mask, values, prior

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
    mask_ax = axes[0]
    values_ax = axes[1]
    prior_ax = axes[2]

    # set titles of each column
    mask_ax.set_title(f'Mask: {mask_str}', fontsize=fs)
    values_ax.set_title(f'Values: {values_str}', fontsize=fs)
    prior_ax.set_title(f'Prior: Mask.*Values', fontsize=fs)

    mask_im   = mask_ax.imshow(mask+0.0, vmin=0, vmax=1, cmap=binary_sc_cm) #cm.binary
    values_im = values_ax.imshow(summary_values/10, vmin=0, vmax=1, cmap=raw_pred_cm) #grays
    prior_im  = prior_ax.imshow((mask*summary_values)/10, vmin=0, vmax=1, cmap=raw_pred_cm) #grays
    #mask_ax.set_xticklabels([]);
    values_ax.set_yticklabels([])
    prior_ax.yaxis.set_label_position("right")
    prior_ax.yaxis.tick_right()

    #axes[0, 0].set_ylabel(f'Patient {row}', rotation=90, fontsize=fs)
    # use one colorbar for raw pred/adj, others are binary thus wont require colorbar
    # using 1 cololarbar: https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
    fig.suptitle(f'Prior over {title_str} TRAINING SCs' , fontsize=fs + 4)
    fig.subplots_adjust(left=0.18, bottom=.05, wspace=.1)
    cbar_ax = fig.add_axes([0.07, 0.1, .03, 0.8])  # [0.85, 0.15, 0.05, 0.7])
    fig.colorbar(prior_im, cax=cbar_ax)
    cbar_ax.yaxis.set_ticks_position('left')

    return fig


# TODO: viz H1/ H2 polynomials for each layer
def intermed_heatmap_outputs(intermediate_outputs, scs, fcs, num_patients, pt = 98.0, title_str = ""):
    assert len(scs.shape) == 3
    assert len(intermediate_outputs.shape)==4
    _, N, _ = scs.shape
    depth, _, _, _ = intermediate_outputs.shape

    preds = intermediate_outputs[depth-1]
    fs = 12
    ncols = 1 + len(intermediate_outputs) #.depth  # raw fc + intermed_out

    fig, axes = plt.subplots(nrows=num_patients, ncols=ncols)
    # set titles of each column
    axes[0, 0].set_title(f'Raw FC', fontsize=fs)
    for i in range(depth):
        axes[0, i + 1].set_title(f'out {i}', fontsize=fs)

    # there are outliers in the fcs. Let's set everything above the 99th percentile
    # value to be the maximum value
    fc_vals = []
    for i in range(num_patients):
        fc = fcs[i]
        fc_vec = upper_tri_as_vec(fc)
        fc_vals.append(fc_vec.tolist())
    max_fc_cbar_val = np.percentile(fc_vals, pt)
    max_fc_cbar_val = max([max_fc_cbar_val, 1.0])


    # there are outliers in the fcs. Let's set everything above the 99th percentile
    # value to be the maximum value
    pred_vals = []
    for i in range(num_patients):
        pd = preds[i]
        pd_vec = upper_tri_as_vec(pd)
        pred_vals.append(pd_vec.tolist())
    max_pred_cbar_val = np.percentile(pred_vals, pt)
    max_pred_cbar_val = max([max_pred_cbar_val, 1.0])

    """
    # find maximum value of all adjs and preds
    # pull out first num_patients, make matrices into rows, find max of rows, ignore indices
    maxs_abs_fcs = torch.abs(fcs[0:num_patients]).view(num_patients, int(N**2)).max(1)[0]
    maxs_abs_fcs = max(maxs_abs_fcs).item()
    #normed_fcs_rows = fcs[0:num_patients].view(num_patients, int(N**2))/maxs_abs_fcs.view(num_patients,1)
    #normed_fcs = normed_fcs_rows.view(num_patients, N, N)
    """
    max_preds = preds[0:num_patients].view(num_patients, int(N ** 2)).max(1)[0].max()
    max_preds_cbar = math.ceil(max(max_preds, 1))
    #max_val = math.ceil(torch.max(torch.cat((max_fcs, max_preds))).item())
    # use one colorbar for fcs, another colorbar for S_out final (preds)
    # using 1 cololarbar: https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
    fc_im, pred_im = 0, 0  # for use later
    for row in range(num_patients):
        fc_ax, pred_ax = axes[row, 0], axes[row, ncols-1]
        #fc_im = fc_ax.imshow(normed_fcs[row], vmin=-maxs_abs_fcs, vmax=maxs_abs_fcs, cmap=fc_cm)
        fc_im = fc_ax.imshow(fcs[row], vmin=-max_fc_cbar_val, vmax=max_fc_cbar_val, cmap=fc_cm)
        for i in range(depth-1):
            im = axes[row, i + 1].imshow(intermediate_outputs[i][row], vmin=0, vmax=1, cmap=intermed_out_cm)
        pred_im = pred_ax.imshow(preds[row], vmin=0, vmax=max_pred_cbar_val, cmap=raw_pred_cm)

        axes[row, 0].set_ylabel(f'Patient {row}', rotation=90, fontsize=fs)  # size='large')
        # Turn off tick labels
        for col in range(ncols):
            axes[row, col].set_xticklabels([])
            axes[row, col].set_yticklabels([])

    fig.suptitle(f'Intermediate Model Outputs ' + title_str, fontsize=23)
    #fig.subplots_adjust(left=0.18, bottom=.05, wspace=.08)
    cbar_left_ax = fig.add_axes([0.07, 0.1, .03, 0.8])  # [0.85, 0.15, 0.05, 0.7])

    cbar_right_ax = fig.add_axes([0.91, 0.1, .03, 0.8])

    fig.colorbar(fc_im, cax=cbar_left_ax)
    cbar_left_ax.set_label('Raw FC (Cov)')
    fig.colorbar(pred_im, cax=cbar_right_ax)
    cbar_right_ax.set_label('Pred SC')
    cbar_left_ax.yaxis.set_ticks_position('left')
    cbar_right_ax.yaxis.set_ticks_position('right')

    return fig


#def threshold_f_msr_plot(adjs, preds, curve_thresholds = np.linspace(0, .2, 50), o='ave'):
def threshold_metric_plot(adjs, preds, curve_thresholds = np.linspace(0, .2, 50), o='ave'):
    #fcs, scs = batch
    """
    ### finding best cutoff
    num_points = len(curve_thresholds)
    f_msrs, recalls, precisions = np.zeros(num_points), np.zeros(num_points), np.zeros(num_points)

    for i, threshold in enumerate(curve_thresholds):
        # score_graphs_batch internally converts non-binarized scs into binarized for calculatison
        ave_pr, ave_re, ave_f1 = score_graphs_batch(threshold, batch, preds, o="ave")
        f_msrs[i], recalls[i], precisions[i] = ave_f1, ave_re, ave_pr

    max_f_msrs_ind = np.argmax(f_msrs)
    ###
    """
    precisions, recalls, f1s, macro_f1s, accs, mccs = \
        metrics_at_thresholds(thresholds=curve_thresholds, adjs=adjs, preds=preds, o=o)

    fig, ax = plt.subplots()  # nrows=1, ncols=1))
    ax.scatter(curve_thresholds, accs, label='Acc')
    ax.scatter(curve_thresholds, mccs, label='MCC')
    ax.scatter(curve_thresholds, macro_f1s, label='Macro-F1')

    """
    for i, _ in enumerate(recalls):
        rotation=0
        offset = 0
        if(i<best_threshold_idx-offset):
            rotation=-20
        elif(i>=best_threshold_idx+offset):
            rotation = +20
        acc, macro_f1, mcc = round(accs[i], 2), round(macro_f1s[i], 2), round(mcc[i], 2),d
        txt = f' {acc} | {macro_f1} | {mcc} |'
        ax.annotate(txt, (curve_thresholds[i], data[i]), fontsize=7, rotation=rotation)

    """
    intervals = .05
    fig.suptitle(f'Metrics vs threshold', fontsize=20)
    plt.grid('on')
    ax.yaxis.grid(which='both', linestyle='--', linewidth=2)
    ax.xaxis.grid(which='both', linestyle='--', linewidth=2)

    """
    max_re = max(recalls)
    max_pr = max(precisions)
    min_fmsrs,  max_f_msrs = min(f_msrs), max(f_msrs)
    min_thres, max_thres = min(curve_thresholds), max(curve_thresholds)
    #if tight:
    #    ax.set_xlim(min_thres, max_thres); ax.set_ylim(min_fmsrs, max_f_msrs)
    #else:
    """
    x_max = round_up(np.max(curve_thresholds), decimals=1)
    ax.set_xlim(-.05, x_max)
    ax.set_ylim(np.nanmin(mccs)-.1, 1)
    ax.set_xlabel('Thresholds', fontsize='18')
    ax.set_ylabel('Metrics', fontsize='18')
    ax.legend()

    return fig


def prec_recall_plot(adjs, preds, curve_thresholds = np.linspace(0, 1, 30), tight = False):
    num_points = len(curve_thresholds)
    f_msrs, recalls, precisions = np.zeros(num_points), np.zeros(num_points), np.zeros(num_points)
    for i, threshold in enumerate(curve_thresholds):
        ave_pr, ave_re, ave_f1, ave_macro_f1, ave_acc, ave_mcc = \
            score_graphs_batch(threshold=threshold, adjs=adjs, preds=preds, o='ave')
        f_msrs[i], recalls[i], precisions[i] = ave_f1, ave_re, ave_pr

    max_f_msrs_ind = np.argmax(f_msrs)
    fig, ax = plt.subplots()  # nrows=1, ncols=1))
    ax.scatter(recalls, precisions)
    for i, _ in enumerate(recalls):
        rotation = 0
        if (i < max_f_msrs_ind - 2):  # left of max
            rotation = -30
        elif (i > max_f_msrs_ind + 2):
            rotation = 20
        f_msr, threshold = round(f_msrs[i], 3), round(curve_thresholds[i], 3)
        txt = f' {f_msr} | {threshold}'
        ax.annotate(txt, (recalls[i], precisions[i]), fontsize=5, rotation=rotation)

    intervals = .05
    # best_fmrs = np.max(f_msrs)
    fig.suptitle('Re/Pr @ Thresh. (F-Msr, Thresh) labels.', fontsize=23)
    plt.grid('on')
    ax.yaxis.grid(which='both', linestyle='--', linewidth=2)
    ax.xaxis.grid(which='both', linestyle='--', linewidth=2)


    if tight:
        min_re, max_re = min(recalls), max(recalls)
        min_pr, max_pr = min(precisions), max(precisions)
        ax.set_xlim(min_re, max_re)
        ax.set_ylim(min_pr, max_pr)
    else:
        ax.set_xlim(0, 1);
        ax.set_ylim(0, 1)
    ax.set_xlabel('Recall', fontsize='20')
    ax.set_ylabel('Precision', fontsize='20')

    return fig


def make_edge_error_subplots(scs, preds, num_patients, true_edge_max = 1.0, num_bins_error=10, num_bins_true=10):#, pred_edge_min = 0.0, pred_edge_max = 2.0):
    batch_size, N, _ = scs.shape

    small_fs = 7

    #assumes undirected
    num_total_edges = N*(N-1)/2
    true_edge_bins = np.linspace(0, true_edge_max, num_bins_true)


    max_error = round(np.max(np.absolute(preds-scs)), 1)+.1

    # make num_patients subplots showing pred edge weight vs true edge weight
    fig, axes = plt.subplots(num_patients, len(true_edge_bins))#, sharex='row')
    fig.suptitle(f'Error Distrib: (PRED-LABEL)', fontsize=15)


    for row in range(num_patients):
        # extract graphs
        truth_graph, pred_graph = scs[row], preds[row]
        errors = pred_graph - truth_graph
        # create arrays of edge weights
        #output_truth, output_pred = edge_weights_true_pred(truth_graph, pred_graph)
        #output_truth, output_pred = np.array(output_truth), np.array(output_pred)

        # create arrays of errors and edge weights
        scs_arr, errors_arr = edge_weights_true_pred(truth_graph, errors)
        scs_arr, errors_arr = np.array(scs_arr), np.array(errors_arr)


        #do zero by hand
        zero_edge_mask = (scs_arr==0)
        errors_non_edge  = errors_arr[zero_edge_mask]  #predictions correcsponding to zero edges
        percent_total_edges = len(errors_non_edge) / num_total_edges * 100

        ax = axes[row, 0]
        ax.hist(errors_non_edge, bins=num_bins_error, range=[-max_error, max_error], orientation='horizontal', density=True)
        ax.set_title(f'edges: {len(errors_non_edge)}\ntotal:{percent_total_edges:.1f}%', fontsize=small_fs)
        ax.set(xlabel=f'non-edges')
        ax.xaxis.get_label().set_fontsize(small_fs)
        ax.set_ylabel(f'Patient {row}', rotation=90, fontsize=15)
        ax.set_xticks([])
        #axes[i, 0].set_ylabel(f'Edge Pred') #rotation=90
        #axes[i,0].y

        for col, low in enumerate(true_edge_bins):
            if col == (len(true_edge_bins)-1):
                break
            low, high = low, true_edge_bins[col+1]
            mask = np.logical_and(scs_arr>low, scs_arr<=high)
            #mask = (output_truth>low) and (output_truth<=high)
            errors_bin = errors_arr[mask]
            percent_total_edges = len(errors_bin)/num_total_edges*100
            ax = axes[row, col+1]
            ax.hist(errors_bin, bins=num_bins_error, range=[-max_error, max_error], orientation='horizontal')
            ax.set_title(f'{len(errors_bin)}\n{percent_total_edges:.2f}%', fontsize=small_fs)
            ax.set(xlabel=f'({str(round(low,2))[1:]},{str(round(high,2))[1:]}]')#, fontsize=small_fs)
            ax.xaxis.get_label().set_fontsize(small_fs)
            ax.set_xticks([])
            ax.set_yticks([])
    plt.subplots_adjust(top=0.85, bottom=0.05, hspace=.5, wspace=0.5)
    return fig


def make_all_edge_error_subplots(scs, preds, true_edge_max=1.0, num_bins_error=10,num_bins_true=10):  # , pred_edge_min = 0.0, pred_edge_max = 2.0):
    batch_size, N, _ = scs.shape

    small_fs = 7

    abs_errors = np.absolute(preds - scs)
    max_error = round(np.max(abs_errors), 1) + .1

    all_scs, all_errors = [], []
    for i in range(batch_size):
        # extract graphs
        truth_graph, pred_graph = scs[i], preds[i]
        errors = pred_graph - truth_graph
        # create arrays of edge weights
        scs_sub_arr, errors_sub_arr = edge_weights_true_pred(truth_graph, errors)
        all_scs    += scs_sub_arr
        all_errors += errors_sub_arr
    all_scs, all_errors = np.array(all_scs), np.array(all_errors)

    # assumes undirected
    num_scs = scs.shape[0]
    num_total_edges = num_scs * N * (N - 1) / 2
    true_edge_bins = np.linspace(0, true_edge_max, num_bins_true)

    # make  subplots showing error vs true edge weight
    fig, axes = plt.subplots(1, len(true_edge_bins))  # , sharex='row')
    fig.suptitle(f'Error Distrib: (PRED-LABEL)', fontsize=15)

    # do zero by hand
    zero_edge_mask = (all_scs == 0)
    errors_non_edge = all_errors[zero_edge_mask]  # predictions correcsponding to zero edges
    percent_total_edges = len(errors_non_edge) / num_total_edges * 100

    ax = axes[0]
    ax.hist(errors_non_edge, bins=num_bins_error, range=[-max_error, max_error], orientation='horizontal', density=True)
    ax.set_title(f'edges: {len(errors_non_edge)}\ntotal:{percent_total_edges:.1f}%', fontsize=small_fs)
    ax.set(xlabel=f'non-edges')
    ax.xaxis.get_label().set_fontsize(small_fs)
    ax.set_ylabel(f'Error Over All Patients', rotation=90, fontsize=15)
    ax.set_xticks([])
    # axes[i, 0].set_ylabel(f'Edge Pred') #rotation=90
    # axes[i,0].y



    for col, low in enumerate(true_edge_bins):
        if col == (len(true_edge_bins) - 1):
            break
        low, high = low, true_edge_bins[col + 1]
        mask = np.logical_and(all_scs > low, all_scs <= high)
        errors_bin = all_errors[mask]
        percent_total_edges = len(errors_bin) / num_total_edges * 100
        ax = axes[col + 1]
        ax.hist(errors_bin, bins=num_bins_error, range=[-max_error, max_error], orientation='horizontal')
        ax.set_title(f'{len(errors_bin)}\n{percent_total_edges:.2f}%', fontsize=small_fs)
        ax.set(xlabel=f'({str(round(low, 2))[1:]},{str(round(high, 2))[1:]}]')#, fontsize=small_fs)
        ax.xaxis.get_label().set_fontsize(small_fs)
        ax.set_xticks([])
        ax.set_yticks([])
    plt.subplots_adjust(top=0.85, bottom=0.05, hspace=.45, wspace=0.5)
    return fig


#num_patients = # of samples to plot
#xy_max is the maximum value of and edge weight (assumed scaled to 1)
def make_edge_subplots(scs, preds, mses, subject_ids, scan_dirs, threshold=0.0, sort_by='acc'):

    # select which plots to choose by the best/worst/median fmsrs/mses
    pr, re, f1, macro_f1, acc, mcc = \
        score_graphs_batch(threshold=threshold, adjs=scs, preds=preds, o='raw')
    sort_idxs = find_best_by(pr=pr, re=re, f1=f1, macro_f1=macro_f1, acc=acc,
                             mcc=mcc, mse=mses, sort_by=sort_by)
    # set up labels
    l = len(sort_idxs)
    patient_idxs = [sort_idxs[0], sort_idxs[1], sort_idxs[l // 2 - 1], sort_idxs[l // 2], sort_idxs[-2], sort_idxs[-1]]
    patient_txts = ["best1", "best2", "median1", "median2", "worst2", "worst1"]

    fig, axes = plt.subplots(nrows=len(patient_idxs), ncols=1)
    for row, idx in enumerate(patient_idxs):
        sc, pred, subject_id, scan_dir = \
            scs[idx], preds[idx], subject_ids[idx], scan_dirs[idx]

        # create arrays of edge weights - CHANGE THIS TO upper_tr_mat_as_vec !!
        output_sc, output_pred = edge_weights_true_pred(sc, pred)
        output_sc, output_pred = np.array(output_sc), np.array(output_pred)

        ax = axes[row]
        ax.scatter(output_sc, output_pred, s=10)

        #x_max, y_max = 10,10
        top_right = np.max([np.max(scs), np.max(preds)])
        #l = mlines.Line2D([0, true_edge_max], [0, true_edge_max], linewidth=1, color='k')
        l = mlines.Line2D([0, top_right], [0, top_right], linewidth=1, color='k')
        ax.add_line(l)
        zero_line = mlines.Line2D([0, np.max(scs)], [0, 0], linewidth=0.25, color='k')
        ax.add_line(zero_line)

        ax.set_xlim([0, np.max(scs)])
        ax.set_ylim([np.min(preds), np.max(preds)])
        ax.set_yticks([0, 1])
        ax.set_yticklabels(["0", "1"])
        ax.yaxis.tick_right()
        ax.set_xticks([])
        ax.set_xticklabels([])
        #ax.set_ylabel('Pred Edge Weight')
        scan_info = f'macro_f1/acc/mcc/mse: {macro_f1[idx]:.2f} || {acc[idx]:.2f} || {mcc[idx]:.2f} || {mses[idx]:.5f}'
        plt.text(.02, .88, scan_info, ha='left', va='center', fontsize=6, transform=ax.transAxes)
        # y label information on scan => per row
        ax.set_ylabel(f'{patient_txts[row]}\n{subject_ids[idx]}-{scan_dirs[idx]}', rotation=90,
                                fontsize=9)  # size='large')

    # set common xlabel for all
    ax.set_xlabel('True Edge Weight')
    ax.set_xticks([0, .25, .5, .75, 1])
    ax.set_xticklabels(["0", ".25", ".5", ".75", "1"])
    fig.suptitle(f'Pred vs True. Sorted by {sort_by}.', fontsize=12)
    """
    #make num_patients subplots showing pred edge weight vs true edge weight
    fig = plt.figure()
    for i in range(num_patients):
        #extract graphs
        truth_graph, pred_graph = scs[i], preds[i]

        #create arrays of edge weights
        output_truth, output_pred = edge_weights_true_pred(truth_graph, pred_graph)
        output_truth, output_pred = np.array(output_truth), np.array(output_pred)

        ax = plt.subplot(num_patients, 1, i+1)
        ax.scatter(output_truth, output_pred, s=10)

        #x_max, y_max = 10,10
        top_right = np.max([true_edge_max, pred_edge_max])
        #l = mlines.Line2D([0, true_edge_max], [0, true_edge_max], linewidth=1, color='k')
        l = mlines.Line2D([0, top_right], [0, top_right], linewidth=1, color='k')
        ax.add_line(l)
        zero_line = mlines.Line2D([0, true_edge_max], [0, 0], linewidth=0.25, color='k')
        ax.add_line(zero_line)


        ax.set_xlim([0, true_edge_max])
        ax.set_ylim([pred_edge_min, pred_edge_max])
        ax.set_title(f'Patient {i}')
        ax.set_xlabel('True Edge Weight')
        ax.set_ylabel('Pred Edge Weight')

    """

    return fig


def make_all_edge_plot(scs, preds):#, true_edge_max=1.0, pred_edge_max = 2.0):
    batch_size, N, _ = scs.shape
    true_edge_max = np.max(scs)
    pred_edge_max = np.max(preds)

    output_all_truth, output_all_pred = [], []
    for i in range(batch_size):
        # extract graphs
        truth_graph, pred_graph = scs[i], preds[i]

        # create arrays of edge weights
        output_truth, output_pred = edge_weights_true_pred(truth_graph, pred_graph)
        output_all_truth += output_truth
        output_all_pred += output_pred

    fig = plt.figure()
    output_all_truth, output_all_pred = np.array(output_all_truth), np.array(output_all_pred)
    plt.scatter(output_all_truth, output_all_pred, s=10)
    #x_max, y_max = 10,10
    l = mlines.Line2D([0, true_edge_max], [0, pred_edge_max], linewidth=1, color='k')
    plt.gca().add_line(l)
    plt.xlim([0, true_edge_max])
    plt.ylim([0, pred_edge_max])
    plt.title('All Edges: Predicted Edge Weight Vs True Edge Weight')
    plt.xlabel('True Edge Weight')
    plt.ylabel('Predicted Edge Weight')

    return fig


def edge_weights_true_pred(truth_graph, pred_graph):
    assert truth_graph.shape == pred_graph.shape
    N = truth_graph.shape[0]
    truth, pred = [], []
    #num_elements_upper_tri = N*(N-1)/2
    #truth, pred = np.zeros(num_elements_upper_tri), np.zeros(num_elements_upper_tri)

    for i in range(N):
        for j in range(i+1, N, 1):
            #print(f'({i},{j}')
            if j==N:
                input(f'problem in index loop')

            truth.append(truth_graph[i, j])
            pred.append(pred_graph[i, j])
    return truth, pred

#https://scikit-learn.org/stable/auto_examples/linear_model/plot_logistic.html
def logistic_func_plot(adjs, preds, num_patients, a,b, title_str = ""):
    assert len(adjs.shape)==3 and len(preds.shape)==3 and adjs.shape[1]==preds.shape[1]
    #assert scs all 0s/1s
    bs, N, _ = adjs.shape

    max_preds = np.max(preds)

    fs = 10
    fig, axes = plt.subplots(nrows=num_patients, ncols=1, sharex=True)

    #a, b = cdp(model.log_regr_layer.a)[0], cdp(model.log_regr_layer.bias)[0]
    db = -b/a #decision boundary

    #axes[0].set_title(f'Logistic Func: a: {a}, b: {b}. Decision Boundary: {db}', fontsize=fs + 4)
    fig.suptitle(f'Log Regres on Outs: a: {a:.1f}, b: {b:.1f}. Dec Bdary: {db:.1f}\nX_axis=edge value. Black: Raw (y=1 -> true edge)', fontsize=fs + 4)


    # want to capture all data and decision boundary
    min_x_lim = np.min([0, math.floor(db - .5)])
    # want to capture all data and decision boundary
    max_x_lim = np.max([max_preds, 1, math.ceil(db + .5)])
    max_x_lim = np.min([max_x_lim, 5]) # if super large, only go out to 5

    # create sigmoid curve
    x_sig = np.linspace(min_x_lim, max_x_lim, 75)
    sig = expit(x_sig * a + b).ravel()

    for row in range(num_patients):
        ax = axes[row]
        ax.set_ylabel(f'Patient {row}', rotation=90, fontsize=fs)

        #get edges and plot with labels
        p_vec = upper_tri_as_vec(preds[row], offset=1) # contin vector
        t_vec = upper_tri_as_vec(truths[row], offset=1) #0/1 vector

        #x value is edge, y value is label: true non-edges on y=0, true edges on y=1
        ax.scatter(p_vec, t_vec, color='black', zorder=20)

        #put sigmoid in
        ax.plot(x_sig, sig, color='red', linewidth=3)

        ax.set_xlim(left=min_x_lim, right=max_x_lim)
        ax.set_ylim(bottom=-.3, top=1.3)

        xticks = np.arange(min_x_lim, max_x_lim, step=0.5)
        ax.set_xticks(xticks)
        ax.set_xticklabels([str(x) for x in xticks])


        ax.vlines(db, -.3, 1.3) #vertical line at decision boundary

    return fig


def metrics_distrib(pr, re, f1, macro_f1, acc, mcc, threshold):
    fs = 10
    nrows, ncols = 4, 1
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, sharex=True)
    fig.suptitle(f'Metrics @ Threshold {threshold:.3f}', fontsize=20)
    f1_ax, macro_f1_ax, acc_ax, mcc_ax = axes[0], axes[1], axes[2], axes[3]

    # set titles of each row
    #pr_ax.set_title(f'Precision: mean {np.mean(pr):.2f} | med {np.median(pr):.2f} | min {np.min(pr):.2f} | max {np.max(pr):.2f}', fontsize=fs)
    #re_ax.set_title(f'Recall: mean {np.mean(re):.2f} | med {np.median(re):.2f} | min {np.min(re):.2f} | max {np.max(re):.2f}', fontsize=fs)
    f1_ax.set_title(f'F1: mean {np.mean(f1):.2f} | med {np.median(f1):.2f} | min {np.min(f1):.2f} | max {np.max(f1):.2f}', fontsize=fs)
    macro_f1_ax.set_title(f'Macro F1: mean {np.mean(macro_f1):.2f} | med {np.median(macro_f1):.2f} | min {np.min(macro_f1):.2f} | max {np.max(macro_f1):.2f}', fontsize=fs)
    acc_ax.set_title(f'Acc: mean {np.mean(acc):.2f} | med {np.median(acc):.2f} | min {np.min(acc):.2f} | max {np.max(acc):.2f}', fontsize=fs)
    mcc_ax.set_title(f'MCC (ignore nan): mean {np.nanmean(mcc):.2f} | med {np.nanmedian(mcc):.2f} | min {np.nanmin(mcc):.2f} | max {np.nanmax(mcc):.2f}', fontsize=fs)

    bins = 25
    #pr_im = pr_ax.hist(pr.flatten(), density=True, bins=bins) #histtype='step'
    #re_im = re_ax.hist(re.flatten(), density=True, bins=bins)
    f1_im = f1_ax.hist(f1.flatten(), density=True, bins=bins)
    macro_f1_im = macro_f1_ax.hist(macro_f1.flatten(), density=True, bins=bins)
    acc_im = acc_ax.hist(acc.flatten(), density=True, bins=bins)
    mcc_wo_nan = mcc[~np.isnan(mcc)]
    mcc_im = mcc_ax.hist(mcc_wo_nan.flatten(), density=True, bins=bins)

    fig.tight_layout()

    return fig


def metrics_2d(x, x_label, y, y_label, subject_ids, scan_dirs, scores=None, scores_label=None, annot_pts=5, xlims=None, ylims=None):
    assert len(x) == len(y), f'oppposing data must be same length'
    if scores is not None:
        assert len(x) == len(scores), f'scores must be same length as data'
    assert not ((scores is not None) ^ (scores_label is not None)), f'must provide labels for scores'
    fs = 10
    fig, ax = plt.subplots(nrows=1, ncols=1)

    if scores is not None:
        vmin = math.floor(np.nanmin(scores) * 10) / 10
        vmax = math.ceil(np.nanmax(scores) * 10) / 10
        im = ax.scatter(x=x, y=y, c=scores, cmap=fc_cm, vmin=vmin, vmax=vmax)
    else:
        im = ax.scatter(x=x, y=y, c=scores)

    ax.hist(x.flatten(), bins=25, weights=len(x) * [.01], orientation='vertical', alpha=.5)
    ax.hist(y.flatten(), bins=25, weights=len(y) * [.01], orientation='horizontal', alpha=.5)

    if scores is not None:
        if scores_label == 'mses':
            sorted_idxs = np.argsort(scores)
        else:
            sorted_idxs = np.flip(np.argsort(scores))

        # mark best/worst/median points and print out ids
        best_scans, worst_scans, median_scans = [], [], []
        for i in range(annot_pts):
            idx_best, idx_worst = sorted_idxs[i], sorted_idxs[-(i + 1)]
            idx_median = sorted_idxs[len(sorted_idxs) // 2 - (annot_pts // 2 - i)]

            best_scans.append(
                (f'{subject_ids[idx_best]}{scan_dirs[idx_best].upper()}||{scores[idx_best]:.2f}', idx_best))
            ax.scatter(x=[x[idx_best]], y=[y[idx_best]], color='b', s=[50], marker='x')

            worst_scans.append(
                (f'{subject_ids[idx_worst]}{scan_dirs[idx_worst].upper()}||{scores[idx_worst]:.2f}', idx_worst))
            ax.scatter(x=[x[idx_worst]], y=[y[idx_worst]], color='k', s=[50], marker='D')

            median_scans.append(
                (f'{subject_ids[idx_median]}{scan_dirs[idx_median].upper()}||{scores[idx_median]:.2f}', idx_median))
            ax.scatter(x=[x[idx_median]], y=[y[idx_median]], color='y', s=[50], marker='P')

        s_best = f"HIGHEST (blue)\n"
        for scan_text, idx in best_scans:
            s_best += scan_text + "\n"

        s_median = "MEDIAN (yellow)\n"
        for scan_text, idx in median_scans:
            s_median += scan_text + "\n"

        s_worst = "LOWEST (black)\n"
        worst_scans.reverse()
        for scan_text, idx in worst_scans:
            s_worst += scan_text + "\n"

        ax.text(x=.01, y=.28, s=s_best[:-2], bbox=dict(color='blue', alpha=0.2), fontsize=7, transform=ax.transAxes)
        ax.text(x=.01, y=.05, s=s_median[:-2], bbox=dict(color='yellow', alpha=0.2), fontsize=7, transform=ax.transAxes)
        ax.text(x=.25, y=.05, s=s_worst[:-2], bbox=dict(color='black', alpha=0.2), fontsize=7, transform=ax.transAxes)

        scores_mean, scores_median = np.nanmean(scores), np.nanmedian(scores)
        ax.set_title(f'{scores_label}: Mean={scores_mean:.2f}, Median={scores_median:.2f}', fontsize=fs)

    # place ticks on axes to denote mean/median vals

    x_mean, x_median = np.nanmean(x), np.nanmedian(x)
    y_mean, y_median = np.nanmean(y), np.nanmedian(y)
    """
    base_x_ticks = [0, .25, .5, .75, 1]
    x_ticks = base_x_ticks + [x_mean, x_median]
    y_ticks = base_x_ticks + [y_mean, y_median]
    x_tick_labels = ['0', '.25', '.5', '.75', '1', f'{x_mean:.2f}', f'{x_median:.2f}']
    y_tick_labels = ['0', '.25', '.5', '.75', '1', f'{y_mean:.2f}', f'{y_median:.2f}']
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_tick_labels)
    ax.set_yticks(y_ticks)
    ax.set_yticklabels(y_tick_labels)
    """

    means = ax.scatter(x=[x_mean, 0], y=[0, y_mean], s=2 * [50], marker='*')
    medians = ax.scatter(x=[x_median, 0], y=[0, y_median], s=2 * [50], marker='x')

    ax.set_xlabel(f'{x_label}')
    ax.set_ylabel(f'{y_label}')
    if xlims is not None:
        ax.set_xlim(xlims)
    if ylims is not None:
        ax.set_ylim(ylims)
    if scores is not None:
        fig.colorbar(im, label=f'{scores_label}')

    return fig


def multiple_ridgeline(data_list, metric_list, thresholds):
    assert len(metric_list) == len(data_list)

    fs = 20
    nrows, ncols = 1, len(metric_list)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols)  # , sharey=True)

    fig.suptitle(f'Metric Distrib vs Thresholds', fontsize=20)
    axes[0].set_ylabel('Thresholds', fontsize=fs)

    yticklabels = []
    for idx, threshold in enumerate(thresholds):
        if idx%4==0:
            yticklabels.append(f'{threshold:.3f}'[1:])
        else:
            yticklabels.append('')

    for idx, (metric, data, ax) in enumerate(zip(metric_list, data_list, axes)):
        if idx == (len(metric_list)-1):
            labels = yticklabels
        else:
            labels = []
        ridgeline(data=data, ax=ax, xlabel=metric, fs=fs, labels=labels, yticks_right=True, ylabel_right=True)

    return fig


# https://glowingpython.blogspot.com/2020/03/ridgeline-plots-in-pure-matplotlib.html
def ridgeline(data, fig=None, ax=None, overlap=0.0, fill='y', labels=None, n_points=150, xlabel=None, ylabel=None, ylabel_right=False, yticks_right=False, fs=18, xlow=None, xhigh=None):
    """
    Creates a standard ridgeline plot.

    data, list of numpy arrays.
    overlap, overlap between distributions. 1 max overlap, 0 no overlap.
    fill, matplotlib color to fill the distributions.
    n_points, number of points to evaluate each distribution function.
    labels, values to place on the y axis to describe the distributions.
    """
    for i, ds in enumerate(data):
        data[i] = np.asarray(ds)

    if ax is None:
        fig, ax = plt.subplots()
        #ax = plt.gca()

    if overlap > 1 or overlap < 0:
        raise ValueError('overlap must be in [0 1]')
    if xlow is None:
        xlow = np.min(np.concatenate(data))
    if xhigh is None:
        xhigh = np.max(np.concatenate(data))

    xx = np.linspace(xlow, xhigh, n_points) # np.max(np.concatenate(data)), n_points)
    curves = []
    ys = []
    for i, d in enumerate(data):
        pdf = gaussian_kde(d)
        y = i*(1.0-overlap)
        ys.append(y)
        curve = pdf(xx)
        if fill:
            ax.fill_between(xx, np.ones(n_points)*y,
                             curve+y, zorder=len(data)-i+1, color=fill)
        ax.plot(xx, curve+y, c='k', zorder=len(data)-i+1)
    if labels is not None:
        ax.set_yticks(ys)
        ax.set_yticklabels(labels)
        if yticks_right:
            ax.yaxis.tick_right()

    if xlabel is not None:
        ax.set_xlabel(xlabel, fontsize=fs)
    if ylabel is not None:
        ax.set_ylabel(ylabel, fontsize=fs)
        if ylabel_right:
            ax.yaxis.set_label_position("right")

    return fig, ax


def edge_metrics(scs, preds, threshold, node_labels):

    binary_preds = (preds > threshold)
    binary_scs = (scs > 0)
    e_mse = edge_mse(preds, scs)
    _, _, e_f1s, e_macro_f1, e_acc, e_mcc = batch_graph_metrics(x=binary_preds, y=binary_scs, graph_or_edge='edge')

    fs = 20
    nrows, ncols = 1, 5
    fig, (mse_ax, acc_ax, macro_f1_ax, mcc_ax, counts_ax) = plt.subplots(nrows=nrows, ncols=ncols)#, sharey=True)
    fig.suptitle(f'Edge Metrics @ Threshold {threshold:.3f}', fontsize=20)


    mse_im = mse_ax.imshow(e_mse, vmin=0, vmax=np.max(e_mse), cmap=raw_pred_cm, interpolation='None')
    acc_im = acc_ax.imshow(e_acc, vmin=0, vmax=1, cmap=raw_pred_cm, interpolation='None')
    f1_im  = macro_f1_ax.imshow(e_macro_f1, vmin=0, vmax=1, cmap=raw_pred_cm, interpolation='None')
    mcc_im = mcc_ax.imshow(e_mcc, vmin=-1, vmax=1, cmap=fc_cm, interpolation='None')
    counts_im = counts_ax.imshow(np.sum(binary_scs, axis=0)/len(binary_scs), vmin=0, vmax=1, cmap=binary_sc_cm, interpolation='None')

    shrink = 1
    fig.colorbar(mse_im, ax=[mse_ax], label='MSE', location='top', shrink=shrink)
    #plt.colorbar(raw_output_im, ax=axes[:, 2], fraction=0.25, pad=0.08, location='right')
    fig.colorbar(acc_im, ax=[acc_ax], label='Accuracy',  location='top', shrink=shrink)
    fig.colorbar(f1_im, ax=[macro_f1_ax], label='Macro-F1',  location='top', shrink=shrink)
    fig.colorbar(mcc_im, ax=[mcc_ax], label='MCC', location='top', shrink=shrink)
    fig.colorbar(counts_im, ax=[counts_ax], label=f'% edge occur ({len(binary_scs)} scans)', location='top', shrink=shrink)

    idxs = np.arange(0, len(e_mse))
    #idxs_str = [f'{x+1}' for x in idxs if x%2==0]
    idxs_str=[]
    for x in idxs:
        if x%10==0 or x==(len(idxs)-1):
            idxs_str.append(f'{x}')
        else:
            idxs_str.append('')
    fs = 6
    for ax in [mse_ax, acc_ax, macro_f1_ax, mcc_ax, counts_ax]:
        ax.set_xticks(idxs)
        ax.set_yticks(idxs)
        ax.set_xticklabels(idxs_str, FontSize=fs)
        ax.set_yticklabels(idxs_str, FontSize=fs)

    #f1_ax.yaxis.set_ticks_position("right")

    #fig.tight_layout()

    return fig


def adj_power_distributions(synth, real, max_power=5, which_norm='max_eig', include_zero_entries=False):
    fs = 10
    metrics = ['median', 'max_abs_val']
    nrows, ncols = len(metrics), max_power
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, constrained_layout=True, sharex='all')
    fig.suptitle(f'Distribs of Powers of Adjs normalized by {which_norm}. Include zero entires?: {include_zero_entries}', fontsize=20)

    fig_all, axes_all = plt.subplots(nrows=1, ncols=ncols, sharex='all')

    # add titles to rows: median, max, frob
    for row, title in enumerate(metrics):
        axes[row, 0].set_ylabel(title, rotation=90, fontsize=fs)

    # add title to cols: FC^1, ..., FC^N
    for fc_power in range(1, max_power+1):
        axes[-1, fc_power-1].set_xlabel(f'A^{fc_power}', fontsize=fs)

    # normalize fcs
    synth_normed = normalize_slices(synth, which_norm=which_norm)
    real_normed = normalize_slices(real, which_norm=which_norm)

    bins = 50
    for power in range(1, max_power+1):
        synth_pow, real_pow = torch.matrix_power(synth_normed, n=power), torch.matrix_power(real_normed, n=power)
        synth_vv, real_vv = upper_tri_as_vec_batch(synth_pow, offset=0).numpy(), upper_tri_as_vec_batch(real_pow, offset=0).numpy()

        # ignore 0 entries in summary statistics
        if not include_zero_entries:
            synth_vv[synth_vv<1e-9], real_vv[real_vv<1e-9] = np.nan, np.nan

        synth_med, real_med = np.nanmedian(synth_vv, axis=1), np.nanmedian(real_vv, axis=1)
        synth_max, real_max = np.nanmax(np.abs(synth_vv), axis=1),  np.nanmax(np.abs(real_vv), axis=1)
        synth_frob, real_frob = torch.linalg.norm(synth_pow, ord='fro', dim=(1, 2)).view(-1, 1).numpy(), torch.linalg.norm(real_pow, ord='fro', dim=(1, 2)).view(-1, 1).numpy()

        axes[0, power-1].hist(synth_med.flatten(), bins=bins, label='Synth Adjs', color='blue', alpha=.5)
        axes[0, power-1].hist(real_med.flatten(), bins=bins, label='Real Adjs', color='magenta', alpha=.5)

        axes[1, power-1].hist(synth_max.flatten(), bins=bins, label='Synth Adjs', color='blue', alpha=.5)
        axes[1, power-1].hist(real_max.flatten(), bins=bins, label='Real Adjs', color='magenta', alpha=.5)

        #axes[2, power-1].hist(synth_frob.flatten(), bins=bins, label='Synth Adjs', color='blue', alpha=.5)
        #axes[2, power-1].hist(real_frob.flatten(), bins=bins, label='Real Adjs', color='magenta', alpha=.5)

        axes_all[power-1].hist(synth_vv.flatten(), bins=bins, label='Synth Adjs', color='blue', alpha=.5)
        axes_all[power-1].hist(real_vv.flatten(), bins=bins,  label='Real Adjs', color='magenta', alpha=.5)
        print(f'\npower {power}: MEAN    - synthetic: {np.mean(synth_vv):.5f}, real: {np.mean(real_vv):.5f}')
        print(f'power {power}:   MAX ABS - synthetic: {np.max(np.abs(synth_vv)):.5f},  real: {np.mean(np.abs(real_vv)):.5f}')
        axes_all[power-1].legend()

        #axes.legend() # common legend?
        #for ax in axes[:, power-1]:
        #    ax.set_xscale('log', base=2)

    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)
    return fig


def fc_power_distributions(synth_fcs, ps_fcs, real_fcs, max_fc_power=5, which_norm='max_eig', include_zero_entries=False):
    fs = 10
    metrics = ['median', 'max_abs_val', 'frob']
    nrows, ncols = len(metrics), max_fc_power
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, constrained_layout=True, sharex='all')
    fig.suptitle(f'Distributions of Powers of FCs', fontsize=20)

    fig_all, axes_all = plt.subplots(nrows=1, ncols=ncols, sharex='all')

    # add titles to rows: median, max, frob
    for row, title in enumerate(metrics):
        axes[row, 0].set_ylabel(title, rotation=90, fontsize=fs)

    # add title to cols: FC^1, ..., FC^N
    for fc_power in range(1, max_fc_power+1):
        axes[-1, fc_power-1].set_xlabel(f'FC^{fc_power}', fontsize=fs)

    # normalize fcs
    synth_fcs_normed = normalize_slices(synth_fcs, which_norm=which_norm)
    ps_fcs_normed = normalize_slices(ps_fcs, which_norm=which_norm)
    real_fcs_normed = normalize_slices(real_fcs, which_norm=which_norm)

    bins = 50
    for fc_power in range(1, max_fc_power+1):
        synth_fcs_pow, ps_fcs_pow, real_fcs_pow = torch.matrix_power(synth_fcs_normed, n=fc_power), torch.matrix_power(ps_fcs_normed, n=fc_power), torch.matrix_power(real_fcs_normed, n=fc_power)
        synth_vv, ps_vv, real_vv = upper_tri_as_vec_batch(synth_fcs_pow, offset=0).numpy(), upper_tri_as_vec_batch(ps_fcs_pow, offset=0).numpy(), upper_tri_as_vec_batch(real_fcs_pow, offset=0).numpy()

        # ignore 0 entries in summary statistics
        if not include_zero_entries:
            min_mag = 1e-9
            synth_vv[np.abs(synth_vv)<min_mag], ps_vv[np.abs(ps_vv)<min_mag], real_vv[np.abs(real_vv)<min_mag] = np.nan, np.nan, np.nan

        synth_med, ps_med, real_med = np.nanmedian(synth_vv, axis=1), np.nanmedian(ps_vv, axis=1), np.nanmedian(real_vv, axis=1)
        synth_max, ps_max, real_max = np.nanmax(np.abs(synth_vv), axis=1), np.nanmax(np.abs(ps_vv), axis=1), np.nanmax(np.abs(real_vv), axis=1)
        synth_frob, ps_frob, real_frob = torch.linalg.norm(synth_fcs_pow, ord='fro', dim=(1, 2)).view(-1, 1).numpy(), torch.linalg.norm(ps_fcs_pow, ord='fro', dim=(1, 2)).view(-1, 1).numpy(), torch.linalg.norm(real_fcs_pow, ord='fro', dim=(1, 2)).view(-1, 1).numpy()


        axes[0, fc_power-1].hist(synth_med.flatten(), bins=bins, label='Synth FCs', color='blue', alpha=.5)
        axes[0, fc_power-1].hist(ps_med.flatten(), bins=bins, label='PS FCs', color='red', alpha=.5)
        axes[0, fc_power-1].hist(real_med.flatten(), bins=bins, label='Real FCs', color='magenta', alpha=.5)

        axes[1, fc_power-1].hist(synth_max.flatten(), bins=bins, label='Synth FCs', color='blue', alpha=.5)
        axes[1, fc_power-1].hist(ps_max.flatten(), bins=bins, label='PS FCs', color='red', alpha=.5)
        axes[1, fc_power-1].hist(real_max.flatten(), bins=bins, label='Real FCs', color='magenta', alpha=.5)

        axes[2, fc_power-1].hist(synth_frob.flatten(), bins=bins, label='Synth FCs', color='blue', alpha=.5)
        axes[2, fc_power-1].hist(ps_frob.flatten(), bins=bins, label='PS FCs', color='red', alpha=.5)
        axes[2, fc_power-1].hist(real_frob.flatten(), bins=bins, label='Real FCs', color='magenta', alpha=.5)

        axes_all[fc_power-1].hist(synth_vv.flatten(), bins=bins, label='Synth FCs', color='blue', alpha=.5)
        axes_all[fc_power-1].hist(ps_vv.flatten(), bins=bins,  label='PS FCs', color='red', alpha=.5)
        axes_all[fc_power-1].hist(real_vv.flatten(), bins=bins,  label='Real FCs', color='magenta', alpha=.5)
        print(f'\npower {fc_power}: MEAN -    synthetic: {np.mean(synth_vv)}, ps: {np.mean(ps_vv)}, real: {np.mean(real_vv)}')
        print(f'power {fc_power}: MAX ABS - synthetic: {np.max(np.abs(synth_vv))}, ps: {np.max(np.abs(ps_vv))}, real: {np.mean(np.abs(real_vv))}')
        axes_all[fc_power-1].legend()

        #axes.legend() # common legend?
        for ax in axes[:, fc_power-1]:
            ax.set_xscale('log', base=10)

    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)
    return fig


def fc_distributions(fcs: np.ndarray, title_str=''):
    fs = 10
    nrows, ncols = 5, 1
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, constrained_layout=True)
    fig.suptitle(f'FC Norm Distributions: {title_str}', fontsize=20)
    frob_norm_ax, max_ax, median_ax, ave_ax, all_ax = axes[0], axes[1], axes[2], axes[3], axes[4]

    raw_frob_norms = np.linalg.norm(fcs, ord='fro', axis=(1,2))
    raw_max_elems = np.amax(fcs, axis=(1, 2))
    raw_median_elems = np.median(fcs, axis=(1, 2))
    raw_ave_elems = np.mean(fcs, axis=(1, 2))
    raw_all_elems = np.reshape(fcs, newshape=(-1))


    normed_fcs = normalize_slices(torch.tensor(fcs), which_norm='frob').numpy()

    normed_frob_norms = np.linalg.norm(normed_fcs, ord='fro', axis=(1,2))
    normed_max_elems = np.amax(normed_fcs, axis=(1, 2))
    normed_median_elems = np.median(normed_fcs, axis=(1, 2))
    normed_ave_elems = np.mean(normed_fcs, axis=(1, 2))
    normed_all_elems = np.reshape(normed_fcs, newshape=(-1))

    raw_vals = [raw_frob_norms, raw_max_elems, raw_median_elems, raw_ave_elems, raw_all_elems]
    normed_vals = [normed_frob_norms, normed_max_elems, normed_median_elems, normed_ave_elems, normed_all_elems]
    titles = ['Frob Norms', 'Max Entries', 'Median Entries', 'Ave Entires', 'All Entires']
    ax_list = [frob_norm_ax, max_ax, median_ax, ave_ax, all_ax]

    bins = 50

    for idx, (r, n, title, ax) in enumerate(zip(raw_vals, normed_vals, titles, ax_list)):
        ax.hist(r.flatten(), bins=bins, label='Raw FCs', color='blue', alpha=.5)
        ax.hist(n.flatten(), bins=bins, label='Normed FCs', color='red', alpha=.5)
        ax.set_title(title, fontsize=fs)
        ax.legend()
        ax.set_xscale('log', base=10)

    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)
    return fig


def intermediate_value_distributions(intermed_values, depth, bins=25):
    fs = 10
    fig, axes = plt.subplots(nrows=depth, ncols=2, constrained_layout=True, sharex='col')
    axes[0, 0].set_title('Frob Norm', fontsize=fs)
    axes[0, 1].set_title('Entries', fontsize=fs)

    for layer in range(depth):
        S_in, cov_norm, cov_raw, H1, H2, temp_sym_zd, S_out = \
            intermed_values['S_ins'][layer], intermed_values['covs_normed'], intermed_values['covs_raw'], intermed_values['H1s'][layer], intermed_values['H2s'][layer], \
            intermed_values['temp_sym_zds'][layer], intermed_values['S_outs'][layer]
        S_in, cov_norm, cov_raw, H1, H2, temp_sym_zd, S_out = S_in.numpy(), cov_norm.numpy(), cov_raw.numpy(), H1.numpy(), H2.numpy(), temp_sym_zd.numpy(), S_out.numpy()

        axes[layer, 0].set_ylabel(f'Layer {layer+1}', rotation=90, fontsize=fs)

        alpha = .4
        # take frob norm of each
        axes[layer, 0].hist(np.linalg.norm(S_in, ord='fro', axis=(1, 2)), bins=bins, label='S_in', color='green', alpha=alpha)
        #axes[layer, 0].hist(np.linalg.norm(cov_raw, ord='fro', axis=(1, 2)), label='Raw FCs', color='black', alpha=alpha)
        axes[layer, 0].hist(np.linalg.norm(H1, ord='fro', axis=(1, 2)), bins=bins, label='H1', color='blue', alpha=alpha)
        axes[layer, 0].hist(np.linalg.norm(H2, ord='fro', axis=(1, 2)), bins=bins, label='H2', color='red', alpha=alpha)

        # over all entries
        bin_scale=5
        axes[layer, 1].hist(S_in.flatten(), bins=bin_scale*bins, label='S_in', color='green', alpha=alpha)
        #axes[layer, 1].hist(cov_raw.flatten(), label='Raw FCs', color='black', alpha=alpha)
        axes[layer, 1].hist(cov_norm.flatten(), bins=bin_scale*bins, label='Norm FCs', color='yellow', alpha=alpha)
        axes[layer, 1].hist(H1.flatten(), bins=bin_scale*bins, label='H1', color='blue', alpha=alpha)
        axes[layer, 1].hist(H2.flatten(), bins=bin_scale*bins, label='H2', color='red', alpha=alpha)

        axes[layer, 0].set_xscale('log', base=10)
        axes[layer, 1].set_xscale('log', base=10)
        axes[layer, 0].set_yticklabels([])
        axes[layer, 1].set_yticklabels([])
        axes[layer, 0].set_yticks([])
        axes[layer, 1].set_yticks([])

    hists, labels = axes[layer, 1].get_legend_handles_labels()
    fig.legend(hists, labels, loc='center')
    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)
    return fig


if __name__ == "__main__":

    # fc power distribution
    from data.pl_data import SyntheticDataModule, PsuedoSyntheticDataModule, RealDataModule
    synth_dm = SyntheticDataModule(num_samples_train=1002, num_signals=500, coeffs=np.array([0.5, 0.5, 0.2]))
    ps_dm = PsuedoSyntheticDataModule(num_patients_test=1, num_patients_val=1, num_signals=500, coeffs=np.array([0.5, 0.5, 0.2]))
    real_dm = RealDataModule(num_patients_test=1, num_patients_val=1)
    synth_dm.setup(); ps_dm.setup(); real_dm.setup()
    synth_fcs, ps_fcs, real_fcs = synth_dm.train_dataloader().dataset.full_ds()[0], ps_dm.train_dataloader().dataset.full_ds()[0], real_dm.train_dataloader().dataset.full_ds()[0]
    synth_scs, real_scs = synth_dm.train_dataloader().dataset.full_ds()[1], real_dm.train_dataloader().dataset.full_ds()[1]

    #fig = fc_power_distributions(synth_fcs=synth_fcs, ps_fcs=ps_fcs, real_fcs=real_fcs, max_fc_power=5, which_norm='max_eig')
    fig = adj_power_distributions(synth=synth_scs, real=real_scs, max_power=4, which_norm='max_eig')
    plt.show()

    """
    np.random.seed(19680801)

    depth = 8
    h1_coeffs = []
    h2_coeffs = []
    for i in range(depth):
        h1_coeffs.append(np.random.randn(2))
        h2_coeffs.append(np.random.randn(2))

    polynomial_filters(h1_coeffs, h2_coeffs, xlims=(0, 10))




    bs, N = 3, 10
    sz = (bs, N, N)
    truth_sampled = np.random.rand(*sz)

    true_edge_mask = (np.random.rand(*sz)>.5)
    true_edges = np.where(true_edge_mask, truth_sampled, 0)
    pred_edges = true_edges + np.random.randn(*sz)/3 #add noise
    pred_edges = np.where(pred_edges > 0, pred_edges, 0)

    # make symmetric
    true_edges = (true_edges+np.transpose(true_edges, axes=(0, 2, 1)))/2
    pred_edges = (pred_edges+np.transpose(pred_edges, axes=(0, 2, 1)))/2
    for slice in range(bs):
        true_edges[slice][np.diag_indices(N)] = 0  # remove diagonal entries
        pred_edges[slice][np.diag_indices(N)] = 0


    node_labels = [f'{x}' for x in range(10)]
    edge_metrics(true_edges, pred_edges, threshold=0.5, node_labels=node_labels)

    fig, axes = plt.subplots(nrows=2, ncols=1)
    data = [np.random.normal(loc=i, scale=2, size=100) for i in range(8)]
    for i in range(len(axes)):
        ridgeline(data, fig=fig, ax=axes[i], overlap=.85, fill='y', xlabel='FMSR', ylabel='Threshold')





    batch = (None, true_edges)
    num_patients = 3
    #make_edge_error_subplots(batch, pred_edges, num_patients, true_edge_max=1, num_bins_error=10, num_bins_true=10)

    make_all_edge_error_subplots(batch, pred_edges, true_edge_max=1, num_bins_error=10, num_bins_true=10)

    #make function to viz logistic regression of matrix outputs
    bs, N = 2,10
    sz = (bs, N, N)
    truth_raw = torch.rand(sz) # [0,1)
    truths = (truth_raw>0.5) + 0.0
    preds = truth_raw + torch.randn((sz))/5
    max_preds = preds.max().item()
    num_patients = 2
    fs = 10
    ncols = 1  # 2 + len(thresholds) # raw pred in front, true adj at end
    fig, axes = plt.subplots(nrows=num_patients, ncols=ncols, sharex=True)

    # set titles of each column
    a, b = 10.0, -5
    axes[0].set_title(f'Logistic Func: a: {a}, b: {b}. Decision: {-b/a}', fontsize=fs+4)

    x_sig = np.linspace(0, max_preds, 50)
    sig = expit(x_sig * a + b).ravel()
    for row in range(num_patients):
        ax = axes[row]
        ax.set_ylabel(f'Patient {row}', rotation=90, fontsize=fs)
        p_vec = upper_tri_as_vec(preds[row], offset=1)
        t_vec = upper_tri_as_vec(truths[row], offset=1)
        ax.scatter(p_vec.numpy(), t_vec.numpy(), color='black', zorder=20)
        a, b = 10.0, -5
        ax.plot(x_sig, sig, color='red', linewidth=3)
        ax.set_xticks(np.arange(0, 1, step=0.2))
        ax.set_ylim(bottom=-.1, top=1.1)
        max_x_lim = np.max(max_preds, 1)
        ax.set_xlim(left=0, right = max_x_lim)
        ax.vlines(-b/a, -.1, 1.1)
    """
