from __future__ import print_function, division
import numpy as np
import torch
import pytorch_lightning as pl
import math

from utils.plotting import threshold_metric_plot, \
    make_edge_subplots, make_all_edge_plot, logistic_func_plot, make_edge_error_subplots, make_all_edge_error_subplots,\
    metrics_distrib, metrics_2d, model_outputs, ridgeline, multiple_ridgeline, edge_metrics, intermediate_outputs, \
    polynomial_filters, fc_distributions, intermediate_value_distributions
from model.model_utils import dict_param, construct_named_coeffs_dict, metrics_at_thresholds, \
    score_graphs_batch, are_scs_binary, prediction_metrics
from utils.util_funcs import cdp, upper_tri_as_vec, sparsity, normalize_slices

class LoggingCallback(pl.Callback):

    def __init__(self):
        super().__init__()
        self.val_batch_losses, self.val_epoch_losses = [], []
        self.train_batch_losses, self.train_epoch_losses = [], []
        self.mean_squared_errors_epoch = []
        self.extreme_eigvals = {}

    """
    def on_train_start(self, trainer, pl_module):
        fcs, scs, subject_ids, scan_dirs, tasks =  pl_module.val_dataloader().dataset.full_ds()
        model_threshold = pl_module.threshold[0].item()

        ######### ▼▼▼▼ distribution of metrics with FIXED threshold ▼▼▼▼ #########
        scan_mses, scan_maes = mae_per_scan(scs=scs, preds=pl_module.prior).numpy()
        scs = scs.numpy()
        pr, re, f1, macro_F1, acc, mcc \
            = score_graphs_batch(threshold=model_threshold, adjs=scs,
                                 preds=pl_module.prior.repeat((len(scs), 1, 1)), o="raw")
        fig_distrib = metrics_distrib(pr=pr, re=re, f1=f1, macro_f1=macro_F1, acc=acc, mcc=mcc,
                                      threshold=model_threshold)
        pl_module.logger.experiment.add_figure(f'Prior as Prediction/Metric Distrib', fig_distrib)

        fig_2d_mse_acc_macro_F1 = \
            metrics_2d(x=scan_mses, x_label=f'MSE',
                       y=acc, y_label=f'Acc @ {model_threshold:.2f}',
                       subject_ids=subject_ids, scan_dirs=scan_dirs,
                       scores=macro_F1, scores_label='macro-f1', annot_pts=5,
                       xlims=None, ylims=(0, 1))
        sparsities = sparsity(As=scs, directed=False, self_loops=False)
        fig_2d_acc_sparsity_mse = \
            metrics_2d(x=sparsities, x_label=f'Sparsities',
                       y=acc, y_label=f'Acc @ {model_threshold:.2f}',
                       subject_ids=subject_ids, scan_dirs=scan_dirs,
                       scores=scan_mses, scores_label='MSE', annot_pts=5,
                       xlims=(0, 1), ylims=(0, 1))
        pl_module.logger.experiment.add_figure(f'Prior as Prediction/2D ACC vs MSE with macro-F1 score', fig_2d_mse_acc_macro_F1)
        pl_module.logger.experiment.add_figure(f'Prior as Prediction/2D ACC vs Sparsity with MSE score', fig_2d_acc_sparsity_mse)
        ######### ▲▲▲▲ distribution of metrics with FIXED threshold ▲▲▲▲ #########


        ##### ▼▼▼▼ Initial Polynomial Filters ▼▼▼▼ ######
        # What x-axis range should be used for polynomial plotting?
        # Unclear...take eigenvalues of all components (S_in, Cov, Prior)
        # and use extreme vals
        normed_fcs = normalize_slices(fcs, which_norm=pl_module.fc_norm).numpy()
        extreme_fc_eigvals = []
        for fc_idx in range(len(normed_fcs)):
            fc_eigs = np.linalg.eigvalsh(normed_fcs[fc_idx])  # in sorted order
            extreme_fc_eigvals.extend([fc_eigs[0], fc_eigs[-1]])
        self.extreme_eigvals['fc'] = [min(extreme_fc_eigvals), max(extreme_fc_eigvals)]

        extreme_sc_eigvals = []
        for sc_idx in range(len(scs)):
            sc_eigs = np.linalg.eigvalsh(scs[sc_idx])  # in sorted order
            extreme_sc_eigvals.extend([sc_eigs[0], sc_eigs[-1]])
        self.extreme_eigvals['sc'] = [min(extreme_sc_eigvals), max(extreme_fc_eigvals)]

        prior_eigvals = np.linalg.eigvalsh(pl_module.prior)
        extreme_prior_eigvals = [prior_eigvals[0], prior_eigvals[-1]]
        self.extreme_eigvals['prior'] = extreme_prior_eigvals

        all_extreme_eigs = extreme_fc_eigvals + extreme_sc_eigvals + extreme_prior_eigvals
        self.extreme_eigvals['all'] = [min(all_extreme_eigs), max(all_extreme_eigs)]

        # no2 plot them
        h1_coeffs = [l.coeffs_1.detach().numpy() for l in pl_module.prox_layers]
        h2_coeffs = [l.coeffs_1.detach().numpy() for l in pl_module.prox_layers]

        poly_fig = polynomial_filters(h1_coeffs, h2_coeffs,
                                      xlims=(self.extreme_eigvals['all'][0], self.extreme_eigvals['all'][1]),
                                      num_points=10)
        pl_module.logger.experiment.add_figure(f'Initial Filters/polynomial filters', poly_fig)
        ##### ▲▲▲▲ Initial Polynomial Filters ▲▲▲▲ ######

        ##### ▼▼▼▼ FC norm distribution ▼▼▼▼ ######
        train_fcs, _, _, _ = pl_module.train_dataloader().dataset.full_ds()
        val_fcs, _, _, _ = pl_module.val_dataloader().dataset.full_ds()
        fcs = torch.cat((train_fcs, val_fcs)).numpy()
        raw_fig = fc_distributions(fcs=fcs, title_str='No FC scaling')
        scaled_fig = fc_distributions(fcs=fcs*pl_module.fc_scaling, title_str=f'FC Scaling of {pl_module.fc_scaling:.8f}')

        pl_module.logger.experiment.add_figure(f'FC Distributions/Raw FCs', raw_fig)
        pl_module.logger.experiment.add_figure(f'FC Distributions/Scaled FCs', scaled_fig)

        #fig = intermediate_value_distributions(pl_module.forward_intermed_outs(fcs), pl_module.depth)
        #pl_module.logger.experiment.add_figure(f'Arg Distributions/Intermediate Sizes', fig)

        ##### ▲▲▲▲ FC norm distribution ▲▲▲▲ ######
    """

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        if pl_module.logging is None:
            return
        # log paramater values
        global_step = pl_module.global_step
        for layer, prox_layer in enumerate(pl_module.prox_layers):
            param_dict = prox_layer.param_dict()
            pl_module.logger.experiment.add_scalars(f'layer:{layer}', param_dict, global_step)

        """
        tau_dict, coeffs1_dict, coeffs2_dict, log_regr_dict = dict_param(pl_module)

        if pl_module.learn_tau:
            pl_module.logger.experiment.add_scalars('tau', tau_dict, global_step)

        all_coeffs1_dict = construct_named_coeffs_dict(coeffs1_dict)
        pl_module.logger.experiment.add_scalars('coeffs 1', all_coeffs1_dict, global_step)

        all_coeffs2_dict = construct_named_coeffs_dict(coeffs2_dict)
        pl_module.logger.experiment.add_scalars('coeffs_2', all_coeffs2_dict, global_step)

        if pl_module.log_regr:
            pl_module.logger.experiment.add_scalars('log_regr', log_regr_dict, global_step)

        """
    def on_after_backward(self, trainer, pl_module):
        if pl_module.logging is None:
            return
        #log gradients
        global_step = pl_module.global_step
        for layer, prox_layer in enumerate(pl_module.prox_layers):
            grad_dict = prox_layer.param_dict(gradients=True)
            pl_module.logger.experiment.add_scalars(f'layer:{layer}', grad_dict, global_step)

        """
        for layer, prox_layer in enumerate(pl_module.prox_layers):
            ave_killed_by_sr_per_batch = prox_layer.ave_killed_by_sr_per_batch
            C_out, C_in = ave_killed_by_sr_per_batch.shape
            sr_kill_dict = {}
            for out_channel in range(C_out):
                s = f'out:{out_channel}'
                for in_channel in range(C_in):
                    s_ = s + f'||in:{in_channel}'
                    sr_kill_dict[s_] = ave_killed_by_sr_per_batch[out_channel, in_channel]

            pl_module.logger.experiment.add_scalars(f'Shift ReLU Kill % per in channel: {layer}', sr_kill_dict, global_step)
        """

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        return

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        if torch.cuda.is_available():
            ave_loss = torch.stack([x['loss'] for x in outputs]).mean()
        else:
            ave_loss = torch.stack([x[0][0]['minimize'] for x in outputs]).mean()

        self.train_epoch_losses.append(ave_loss)
        #self.log('per_epoch/ave_train_loss', ave_loss, on_epoch=True, logger=True)
        #pl_module.logger.experiment.add_scalar('per_epoch/ave_train_loss', ave_loss, pl_module.current_epoch)

    # we must do this bc on_validation_epoch_end does not have outputs like on_train_epoch_end
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        if trainer.running_sanity_check:
            return
        self.val_batch_losses.append(outputs) #loss, pr, re, and fmsr

    def on_validation_epoch_end(self, trainer, pl_module):
        if trainer.running_sanity_check:
            return
        ave_loss = torch.stack([x['val_loss'] for x in self.val_batch_losses]).mean()
        ave_macro_F1 = torch.stack([x['val_macro_F1'] for x in self.val_batch_losses]).mean()
        ave_acc = torch.stack([x['val_acc'] for x in self.val_batch_losses]).mean()
        ave_mcc = torch.stack([x['val_mcc'] for x in self.val_batch_losses]).mean()
        self.val_batch_losses = []

        # log current performance and prior performance at each epoch (prior doesnt change but vizually nice)
        if pl_module.logging is not None:
            metrics_dict = {'macro_F1': ave_macro_F1, 'acc': ave_acc, 'mcc': ave_mcc, 'mse': ave_loss}
            pl_module.logger.experiment.add_scalars(f'metrics/(epoch)@best threshold ({pl_module.threshold_metric})/', metrics_dict, pl_module.current_epoch)
            for prior_type, prior_metric in pl_module.prior_metrics.items():
                pl_module.logger.experiment.add_scalars(f'metrics/{prior_type}', prior_metric, pl_module.current_epoch)

        if len(self.train_epoch_losses) == 0:
            losses = {'validation': ave_loss}
        else:
            losses = {'train (ave across batches)': self.train_epoch_losses[-1], 'validation': ave_loss}
        if pl_module.logging is not None:
            pl_module.logger.experiment.add_scalars('per_epoch_losses', losses, global_step=pl_module.current_epoch)

        # do pass on val data for mse losses
        batch = next(iter(trainer.val_dataloaders[0]))
        fcs, scs, subject_ids, scan_dirs, tasks = batch
        batch_size, N, _ = fcs.shape

        with torch.no_grad():
            if pl_module.prior_construction in ['single', None, 'block']:
                prior = torch.broadcast_to(pl_module.prior, (1, batch_size, N, N))
            elif pl_module.prior_construction in ['single_grouped']:
                # USE FULL PRIOR IN VALIDATION
                prior = torch.broadcast_to(pl_module.full_prior, (1, batch_size, N, N))
            elif pl_module.prior_construction in ['multi']:
                prior_0 = torch.broadcast_to(pl_module.prior[0][0], (1, batch_size, N, N))
                prior_1 = torch.broadcast_to(pl_module.prior[1][0], (1, batch_size, N, N))
                prior = torch.cat((prior_0, prior_1), dim=0)  # for contiguos layout. Needed for normalization.
            else:
                raise ValueError('Indeterminate Prior Construction')
            device = torch.device('cuda') if torch.cuda.is_available() else prior.device()
            fcs, scs = fcs.to(device), scs.to(device)
            preds = pl_module(fcs.to(device), prior.to(device)).detach()

        # find mse across scans
        # use performance_metrics() in model_utisl to find loss_per_scan
        #self.mean_squared_errors_epoch.append(loss_per_scan(scs=scs, preds=preds)[0].numpy())
        #self.mean_absolute_errors_epoch.append(loss_per_scan(scs=scs, preds=preds)[1].numpy())

        # lowest validation loss thus far?
        self.val_epoch_losses.append(ave_loss.item())
        best_loss = min(self.val_epoch_losses)
        is_best_val_loss = (ave_loss.item() <= best_loss)

        if is_best_val_loss:
            #print(f"New Best Val Loss found on {pl_module.current_epoch} epoch")
            pl_module.log(f'Best Epoch (loss)', pl_module.current_epoch, logger=False, prog_bar=True)
            logging_info = {'logger': pl_module.logger,
                            'epoch': pl_module.current_epoch,
                            'global_step': pl_module.global_step
                            }

            if pl_module.logging == 'all':
                # pull out cortical labels
                try:
                    cortical_labels = trainer.datamodule.metadata['idx2label_cortical']['label_name']
                except AttributeError:
                    cortical_labels = None
                make_epoch_plots(log_info=logging_info, model=pl_module,
                                 fcs=fcs, scs=scs, preds=preds,
                                 epoch_mses=self.mean_squared_errors_epoch,
                                 subject_ids=subject_ids,
                                 extreme_eigvals=self.extreme_eigvals,
                                 scan_dirs=scan_dirs,
                                 tasks=tasks,
                                 cortical_labels=cortical_labels)
"""
    # move this functionality to on_test_end?
    def on_train_end(self, trainer, pl_module):

        # dont use output of trainer.test() but the test call fetches test_dataloader
        trainer.test()
        batch = next(iter(trainer.test_dataloaders[0]))
        fcs, scs, subject_ids, scan_dirs, tasks = batch
        with torch.no_grad():
            preds = pl_module(fcs)
        if pl_module.log_regr:
            test_loss = F.binary_cross_entropy_with_logits(preds, scs, reduction='mean')
        else:
            test_loss = F.mse_loss(preds, scs, reduction='mean')

        # find mse FOR EACH SCAN  in test set
        scan_mses, scan_maes = loss_per_scan(scs=scs, preds=preds).numpy()

        fcs, scs, preds = fcs.detach().numpy(), scs.detach().numpy(), preds.detach().numpy()
        prior =  pl_module.prior.detach().numpy()
        model_threshold = pl_module.threshold.detach().item()

        ##### ▼▼▼▼ Test Set Metrics using model threshold  ▼▼▼▼ ######
        ave_pr, ave_re, ave_f1, ave_macro_F1, ave_acc, ave_mcc \
            = score_graphs_batch(threshold=model_threshold, adjs=scs, preds=preds, o="ave")
        test_txt = f'Ave MSE: {test_loss:.5f}\n'
        test_txt += f'Ave MCC: {ave_mcc:.5f}\n'
        test_txt += f'Ave ACC: {ave_acc:.3f}. Threshold @ {model_threshold:.3f}\n'
        test_txt += f'macro_F1:{ave_macro_F1:.3f}'
        pl_module.logger.experiment.add_text('Test metrics', test_txt)
        ##### ▲▲▲▲ Test Set Metrics using model threshold  ▲▲▲▲ ######

        ##### ▼▼▼▼ MSE Distribution w/ best threshold vs Epoch ▼▼▼▼ ######
        labels = [f'{e:.0f}' for e in range(len(self.mean_squared_errors_epoch))]
        fig_rl_ell_epoch_mse, _ = ridgeline(data=self.mean_squared_errors_epoch, xlabel='Val MSE', ylabel='Epochs',
                                            labels=labels, xlow=0.0)
        pl_module.logger.experiment.add_figure(f'mse distribution (@ best thresh) vs epoch/ridgeline MSEs',
                                               fig_rl_ell_epoch_mse)
        ##### ▲▲▲▲ MSE Distribution w/ best threshold vs Epoch ▲▲▲▲ ######

        ##### ▼▼▼▼ Metric Distribution vs Threshold ▼▼▼▼ ######
        thresholds = np.concatenate((np.arange(0,.25,.01), np.arange(.25,.8,.05)), axis=None)

        precisions, recalls, f1s, macro_F1s, accs, mccs = metrics_at_thresholds(thresholds=thresholds, adjs=scs, preds=preds, o='raw')
        labels = [f'{cv:.3f}' for cv in thresholds]
        fig_rl_f1s, _ = ridgeline(data=f1s, xlabel='F1', ylabel='Thresholds', labels=labels, xlow=0.0, xhigh=1.0)
        fig_rl_macro_F1s, _ = ridgeline(data=macro_F1s, xlabel='Macro F1', ylabel='Thresholds', labels=labels, xlow=0.0, xhigh=1.0)
        fig_rl_accs, _ = ridgeline(data=accs, xlabel='Accuracies', ylabel='Thresholds', labels=labels, xlow=0.0, xhigh=1.0)
        valid_mccs = [[mcc for mcc in mccs_ if ((not math.isnan(mcc)) and (not math.isinf(mcc)))] for mccs_ in mccs]
        fig_rl_mccs, _ = ridgeline(data=valid_mccs, xlabel='MCCs', ylabel='Thresholds', labels=labels, xlow=-1.0, xhigh=1.0)
        pl_module.logger.experiment.add_figure(f'test/ridgeline F1s', fig_rl_f1s)
        pl_module.logger.experiment.add_figure(f'test/ridgeline macro F1s', fig_rl_macro_F1s)
        pl_module.logger.experiment.add_figure(f'test/ridgeline accs', fig_rl_accs)
        pl_module.logger.experiment.add_figure(f'test/ridgeline mccs', fig_rl_mccs)
        ##### ▲▲▲▲ Metric Distribution vs Threshold ▲▲▲▲ ######


        ##### ▼▼▼▼ Model Outputs  ▼▼▼▼ ######
        # representative outputs (best, median, worst) by some metric (acc, mse, fmsr)
        fig_samples_sorted_outputs_acc = \
            model_outputs(fcs=fcs, scs=scs, preds=preds, mses=scan_mses,
                          threshold=model_threshold,
                          subject_ids=subject_ids, scan_dirs=scan_dirs,
                          prior=prior, sort_by='acc')
        pl_module.logger.experiment.add_figure(f'test/samples sorted by acc', fig_samples_sorted_outputs_acc)
        ##### ▲▲▲▲ Model Outputs  ▲▲▲▲ ######

        ######### ▼▼▼▼ metrics vs thresholds ▼▼▼▼ #############
        fig_metrics_vs_threshs = threshold_metric_plot(adjs=scs, preds=preds,
                                           curve_thresholds=pl_module.threshold_metric_test_points,
                                           o='ave')
        pl_module.logger.experiment.add_figure(f'test/Metrics vs Thresholds', fig_metrics_vs_threshs)
        ######### ▲▲▲▲ metrics vs thresholds ▲▲▲▲ #############

        ######### ▼▼▼▼ distribution of metrics with FIXED threshold ▼▼▼▼ #########
        pr, re, f1, macro_F1, acc, mcc \
            = score_graphs_batch(threshold=model_threshold, adjs=scs, preds=preds, o="raw")
        fig_distrib = metrics_distrib(pr=pr, re=re, f1=f1, macro_F1=macro_F1, acc=acc, mcc=mcc, threshold=model_threshold)
        pl_module.logger.experiment.add_figure(f'test/Metric Distrib', fig_distrib)

        fig_2d_mse_acc_macro_F1 = \
            metrics_2d(x=scan_mses, x_label=f'MSE',
                       y=acc, y_label=f'Acc @ {model_threshold:.2f}',
                       subject_ids=subject_ids, scan_dirs=scan_dirs,
                       scores=macro_F1, scores_label='macro-f1', annot_pts=5,
                       xlims=None, ylims=(0, 1))
        sparsities = sparsity(As=scs, directed=False, self_loops=False)
        fig_2d_acc_sparsity_mse = \
            metrics_2d(x=sparsities, x_label=f'Sparsities',
                       y=acc, y_label=f'Acc @ {model_threshold:.2f}',
                       subject_ids=subject_ids, scan_dirs=scan_dirs,
                       scores=scan_mses, scores_label='MSE', annot_pts=5,
                       xlims=(0, 1), ylims=(0, 1))
        pl_module.logger.experiment.add_figure(f'test/2D Acc vs MSE with macro-F1 score', fig_2d_mse_acc_macro_F1)
        pl_module.logger.experiment.add_figure(f'test/2D ACC vs Sparsity with MSE score', fig_2d_acc_sparsity_mse)
        ######### ▲▲▲▲ distribution of metrics with FIXED threshold ▲▲▲▲ #########

        ######### ▼▼▼▼ Performance on each EDGE in TEST set ▼▼▼▼ #########
        fig_edge_metrics = edge_metrics(scs=scs, preds=preds, threshold=model_threshold, node_labels=[])
        pl_module.logger.experiment.add_figure(f'test/Edge Metrics', fig_edge_metrics)
        ######### ▲▲▲▲ Performance on each EDGE in TEST set ▲▲▲▲ #########

"""

def make_epoch_plots(log_info, model, fcs, scs, preds, epoch_mses, subject_ids, extreme_eigvals, scan_dirs=None, tasks=None, cortical_labels=None, visible_figs = False):
    logger, current_epoch, global_step = log_info['logger'], log_info['epoch'], log_info['global_step']
    pl_module = model

    fcs, scs, preds = fcs.detach().numpy(), scs.detach().numpy(), preds.detach().numpy()
    prior = pl_module.prior.detach().numpy()

    num_patients_display = 3 # remove this eventually

    # thresholded outputs of model compared to labels
    best_threshold = pl_module.threshold[0].item()

    normed_fcs = normalize_slices(torch.from_numpy(fcs), which_norm=pl_module.fc_norm).numpy()

    """
    ##### ▼▼▼▼ Learned Polynomial Filters ▼▼▼▼ ######
    # must be REMADE FOR MIMO
    h1_coeffs = [l.coeffs_1.detach().numpy() for l in pl_module.prox_layers]
    h2_coeffs = [l.coeffs_2.detach().numpy() for l in pl_module.prox_layers]

    poly_fig = polynomial_filters(h1_coeffs, h2_coeffs, xlims=(extreme_eigvals['all'][0], extreme_eigvals['all'][1]), num_points=10)
    pl_module.logger.experiment.add_figure(f'epoch {current_epoch}/polynomial filters', poly_fig,
                                           global_step=global_step)
    ##### ▲▲▲▲ Learned Polynomial Filters ▲▲▲▲ ######
    """
    """
    # MUST BE REMADE FOR MIMO
    ##### ▼▼▼▼ Intermediate Outputs of model ▼▼▼▼ ######
    int_out_fig = intermediate_outputs(model=pl_module, depth=pl_module.depth,
                                       fcs=normed_fcs, scs=scs, preds=preds,
                                       subject_ids=subject_ids, scan_dirs=scan_dirs,
                                       threshold=best_threshold, sort_by='acc',
                                       fc_pctile=100, pred_pctile=95)
    pl_module.logger.experiment.add_figure(f'epoch {current_epoch}/intermediate outputs by acc', int_out_fig,
                                           global_step=global_step)

    fig = intermediate_value_distributions(pl_module.forward_intermed_outs(fcs), pl_module.depth)
    pl_module.logger.experiment.add_figure(f'epoch {current_epoch}/Intermediate Sizes', fig)
    ##### ▲▲▲▲ Intermediate Outputs of model ▲▲▲▲ ######
    """

    ##### ▼▼▼▼ Metric Distribution vs Threshold ▼▼▼▼ ######
    labels = [f'{cv:.3f}' for cv in pl_module.threshold_metric_test_points]
    precisions, recalls, f1s, macro_F1s, accs, mccs = \
        metrics_at_thresholds(thresholds=pl_module.threshold_metric_test_points,
                              adjs=scs, preds=preds, o='raw')
    # remove nans in mcc
    valid_mccs = [[mcc for mcc in mccs_ if ((not math.isnan(mcc)) and (not math.isinf(mcc)))] for mccs_ in mccs]
    try:
        fig_metric_distribs = multiple_ridgeline(data_list=[macro_F1s, accs, valid_mccs],
                                                 metric_list=['macro_F1', 'acc', '(non-NaN) mcc'],
                                                 thresholds=pl_module.threshold_metric_test_points)
    except ValueError: #mccs could pose problems with nans
        fig_metric_distribs = multiple_ridgeline(data_list=[macro_F1s, accs],
                                                 metric_list=['macro_F1', 'acc'],
                                                 thresholds=pl_module.threshold_metric_test_points)
    pl_module.logger.experiment.add_figure(f'epoch {current_epoch}/metrics distrib vs thresholds', fig_metric_distribs, global_step=global_step)
    ##### ▲▲▲▲ Metric Distribution vs Threshold ▲▲▲▲ ######

    ##### ▼▼▼▼ Model Outputs  ▼▼▼▼ ######
    fig_repr_outputs_acc = model_outputs(fcs=normed_fcs, scs=scs, preds=preds, mses=epoch_mses[-1],
                                         threshold=best_threshold,
                                         subject_ids=subject_ids, scan_dirs=scan_dirs, prior=prior,
                                         sort_by='acc')
    logger.experiment.add_figure(f'epoch {current_epoch}/repres outputs by acc', fig_repr_outputs_acc, global_step=global_step)
    ##### ▲▲▲▲ Model Outputs  ▲▲▲▲ ######


    ######### ▼▼▼▼ metrics vs range of thresholds ▼▼▼▼ #############
    fig_thresh = threshold_metric_plot(adjs=scs, preds=preds,
                                       curve_thresholds=pl_module.threshold_metric_test_points,
                                       o='ave')
    logger.experiment.add_figure(f'epoch {current_epoch}/Metrics vs Thresholds', fig_thresh, global_step=global_step)
    ######### ▲▲▲▲ metrics vs range of thresholds ▲▲▲▲ #############

    ######### ▼▼▼▼ distribution of metrics with FIXED threshold ▼▼▼▼ #########
    pr, re, f1, macro_F1, acc, mcc \
        = score_graphs_batch(threshold=best_threshold, adjs=scs, preds=preds, o="raw")
    fig_distrib = metrics_distrib(pr=pr, re=re, f1=f1, macro_f1=macro_F1, acc=acc, mcc=mcc, threshold=best_threshold)
    logger.experiment.add_figure(f'epoch {current_epoch}/Metric Distrib', fig_distrib, global_step=global_step)

    fig_2d_mse_acc_macro_F1 = \
        metrics_2d(x=epoch_mses[-1], x_label=f'MSE',
                   y=acc, y_label=f'Acc @ {best_threshold:.2f}',
                   subject_ids=subject_ids, scan_dirs=scan_dirs,
                   scores=macro_F1, scores_label='macro-F1', annot_pts=5,
                   xlims=None, ylims=None)
    sparsities = sparsity(As=scs, directed=False, self_loops=False)
    fig_2d_acc_sparsity_mse = \
        metrics_2d(x=sparsities, x_label=f'Sparsities',
                   y=acc, y_label=f'Acc @ {best_threshold:.2f}',
                   subject_ids=subject_ids, scan_dirs=scan_dirs,
                   scores=epoch_mses[-1], scores_label='MSE', annot_pts=5,
                   xlims=(0, 1), ylims=(0, 1))
    # remove nan values in both mcc and epoch_mses[-1]
    """
    remove_mask_nan = np.isnan(mcc)
    fig_2d_mse_mcc = \
        metrics_2d(x=epoch_mses[-1][remove_mask_nan], x_label=f'MSE',
                   y=mcc[remove_mask_nan], y_label=f'MCC @ {best_threshold:.2f}',
                   subject_ids=subject_ids, scan_dirs=scan_dirs,
                   scores=acc[remove_mask_nan], scores_label='ACC', annot_pts=5,
                   xlims=None, ylims=(-1, 1))
    """
    logger.experiment.add_figure(f'epoch {current_epoch}/2D ACC vs MSE with macro-F1 score', fig_2d_mse_acc_macro_F1, global_step=global_step)
    logger.experiment.add_figure(f'epoch {current_epoch}/2D ACC vs Sparsity with MSE score', fig_2d_acc_sparsity_mse, global_step=global_step)

    ######### ▲▲▲▲ istribution of metrics with FIXED threshold ▲▲▲▲ #########

    ######### ▼▼▼▼ Performance on each EDGE in VALIDATION set ▼▼▼▼ #########
    #fig_edge_metrics = edge_metrics(scs=scs, preds=preds, threshold=best_threshold, node_labels=[])
    #logger.experiment.add_figure(f'epoch {current_epoch}/Edge Metrics', fig_edge_metrics, global_step=global_step)
    ######### ▲▲▲▲ Performance on each EDGE in VALIDATION set ▲▲▲▲ #########

    ######### ▼▼▼▼ Graphs Outputs at each INTERMEDIATE layer ▼▼▼▼ #########
    """
    intermediate_outputs = pl_module.forward_intermed_outs(fcs)['S_outs']
    fig_int = intermed_heatmap_outputs(intermediate_outputs, scs, fcs, num_patients_display)
    logger.experiment.add_figure(f'epoch {current_epoch}/Intermediate Outputs', fig_int, global_step=global_step)
    """
    ######### ▲▲▲▲ Graphs Outputs at each INTERMEDIATE layer ▲▲▲▲ #########

    binarized_labels = are_scs_binary(scs)
    # contin wieghts: if labels are 0/1 then plots are nonsense
    if binarized_labels and pl_module.log_regr:
        # plot fit logistic function
        a, bias = cdp(pl_module.log_regr_layer.a)[0], cdp(pl_module.log_regr_layer.bias)[0]
        fig_logr = logistic_func_plot(scs, preds, num_patients_display, a, bias)
        logger.experiment.add_figure(f'epoch {current_epoch}/Logistic Regression', fig_logr, global_step=global_step)

    if not binarized_labels:
        ######### ▼▼▼▼ Pred vs True edge weights for subset/all validation scans ▼▼▼▼ #########
        fig_edge_by_acc = make_edge_subplots(scs=scs, preds=preds, mses=epoch_mses[-1],
                                             subject_ids=subject_ids,
                                             scan_dirs=scan_dirs,
                                             threshold=best_threshold,
                                             sort_by='acc')
        logger.experiment.add_figure(f'epoch {current_epoch}/weight predictions by acc', fig_edge_by_acc, global_step=global_step)
        # pred edge weight vs label edge weight for a few validation patients

        # pred edge weight vs label edge weight for all validation patients
        fig_all_edge = make_all_edge_plot(scs, preds)#, true_edge_max=true_edge_max, pred_edge_max=pred_edge_max)
        logger.experiment.add_figure(f'epoch {current_epoch}/All weight predictions', fig_all_edge, global_step=global_step)
        ######### ▲▲▲▲ Pred vs True edge weights for subset/all validation scans ▲▲▲▲ #########


        ######### ▼▼▼▼  Distribution of Error:=(Pred - True) edge weights for subset/all validation scans ▼▼▼▼ #########
        # Group errors by True edge weight
        """
        fig = make_edge_error_subplots(scs, preds, num_patients_display, true_edge_max=true_edge_max, num_bins_error=10, num_bins_true=5)
        logger.experiment.add_figure(f'epoch {current_epoch}/errors', fig, global_step=global_step)

        fig = make_all_edge_error_subplots(scs, preds, true_edge_max=1, num_bins_error=15, num_bins_true=5)
        logger.experiment.add_figure(f'epoch {current_epoch}/All errors', fig, global_step=global_step)
        """
        ######### ▲▲▲▲  Distribution of (Pred - True) edge weights for subset/all validation scans ▲▲▲▲ #########


if __name__ == "__main__":
    print('custom callback main')