from typing import List, Any, Optional
import math, sys, numpy as np, torch, pytorch_lightning as pl
from collections import OrderedDict
from torch import nn
from torch.nn import functional as F
from pytorch_lightning import seed_everything
from typing import Any, Dict, List, Optional, Type
from pathlib import Path

file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from utils import adj2vec, vec2adj, normalize_slices, mimo_tensor_polynomial, adjs2fvs
from metrics import best_threshold_by_metric
try:
    from gdn_utils import filter_repeats, format_color, shallowest_layer_all_zero, resample_params, clamp_tau, \
        construct_prior, hinge_loss, \
        apply_mask, print_subnet_perf_dict, percent_change_metrics
except:
    from unroll.gdn.gdn_utils import filter_repeats, format_color, shallowest_layer_all_zero, resample_params, clamp_tau, \
        construct_prior, hinge_loss, \
        apply_mask, print_subnet_perf_dict, percent_change_metrics

from metrics import compute_metrics


def gdn_layer_input_verification(tau_info, normalization_info):
    if tau_info['type'] == 'scalar':
        assert tau_info['learn'] in [True, False]
        assert 'low' in tau_info and 'high' in tau_info
    elif tau_info['type'] == 'mlp':
        assert False, f'NEED TO THINK MORE ABOUT USING MLP FOR TAU'
        assert 'mlp' in tau_info['mlp']
        tau_mlp = tau_info['mlp']
        assert 'num_features' in tau_mlp and 'h_layer_size' in tau_mlp and 'depth' in tau_mlp
    else:
        raise ValueError(f"Unrecognized tau type {tau_info['type']}: must be scalar or MLP")

    assert normalization_info['method'] in ["frob", "max_abs", "percentile", "none", None]
    assert normalization_info['where'] in ['before_reduction', 'after_reduction']


class gdn_layer(nn.Module):

    def __init__(self,
                 c_in: int = 1,
                 c_out: int = 1,
                 order_poly_fc: int = 1,
                 gradient_ablation: bool = False,  # removes all params except tau
                 tau_info: Dict = {'type': 'scalar', 'learn': True, 'low': .01, 'high': .5},
                 include_nonlinearity: bool = True,
                 poly_basis: Optional[str] = 'cheb',  # basis (standard, cheb) used for P(A_hat)
                 channel_reduction=torch.mean,
                 normalization_info: Dict = {'method': 'max_abs', 'value': 99, 'where': 'after_reduction'}):
        super().__init__()
        self.c_in, self.c_out = c_in, c_out

        # parameters
        self.order_poly_fc = order_poly_fc
        assert order_poly_fc >= 1, f'order of polynomial of A_O must be >= 1'
        self.gradient_ablation = gradient_ablation
        self.stdv_scaling = 1 / 3  # scale N(0,1)*stdv_sampling-> N(0,stdv_sampling^2)
        self.init_alpha_sample_mean, self.init_beta_sample_mean = 1, 0
        # alpha scales the input recieved from last layer. Make it ~1 -> making small changes to previous layer.
        self.alpha = nn.Parameter(torch.randn(c_out, c_in) * self.stdv_scaling + self.init_alpha_sample_mean)  # N(1, 1/9)
        # beta scales the part of the gradient. Make it ~0.
        self.beta = nn.Parameter(torch.randn(c_out, c_in) * self.stdv_scaling + self.init_beta_sample_mean)

        self.k_3 = nn.Parameter(torch.randn(c_out, c_in) * self.stdv_scaling + self.init_alpha_sample_mean)
        self.k_4 = nn.Parameter(torch.randn(c_out, c_in) * self.stdv_scaling + self.init_alpha_sample_mean)
        self.commute = nn.Parameter(torch.randn(c_out, c_in) * self.stdv_scaling + self.init_alpha_sample_mean)

        # these define the coefficients of the polynomial of the FC. Make these all ~0 except for the 1st order term,
        # which should be ~1. because we group it with an unscaled A_in. See paper.
        # Note that 0th order term is NOT used because it only affects diagonal, but will be included in parameter
        # count by pytorch.
        # coeffs_poly_fc[i] in R^(order+1 x c_in) are all polynomial coeffs used by output channel i.
        # coeffs_poly_fc[i, j] in R^c_in are the poly coeffs corresponding to jth polynomial basis
        # coeffs_poly_fc[i, j, k] in R is the poly coeff corresponding to jth polynomial basis for the
        # k^th input channel
        self.coeffs_poly_fc = nn.Parameter(torch.randn(c_out, order_poly_fc + 1, c_in) * self.stdv_scaling)
        with torch.no_grad():
            self.coeffs_poly_fc[:, 1, :] = 1

        gdn_layer_input_verification(tau_info, normalization_info)
        if tau_info['type'] == 'scalar':
            tau = torch.linspace(start=tau_info['low'], end=tau_info['high'], steps=c_out).view(c_out, 1, 1, 1)
            if tau_info['learn']:
                self.tau = nn.Parameter(tau)
            else:
                self.register_buffer('tau', torch.tensor(tau, dtype=torch.float32), persistent=True)
        elif tau_info['type'] == 'mlp':
            assert False, f'need to do this'
            self.tau_mlp = tauNN(**tau_info['mlp'])
        self.tau_info = tau_info

        # end of parameters
        self.normalization_info = normalization_info

        self.include_nonlinearity = include_nonlinearity

        self.poly_basis = poly_basis
        assert poly_basis in ['cheb', 'standard', None]
        self.channel_reduction = channel_reduction

        self.output_zeros = False

    def forward(self, S_in, A_O, extra_outs=False, layer=None, normalize=True):
        # S_in = [c_in, batch_size, N, N]
        # A_O = [batch_size, N, N]
        assert (len(S_in.shape) == 4) and (len(A_O.shape) == 3) and (S_in.shape[1] == A_O.shape[0])
        assert (S_in.shape[-1] == S_in.shape[-2]) and (A_O.shape[-1] == A_O.shape[-2])
        c_in = S_in.shape[0]
        assert (c_in == self.c_in), f'input has {c_in} input channels but expected {self.c_in} input channels'
        batch_size, N, _ = A_O.shape
        assert torch.all(self.tau >= 0), f'tau is negative {self.tau}'

        # c0*I + c1*A_O + c2*A_O^2 + ...
        poly_A_O = mimo_tensor_polynomial(A_O.expand(S_in.shape), self.coeffs_poly_fc, cob=self.poly_basis)

        # shape of intermediate tensors: [c_out, c_in, batch_size, N, N]
        intermed_shape = (self.c_out, self.c_in, batch_size, N, N)
        A, A_O = S_in.expand(intermed_shape), A_O.expand(intermed_shape)

        # higher order gradient terms from || A_O - H(A; h) ||_F^2: k_i -> k == i
        # A_O*A + A*A_O
        k_2_grad = A.matmul(A_O) + A_O.matmul(A)
        # A^2*A_O + A*A_O*A + A_O*A^2
        #k_3_grad = A.matmul(k_2_grad) + A_O.matmul(A.matmul(A))
        # A^3*A_O + A^2*A_O*A + A^1*A_O*A^2 + A_O*A^3
        #k_4_grad = A.matmul(k_3_grad) + A_O.matmul(A.matmul(A).matmul(A))

        # gradient term from || A_O*A - A*A_O ||_F^2
        # A_O^2*A + 2*A_O*A*A_O + A*A_O^2
        #commute_gradient = A_O.matmul(k_2_grad) + k_2_grad.matmul(A_O)

        if not self.gradient_ablation:
            temp = \
                self.alpha.view(self.c_out, self.c_in, 1, 1, 1) * A \
                + self.beta.view(self.c_out, self.c_in, 1, 1, 1) * k_2_grad \
                + poly_A_O
                #+ self.commute.view(self.c_out, self.c_in, 1, 1, 1) * commute_gradient
                #+ self.k_3.view(self.c_out, self.c_in, 1, 1, 1) * k_3_grad \
                # + self.k_4.view(self.c_out, self.c_in, 1, 1, 1) * k_4_grad \

        else:
            temp = A + poly_A_O

        # Project onto adjacencies with no zero'd diagonal
        # zd = 0 on on slice diagonals, 1s everywhere else.
        zd = (torch.ones((N, N), device=S_in.device) - torch.eye(N, device=S_in.device)).expand(intermed_shape)
        temp = temp * zd

        # Normalizing here forces all slices to have entries in [-1, 1].
        # Thus by reducing (mean)right after this, entries in output channel will still be in [-1, 1].
        # But this limits the ability of the model to scale input channels differently.
        if 'before' in self.normalization_info['where'] and normalize:
            temp = normalize_slices(temp, which_norm=self.normalization_info['method'], extra=self.normalization_info['value'])

        # this can be a mean or sum reduction over the channels
        temp = self.channel_reduction(temp, dim=1)
        # temp = torch.mean(temp, dim=1)

        # Normalizing here forces entries in output channels to be in [-1, 1] but also allows relative scaling between
        # input channels to be learning.
        if 'after' in self.normalization_info['where'] and normalize:
            temp = normalize_slices(temp, which_norm=self.normalization_info['method'], extra=self.normalization_info['value'])

        # tau is either a direct learned parameter or output of MLP. Shared MLP for entire layer.
        tau = self.tau_mlp(adjs2fvs([A_O, A, temp])).view(temp.shape) if (self.tau_info['type'] == 'mlp') else self.tau

        temp = F.relu(temp - tau) if self.include_nonlinearity else (temp - tau)

        S_out = temp
        self.output_zeros = False
        if torch.allclose(S_out, torch.zeros(S_out.shape, device=S_out.device)):  # and False:
            self.output_zeros = True
            print('==================================')
            print('WARNING: A layer is outputting an all 0 S_out')
            # print(cdp(self.tau), 'tau: threshold cutoff')
            print('==================================')

        if extra_outs:
            return S_out, S_in, temp
        else:
            return S_out


class gdn(pl.LightningModule):
    def __init__(self,
                 # architecture
                 depth: int,
                 mimo_architecture: List[int],
                 share_parameters: bool,
                 poly_fc_orders: [List[int]],
                 gradient_ablation: bool = False,
                 include_nonlinearity: bool = True,
                 poly_basis: Optional[str] = 'cheb',  # basis (standard, cheb) used for P(A_hat)
                 channel_reduction= torch.mean,
                 # normalization procedure: where and what type
                 normalization_info: Dict = {'method': 'max_abs', 'value': 99, 'where': 'after_reduction',
                                             'norm_last_layer': False},
                 # loss, optimizer
                 which_loss: str = 'hinge',
                 monitor: str = 'error',  # monitor is used by optimizer to check to reduce lr, stop running, etc
                 gamma: float = 0.3,
                 optimizer: str = 'adam',
                 learning_rate: float = .01,
                 adam_beta_1: float = .85, adam_beta_2: float = .99, momentum: float = .95,
                 hinge_margin: float = 1.0, hinge_slope: float = 1.0,
                 weight_decay: float = 0,
                 l2_strength: float = 0.0, l1_strength: float = 0.0, l2_cutoff: float = 0.0, l1_cutoff: float = 0.0,
                 # subnetwork of interest: None for synthetics.
                 # # 'full' means entire network. for brain data, other options are ['occipital', 'frontal', ..., 'full' ] etc
                 real_network: Optional[str] = 'full',
                 # tau: low/high (for init), max_clamp_val for max allowable tau value
                 tau_info: Dict = {'learn': True, 'low': .01, 'high': .99, 'max_clamp_val': .99},
                 # prior
                 n_train: int = 68,
                 prior_construction: Optional[str] = 'zeros', #'mean',
                 prior_frac_contains: float = 0.0,  # used for when construction prior with mean/median. What fraction of samples should contain an edge for us to count it in prior (use in mean/median calculation)?
                 learn_prior: bool = False,
                 # reproducability
                 seed: int = 50,
                 # threshold finding
                 threshold_metric: str = 'acc'  # which metric to use when choosing threshold value
                 ):
        super().__init__()
        self.save_hyperparameters()
        seed_everything(seed, workers=True)
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/1225
        # for logging of hparams and computationa graph
        # self.example_input_array = torch.zeros(10, 68, 68)

        assert depth == len(mimo_architecture) - 1, f'depth defined as number of layers'
        assert len(poly_fc_orders) == depth
        assert channel_reduction in [torch.mean, torch.sum]# ['mean', 'sum']


        #### Prior Construction and stuff to make model work ####
        assert prior_construction in ['zeros', 'ones', 'median', 'mean', 'multi', 'block', 'sbm'], f'prior_construction is {prior_construction}'
        assert prior_construction != 'multi', f'must change mimo_architecture and prior_prep to accomodate >1 prior channel'

        # register_buffer() wont save tensor in state_dict if fed None. Thus we must know what size graphs
        # are used for training to initialize training_prior tensor (so it can be save/loaded later on)
        num_prior_channels = 1  # FORCE THIS TO BE
        if learn_prior:
            # prior parameter as a vector. Convert to adjacency matrix form with vec2adj
            num_edges = n_train * (n_train - 1) // 2
            self.learned_prior = nn.Parameter(torch.randn(1, num_edges) / 10)
        self.register_buffer('training_prior', -1 * torch.ones(size=(num_prior_channels, n_train, n_train)), persistent=True)
        # best threshold found on training set when training_prior was constructed
        self.register_buffer('training_prior_threshold', -1 * torch.ones(num_prior_channels, 1), persistent=True)
        # best threshold found on validation set after training done, to be used for testing
        self.register_buffer('testing_prior_threshold', -1 * torch.ones(num_prior_channels, 1), persistent=True)

        self.prior_channels, self.prior_channel_names = None, None

        # self.register_buffer('real_data_prior', torch.zeros(1, 68, 68), persistent=True)
        self.register_buffer('threshold', torch.tensor([-1.0]), persistent=True)
        self.test_threshold = None  # placeholder, will save in checkpoint and load when needed
        # self.register_buffer('test_threshold', torch.tensor([-1.0]), persistent=True)
        ######

        # on validation set
        self.list_of_metrics = []
        self.prior_metrics_val, self.prior_metrics_test = {}, {}

        common_layer_params = {'poly_basis': poly_basis,
                               'gradient_ablation': gradient_ablation,
                               'normalization_info': normalization_info,
                               'tau_info': tau_info,
                               'channel_reduction': channel_reduction,
                               'include_nonlinearity': include_nonlinearity}
        if share_parameters:
            print(f'Given {mimo_architecture} and Sharing Parameters.  Mimo_architecture must be the flat (= same)!')
            assert all([c == mimo_architecture[0] for c in mimo_architecture]), f'for shared parameters, all input/output must be same'
            print(f'Shared parameters. Looping over Single MIMO layer with c_in/c_out = {mimo_architecture[0]}, {len(mimo_architecture) - 1} times')
            c = mimo_architecture[0]  # they must all be the same
            single_layer = gdn_layer(**common_layer_params, c_in=c, c_out=c, order_poly_fc=poly_fc_orders[0])
            layers = [single_layer for _ in range(depth)]
        else:
            print('\tNot sharing parameters. First channel overridden for proper prior dim. Last channel dim will be overridden to reduce unused params')
            print(f'\t\tGIVEN {mimo_architecture}', end="")
            mimo_architecture[0], mimo_architecture[-1] = 1, 1
            print(f' ---> USING {mimo_architecture}')

            layers = []
            for layer in range(depth):
                # last layer no norm
                c_in, c_out = mimo_architecture[layer], mimo_architecture[layer + 1]
                layers.append(gdn_layer(**common_layer_params, c_in=c_in, c_out=c_out,
                                        order_poly_fc=poly_fc_orders[layer]))

        self.layers = nn.ModuleList(layers)

        # private variables to be used for printing/logging
        self.epoch_val_losses = []

        self.subnetwork_masks = None
        self.checkpoint_loaded = False  # are we being loaded from a checkpoint?

        self.min_output = torch.tensor([-1])
        self.max_output = torch.tensor([1])

    def setup(self, stage: Optional[str]):
        self.subnetwork_masks = self.trainer.datamodule.subnetwork_masks
        # we must construct prior for the first time and save it in self.training_prior. Done in self.construct_prior_()
        if stage in ['fit'] and not self.checkpoint_loaded:
            # assert torch.allclose(self.training_prior, -1*torch.ones(1)), f"init'ed to all -1. This is how we know not used yet."
            self.prior_channels, self.prior_channel_names = self.construct_prior_(prior_dl=self.trainer.datamodule.train_dataloader())
            try:
                train_scs, val_scs = self.trainer.datamodule.train_dataloader().dataset.full_ds()[1], \
                                     self.trainer.datamodule.val_dataloader().dataset.full_ds()[1]
            except AttributeError:  # AttributeError: 'Subset' object has no attribute 'full_ds'
                train_scs = torch.cat([a[1].unsqueeze(dim=0) for a in self.trainer.datamodule.train_dataloader().dataset], dim=0)
                val_scs = torch.cat([a[1].unsqueeze(dim=0) for a in self.trainer.datamodule.val_dataloader().dataset], dim=0)

            for i, (prior_channel, prior_channel_name) in enumerate(zip(self.prior_channels, self.prior_channel_names)):
                # find performance of prior channel on validation set by finding optimal threshold (on training set) and
                # reporting performance
                self.training_prior_threshold[i], self.prior_metrics_val[prior_channel_name] = \
                    self.prior_performance(prior_channel=prior_channel, prior_channel_name=prior_channel_name,
                                           holdout_set_threshold=train_scs, holdout_set_metrics=val_scs)

        if stage in ['test'] and not self.checkpoint_loaded:  # and same_size=True
            try:
                val_scs, test_scs = self.trainer.datamodule.val_dataloader().dataset.full_ds()[1], \
                                    self.trainer.datamodule.test_dataloader().dataset.full_ds()[1]
            except AttributeError:  # AttributeError: 'Subset' object has no attribute 'full_ds'
                val_scs = torch.cat([a[1].unsqueeze(dim=0) for a in self.trainer.datamodule.val_dataloader().dataset], dim=0)
                test_scs = torch.cat([a[1].unsqueeze(dim=0) for a in self.trainer.datamodule.test_dataloader().dataset], dim=0)

            # if testing is NOT right after train (loading from checkpoint), self.prior_channels is None
            # use validation data to find best threshold for PRIOR as prediction on test set
            print(f'PERFORMANCE OF {self.hparams.prior_construction} PRIOR ON TEST SET USING VALIDATION SET TO FIND THRESHOLD')
            for i, (prior_channel, prior_channel_name) in enumerate(zip(self.prior_channels, self.prior_channel_names)):
                # find performance of prior channel on test set by finding optimal threshold (on val set) and
                # reporting performance
                self.testing_prior_threshold[i], self.prior_metrics_test[prior_channel_name] = \
                    self.prior_performance(prior_channel=prior_channel, prior_channel_name=prior_channel_name,
                                           holdout_set_threshold=val_scs, holdout_set_metrics=test_scs)

    def prior_performance(self, prior_channel, prior_channel_name, holdout_set_threshold, holdout_set_metrics):
        # find best threshold to use to optimize acc on train set
        # optimize threshold on training set
        threshold = \
            best_threshold_by_metric(
                y_hat=adj2vec(prior_channel.expand(holdout_set_threshold.shape).detach()),
                y=adj2vec(holdout_set_threshold.detach()),
                thresholds=self.hparams.threshold_metric_test_points,
                metric=self.hparams.threshold_metric,
                non_neg=True
            )

        # find performance on validation set for each subnetwork using threshold found on train set
        y, y_hat = holdout_set_metrics.detach(), torch.broadcast_to(prior_channel, holdout_set_metrics.shape).detach()
        subnetwork_metrics = {}
        for subnetwork_name, subnetwork_mask in self.subnetwork_masks.items():
            y_subnet = adj2vec(apply_mask(y, subnetwork_mask))
            y_hat_subnet = adj2vec(apply_mask(y_hat.detach(), subnetwork_mask))
            subnetwork_metrics[subnetwork_name] = compute_metrics(y_hat=y_hat_subnet, y=y_subnet, threshold=threshold,
                                                                  non_neg=self.trainer.datamodule.non_neg_labels,
                                                                  self_loops=self.trainer.datamodule.self_loops)
        # display metrics found
        print(f"Prior {prior_channel_name} metrics using {threshold}")
        print(f'ON VAL')
        # only interested in mean at the moment...TODO: print stde
        mean_metrics = {}
        for subnetwork, metric_dict in subnetwork_metrics.items():
            mean_metrics[subnetwork_name] = {}
            for metric_name, metric_values in metric_dict.items():
                mean_metrics[subnetwork_name][metric_name] = torch.mean(metric_values)
        # mean, stde = self.aggregate_step_outputs(outputs={'metrics': subnetwork_metrics['full']})
        # mean_metrics = {subnetwork: {metric_name: torch.mean(metric_values[metric_name])} for metric_name, metric_values in metrics.keys() for subnetwork, metrics in subnetwork_metrics.items()}
        print_subnet_perf_dict(subnetwork_metrics_dict=mean_metrics,  # {'full': mean_metrics},
                               indents=2, convert_to_percent=['acc', 'error'],
                               metrics2print=['se', 'ae', 'nmse', 'acc', 'error', 'mcc'])
        return threshold, mean_metrics  # subnetwork_metrics

    def on_train_start(self) -> None:
        if self.hparams.learn_prior:
            with torch.no_grad():
                # initilized with some noise to break symmetry problem. Add on top of this.
                self.learned_prior.data += adj2vec(self.prior_channels[0]).to(self.learned_prior.device)
                # display
                # import matplotlib.pyplot as plt
                # plt.imshow(graph_learning_utils.vec2adj(self.learned_prior.data, self.hparams.n_train).squeeze())
        # add mean/stde values
        for prior_channel_name, metrics_per_subnet in self.prior_metrics_val.items():
            for subnet_name, subnetwork_metrics in metrics_per_subnet.items():
                name = f'val/{prior_channel_name} prior/'
                for metric_name, value in subnetwork_metrics.items():
                    self.log(name=name + metric_name + '/mean', value=value)

        for prior_channel_name, metrics_per_subnet in self.prior_metrics_test.items():
            for subnet_name, subnetwork_metrics in metrics_per_subnet.items():
                name = f'test/{prior_channel_name} prior/'
                for metric_name, value in subnetwork_metrics.items():
                    self.log(name=name + metric_name + '/mean', value=value)

    def forward(self, batch) -> Any:
        return self.shared_step(batch=batch)

    def shared_step(self, batch, int_out=False):
        # fcs, adjs, subject_ids, scan_dirs, tasks = batch
        fcs, adjs = batch[:2]
        batch_size, N = fcs.shape[:-1]

        prior = self.prior_prep(batch_size=batch_size, N=N)
        s_in = prior
        intermediate_outputs = []

        s_out = s_in
        for i, layer in enumerate(self.layers):
            s_out = layer(s_out, fcs, layer=i, normalize=self.hparams.normalization_info['norm_last_layer'] if i == (self.hparams.depth - 1) else True)
            if int_out:
                intermediate_outputs.append(s_out.squeeze())

        # s_out shape = [c_out, bs, N, N], where c_out of final layer.
        # For non-shared_params c_out will always be 1. For shared params will be>=1.
        # Thus simply only look at 0th output channel for loss.
        adjs_hat = s_out[0]

        return adjs_hat if not int_out else (adjs_hat, intermediate_outputs)

    def training_step(self, batch, batch_idx):
        x, y = batch[0:2]

        y_hat, intermediate_outputs = self.shared_step(batch, int_out=True)
        self.max_output = y_hat.max() if not torch.isnan(y_hat.max()) else torch.tensor([3])  # IGNORE DIAG IN MAX SEARCH
        self.min_output = y_hat.min() if not torch.isnan(y_hat.min()) else torch.tensor([-3])

        loss = self.compute_intermed_loss(intermediate_outputs=intermediate_outputs, y=y, per_edge_loss=True)
        # y_hat = self.shared_step(batch)
        # loss = self.compute_loss(y_hat=y_hat, y=y, N=self.hparams.n_train, bs=len(y_hat))
        subnetwork_metrics = self.compute_subnetwork_metrics(y=y, y_hat=y, threshold=self.threshold[0])
        return {'loss': loss, 'metrics': subnetwork_metrics, 'batch_size': len(x)}

    def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None:
        # find the shallowest layer where model outputting ALL zeros. -1 if no layer outputting all zeros.
        sl = shallowest_layer_all_zero(self)

        if sl > -1:
            # sl = shallowest_layer_all_zero(self)
            print(f'\n\t~~~~shallowest layer with zero outputs = {sl}. Resample param vals of layer.~~~')
            module = self.layers[sl]
            resample_params(module)

        # ensure tau stays in [0. ~99]
        if self.hparams.tau_info['type'] == 'scalar' and self.hparams.tau_info['learn']:
            clamp_tau(self, large_tau=self.hparams.tau_info['max_clamp_val'])

        if self.hparams.learn_prior and False:
            # inspect gradients
            import matplotlib.pyplot as plt
            plt.imshow(vec2adj(self.learned_prior.grad, self.hparams.n_train).squeeze())
            plt.colorbar()

        return

    def training_epoch_end(self, train_step_outputs):
        # this must be overridden for batch outputs to be fed to callback. Bug.
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/4326

        # should this be done in val_epoch_end? Only when find best val_loss?
        avg_loss = torch.stack([x['loss'] for x in train_step_outputs]).mean()
        # self.log("epoch_avg_train_loss", avg_loss, logger=False, prog_bar=True)
        self.log(name=f'train/{self.hparams.which_loss}_epoch', value=avg_loss, on_step=False, on_epoch=True)

        outputs = self.aggregate_subnetwork_step_outputs(outputs=train_step_outputs)
        self.log_subnetwork_metrics(outputs, stage='train')
        # means, stdes = self.aggregate_step_outputs(outputs=train_step_outputs)
        # self.log_metrics(means, stdes, stage='train')
        return None

    def on_validation_start(self) -> None:
        # find best threshold only once before every validation epoch
        # use training set to optimize threshold during training
        test_points = torch.linspace(start=self.min_output.item(), end=self.max_output.item(), steps=200)
        self.threshold[0] = self.find_threshold(dl=self.trainer.datamodule.train_dataloader(),
                                                threshold_test_points=test_points, #self.hparams.threshold_metric_test_points,
                                                metric2chooseThresh=self.hparams.threshold_metric)
        self.log('threshold', self.threshold, prog_bar=True, on_epoch=True, on_step=False)

    def validation_step(self, batch, batch_idx):
        # return self.eval_step(batch, batch_idx, threshold=self.threshold, reduction='sum')
        x, y = batch[0:2]
        y_hat, intermediate_outputs = self.shared_step(batch, int_out=True)
        loss = self.compute_intermed_loss(intermediate_outputs=intermediate_outputs, y=y, per_edge_loss=True)
        subnetwork_metrics = self.compute_subnetwork_metrics(y=y, y_hat=y_hat, threshold=self.threshold[0])
        return {'loss': loss, 'metrics': subnetwork_metrics, 'batch_size': len(x)}

    def validation_epoch_end(self, val_step_outputs):
        stage = 'val'
        outputs = self.aggregate_subnetwork_step_outputs(outputs=val_step_outputs)
        self.log_subnetwork_metrics(outputs, stage=stage)
        self.list_of_metrics.append({**outputs, 'epoch': self.current_epoch})

        self.progress_bar_update(outputs=outputs)

        mean_outputs = {subnetwork: subnetwork_metrics['mean'] for subnetwork, subnetwork_metrics in outputs.items()}
        perc_change_curr_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val,
                                                                 prediction_metrics=mean_outputs)
        self.log_change_relative_to_prior(outputs=mean_outputs, perc_change_curr_priors_subnets=perc_change_curr_priors_subnets, stage=stage)

        # the best we've done vs how were doing now
        self.print_training_progress(outputs=mean_outputs, perc_change_curr=perc_change_curr_priors_subnets[self.hparams.prior_construction][self.hparams.real_network])

        # to able to see/log lr, need to do this
        current_lr = self.trainer.optimizers[0].state_dict()['param_groups'][0]['lr']
        self.log("lr", round(current_lr, 10), logger=False, prog_bar=True)

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        checkpoint['prior_metrics_val'] = self.prior_metrics_val
        checkpoint['prior_channels'] = self.prior_channels
        checkpoint['prior_channel_names'] = self.prior_channel_names
        checkpoint['list_of_metrics'] = self.list_of_metrics
        test_points = torch.linspace(start=self.min_output.item(), end=self.max_output.item(), steps=200)
        checkpoint['test_threshold'] = self.find_threshold(dl=self.trainer.datamodule.val_dataloader(),
                                                           threshold_test_points=test_points, #self.hparams.threshold_metric_test_points,
                                                           metric2chooseThresh=self.hparams.threshold_metric)
        # self.log('test_threshold', self.test_threshold, on_epoch=True, on_step=False)
        print(f"\n\tsaving checkpoint: saving threshold found {checkpoint['test_threshold']:.3f} using validation set during training")

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        self.prior_metrics_val = checkpoint['prior_metrics_val']
        self.prior_channels = checkpoint['prior_channels']
        self.prior_channel_names = checkpoint['prior_channel_names']
        self.list_of_metrics = checkpoint['list_of_metrics']

        self.test_threshold = checkpoint['test_threshold']
        self.checkpoint_loaded = True  # ensures we dont do model setup procedure again
        print(f"\nLoading threshold found using validaiton set during training: {self.test_threshold:.5f} which achieved {self.list_of_metrics[-1]['full']['mean']['error'] * 100:.4f}% error")

    """
    def on_test_start(self) -> None:
        # find best threshold only once before every validation epoch
        # use training set to optimize threshold during training
        # this is to be used to threshold the MODEL OUTPUTS
        self.test_threshold[0] = self.find_threshold(dl=self.trainer.datamodule.val_dl, threshold_test_points=self.hparams.threshold_metric_test_points, metric2chooseThresh=self.hparams.threshold_metric)
        self.log('test_threshold', self.test_threshold, prog_bar=True, on_epoch=True, on_step=False)
    """

    def test_step(self, batch, batch_idx):
        x, y = batch[0:2]
        y_hat, intermediate_outputs = self.shared_step(batch, int_out=True)
        loss = self.compute_intermed_loss(intermediate_outputs=intermediate_outputs, y=y, per_edge_loss=True)
        subnetwork_metrics = self.compute_subnetwork_metrics(y=y, y_hat=y_hat, threshold=self.test_threshold)
        return {'loss': loss, 'metrics': subnetwork_metrics, 'batch_size': len(x)}

    def test_epoch_end(self, test_step_outputs: List[Any]) -> None:
        stage = 'test'
        outputs = self.aggregate_subnetwork_step_outputs(outputs=test_step_outputs)
        self.log_subnetwork_metrics(outputs, stage=stage)
        mean_outputs = {subnetwork: subnetwork_metrics['mean'] for subnetwork, subnetwork_metrics in outputs.items()}
        perc_change_curr_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val,
                                                                 prediction_metrics=mean_outputs)
        self.log_change_relative_to_prior(outputs=mean_outputs,
                                          perc_change_curr_priors_subnets=perc_change_curr_priors_subnets, stage=stage)

        self.log_best_validation_metrics()

    def configure_optimizers(self):
        # optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.momentum)#, weight_decay=self.weight_decay)

        lr = self.hparams.learning_rate
        if 'adam' in self.hparams.optimizer:
            b1, b2 = self.hparams.adam_beta_1, self.hparams.adam_beta_2
            optimizer = torch.optim.Adam(self.parameters(), lr=lr, betas=(b1, b2))  # , weight_decay=self.weight_decay)
        elif 'sgd' in self.hparams.optimizer:
            optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=self.hparams.momentum)
        else:
            raise ValueError(f'only configured Adam and SGD optimizzer. Given {self.hparams.optimizer}')
        optimizers = [optimizer]
        return optimizer

    ### HELPER METHODS ###
    @torch.no_grad()
    def compute_subnetwork_metrics(self, y, y_hat, threshold):
        subnetwork_metrics = {}
        for subnetwork_name, subnetwork_mask in self.subnetwork_masks.items():
            y_subnet = adj2vec(apply_mask(y, subnetwork_mask))
            y_hat_subnet = adj2vec(apply_mask(y_hat.detach(), subnetwork_mask))
            subnetwork_metrics[subnetwork_name] = compute_metrics(y_hat=y_hat_subnet, y=y_subnet, threshold=threshold,
                                                                  non_neg=self.trainer.datamodule.non_neg_labels,
                                                                  self_loops=self.trainer.datamodule.self_loops)
        return subnetwork_metrics

    @torch.no_grad()
    def aggregate_subnetwork_step_outputs(self, outputs):
        subnetworks = outputs[0]['metrics'].keys()
        subnetwork_aggregate_outputs = {}
        for subnetwork in subnetworks:
            subnet_outputs = [{'metrics': o['metrics'][subnetwork], 'batch_size': o['batch_size']} for o in outputs]
            mean, stde = self.aggregate_step_outputs(subnet_outputs)
            subnetwork_aggregate_outputs[subnetwork] = {'mean': mean, 'stde': stde}
        return subnetwork_aggregate_outputs

    @torch.no_grad()
    def aggregate_step_outputs(self, outputs):
        # aggregate all outputs from step batches
        # total_epoch_samples = torch.stack([torch.tensor(x['batch_size']) for x in outputs]).sum()
        all_sample_metrics = {m: [] for m in outputs[0]['metrics'].keys()}
        for output in outputs:
            for metric_name, metric_values in output['metrics'].items():
                all_sample_metrics[metric_name].append(metric_values)

        # combine list of tensors into one large tensor

        for metric_name, metric_values in all_sample_metrics.items():
            if 'glad' in metric_name:
                # will throw error bc cant concat single dim tensors
                all_sample_metrics[metric_name] = torch.tensor(metric_values)
            else:
                all_sample_metrics[metric_name] = torch.cat(metric_values)

        # compute mean and standard error of each
        means, stdes = {}, {}
        for metric_name, metric_values in all_sample_metrics.items():
            means[metric_name] = torch.mean(metric_values)
            stdes[metric_name] = torch.std(metric_values) / math.sqrt(len(metric_values))

        return means, stdes

    # DOES NOT CHANGE ANY PRIVATE VARIABLES: STATIC METHOD
    @torch.no_grad()
    def find_threshold(self, dl, threshold_test_points, metric2chooseThresh):
        # use data (train or val) set to optimize threshold during training
        ys, y_hats = [], []
        for i, batch in enumerate(iter(dl)):  # loop so dont run out of memory
            batch[0] = batch[0].to(self.device)  # move fcs/scs to GPU (ligthning doesnt do this for us)
            batch[1] = batch[1].to(self.device)
            ys.append(batch[1])
            y_hats.append(self.shared_step(batch))
        y, y_hat = torch.cat(ys, dim=0), torch.cat(y_hats, dim=0)

        # loop over candidate thresholds, see which one optimizes threshold_metric (acc, mcc, se, etc)
        # over FULL network
        return best_threshold_by_metric(thresholds=threshold_test_points,
                                        y=y, y_hat=y_hat, metric=metric2chooseThresh, non_neg=True)

    def loss_func(self, y, y_hat, per_sample_loss, per_edge_loss):
        assert y.ndim == y_hat.ndim == 3
        batch_size, n = y.shape[:2]
        total_edges = batch_size * n * (n - 1) // 2

        if self.hparams.which_loss == 'nse':
            # NOTE NORMALIZATION WILL BE WRONG FOR nmse -> WE ARE ALREADY COMPUTING MEAN (reduction of batch size), THEN WILL DIVIDE BY BS *N *N edges
            #assert False, 'need to address normalization problem'
            reduction_dims = (1, 2) # if use_raw_adj else 1
            se = ((y - y_hat) ** 2).sum(dim=reduction_dims)
            nse = se / (y ** 2).sum(dim=reduction_dims)
            loss = nse.mean()
            return loss
            #loss = torch.divide(se, torch.linalg.norm(y, ord=2, dim=reduction_dims)).sum()
        elif self.hparams.which_loss == 'se':
            loss = F.mse_loss(y_hat, y, reduction='sum')
        elif self.hparams.which_loss == 'ae':
            loss = F.l1_loss(y_hat, y, reduction='sum')
        elif self.hparams.which_loss == 'hinge':
            #assert False, f'need to take into acocunt sign of edge!'
            loss = hinge_loss(y=y > 0, y_hat=y_hat, margin=self.hparams.hinge_margin,
                              slope=self.hparams.hinge_slope, per_edge=False).sum()
        else:
            raise ValueError(f'loss {self.hparams.which_loss} not recognized')
        if per_sample_loss:
            # averaged over samples in batch
            loss = loss / batch_size
        elif per_edge_loss:
            # averaged over each possible edge: N^2 for raw adj, N*(N-1)/2 for symm w/o self-loops
            loss = loss / total_edges

        return loss

    def compute_intermed_loss(self, intermediate_outputs, y, per_sample_loss=False, per_edge_loss=False):
        # intermediate outputs INCLUDE final output
        # assert y.ndim == 2
        assert len(intermediate_outputs) == self.hparams.depth
        assert not (per_sample_loss and per_edge_loss), f'can only choose at most one loss normalization'

        if torch.isclose(torch.tensor([self.hparams.gamma]), torch.zeros([1])):  # or self.hparams.which_loss == 'hinge':
            # just use straight loss from final outputs.
            return self.loss_func(y=y, y_hat=intermediate_outputs[-1].squeeze(),
                                  per_sample_loss=per_sample_loss, per_edge_loss=per_edge_loss)
        else:
            # use intermediate losses from intermediate outputs
            losses = torch.zeros(self.hparams.depth)
            for d, y_hat_i in enumerate(intermediate_outputs):
                # compute a loss for each intermediate output
                loss = self.loss_func(y=y, y_hat=y_hat_i.squeeze(),
                                      per_sample_loss=per_sample_loss, per_edge_loss=per_edge_loss)

                # weight loss by depth of unrolling: weight loss more as we get closer to end
                losses[d] = loss * self.hparams.gamma ** (self.hparams.depth - (d + 1))  # / self.hparams.depth

            return losses.sum()

    def compute_loss(self, y, y_hat, N, bs):
        if self.hparams.which_loss == 'se':
            loss = F.mse_loss(y_hat, y, reduction='sum')
        elif self.hparams.which_loss == 'ae':
            loss = F.l1_loss(y_hat, y, reduction='sum')
        elif self.hparams.which_loss == 'se+ae':
            loss = F.mse_loss(y_hat, y, reduction='sum') + F.l1_loss(y_hat, y, reduction='sum')
        elif self.hparams.which_loss == 'cross entropy':
            # y_hat_bce = torch.where(y_hat>1, 1)
            binary_adjs = torch.where(y > 0, 1, 0)
            loss = F.binary_cross_entropy_with_logits(input=y_hat, target=binary_adjs)
        elif self.hparams.which_loss == 'hinge':
            loss = hinge_loss(y=y, y_hat=y_hat, margin=self.hparams.hinge_margin, slope=self.hparams.hinge_slope,
                              per_edge=False).sum()
        elif self.hparams.which_loss == 'hinge+se':
            loss = hinge_loss(y=y, y_hat=y_hat, margin=self.hparams.hinge_margin, slope=self.hparams_hinge_slope,
                              per_edge=False).sum()
            loss += F.mse_loss(y_hat, y, reduction='sum')
        else:
            raise ValueError(f'which_loss {self.hparams.which_loss} not recognized')
        # link to implimentation: https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/regression/linear_regression.py
        # L1 regularizer
        if self.hparams.l1_strength > 0:
            l1_reg = sum(param.abs().sum() for param in self.parameters())
            loss += self.hparams.l1_strength * l1_reg

            # L2 regularizer
        if self.hparams.l2_strength > 0:
            l2_reg = sum(param.pow(2).sum() for param in self.parameters())
            loss += self.hparams.l2_strength * l2_reg

        denom = bs * N * N - bs * N  # ignore diagonals
        # sum_to_mean_constant = fcs.shape[0]*fcs.shape[1]*fcs.shape[2] - fcs.shape[1] #remove all zero diagonal from constant?? These will always be zero
        per_edge_loss = loss / denom
        return per_edge_loss

    def construct_prior_(self, prior_dl=None, prior_dtype=torch.float32):
        # constructing prior from data
        if self.hparams.prior_construction in ['mean', 'median', 'multi']:
            # create prior with prior_ds set (train set): prior_dl cannot be none
            assert prior_dl is not None, f'to construct prior from data, need train data'
            _, prior_scs, prior_subject_ids, _, _ = prior_dl.dataset.full_ds()
            unique_scs_train_set = filter_repeats(prior_scs, prior_subject_ids)
            if self.hparams.prior_construction in ['mean', 'median']:
                prior = construct_prior(unique_scs_train_set, frac_contains=self.hparams.prior_frac_contains,
                                        reduction=self.hparams.prior_construction)
                self.training_prior[0] = prior[0]
                prior_channels = [prior]
                prior_channel_names = [self.hparams.prior_construction]
            else:  # multi
                self.training_prior[0, 0] = construct_prior(unique_scs_train_set, frac_contains=self.hparams.prior_frac_contains, reduction='mean')
                self.training_prior[1, 0] = construct_prior(unique_scs_train_set, frac_contains=self.hparams.prior_frac_contains, reduction='median')
                prior_channels = [self.training_prior[0, 0], self.training_prior[1, 0]]
                prior_channel_names = ['mean', 'median']
        # these priors do not use real data. Reconstructed each time, only need graph size N.
        # This is useful when we want to test on different sized data than we trained on: simply feed in test set as
        # val_dl, and prior will be constructed appropriately.
        else:
            if self.hparams.prior_construction == 'block':
                block_scale = .35  # minimizes se
                assert self.hparams.n_train % 2 == 0, f'for block prior, n must be even (or in general divisible by number of communities'
                ones = torch.ones((self.hparams.n_train // 2), (self.hparams.n_train // 2))
                # self.prior = torch.block_diag(ones, ones).view(1, N, N)*block_scale
                prior_channels = [torch.block_diag(ones, ones).view(1, self.hparams.n_train, self.hparams.n_train) * block_scale]
                prior_channel_names = ['block']
                # for sc in np.arange(0.3, .4, .01):
                #    print(f'scale: {sc}', predicition_metrics(y_hat=sc*torch.block_diag(ones, ones).view(1, N, N).repeat(len(val_scs), 1, 1), y=val_scs, y_subject_ids=val_subject_ids))
            if self.hparams.prior_construction == 'sbm':
                prob_matrix = prior_dl.dataset.prob_matrix()
                prior_channels = [prob_matrix.expand(1, self.hparams.n_train, self.hparams.n_train).to(prior_dtype)]
                prior_channel_names = ['sbm']
            elif self.hparams.prior_construction in ['zeros']:
                # self.prior = torch.zeros(1, N, N, dtype=train_fcs.dtype)
                prior_channels = [torch.zeros(1, self.hparams.n_train, self.hparams.n_train, dtype=prior_dtype)]
                prior_channel_names = ['zeros']
            elif self.hparams.prior_construction in ['ones']:
                prior_channels = [torch.ones(1, self.hparams.n_train, self.hparams.n_train, dtype=prior_dtype)]
                prior_channel_names = ['ones']
            else:
                raise ValueError('unrecognized prior construction arg')
            self.training_prior[0] = prior_channels[0]

        return prior_channels, prior_channel_names

    def prior_prep(self, batch_size, N):
        # called by train_step/val_step/test_step to construct the prior tensor.

        # By NOT using self.training_prior (constructed in setup() during initial training) when we don't
        # need to (e.g. for mean/median), then we can test on graphs of sizes different than those trained on.
        prior_channels = 1 if (not self.hparams.share_parameters) else self.layers[0].c_in
        if self.hparams.learn_prior:
            # add learned prior on top of given prior => smaller gradients
            learned_prior = vec2adj(v=self.learned_prior, n=self.hparams.n_train).expand(prior_channels, batch_size, N, N)
            p = self.training_prior.expand(prior_channels, batch_size, N, N).to(learned_prior.device)
            return learned_prior + p
        # elif self.hparams.prior_construction == 'zeros':
        #    return x.expand(prior_channels, batch_size, N, N)
        elif self.hparams.prior_construction == 'zeros':
            return torch.zeros(size=(N, N), device=self.device).expand(prior_channels, batch_size, N, N)
        elif self.hparams.prior_construction == 'ones':
            return 0.5 * torch.ones(size=(N, N), device=self.device).expand(prior_channels, batch_size, N, N)
        elif self.hparams.prior_construction == 'block':
            block_scale = .35  # minimizes se for brain graphs
            assert (N % 2) == 0, f'block diagram requires even N'
            ones = torch.ones(size=((N // 2), (N // 2)), device=self.device)
            return torch.block_diag(ones, ones).expand(prior_channels, batch_size, N, N) * block_scale
        elif self.hparams.prior_construction in ['mean', 'median', 'sbm']:
            # must use training_prior
            assert self.training_prior.shape[-1] == N, f'when using prior constructed from data, we cannot test on data of different size than it'
            return self.training_prior.expand(prior_channels, batch_size, N, N)

    def best_metrics(self, sort_metric, sort_subnetwork='full', top_k=1, maximize=True):  # e['full'][sort_metric]
        sorted_list_of_metrics = sorted(self.list_of_metrics, key=lambda e: e[sort_subnetwork]['mean'][sort_metric],
                                        reverse=maximize)
        return sorted_list_of_metrics[:top_k]

    def extra_repr(self):
        return f'mimo_architecture {self.hparams.mimo_architecture}# Layers: {self.hparams.depth}'

    def log_change_relative_to_prior(self, outputs, perc_change_curr_priors_subnets, stage):
        # how much better are we doing than prior? POSITIVE = GOOD

        for prior_channel_name, subnetwork_dicts in perc_change_curr_priors_subnets.items():
            for subnetwork_name, subnetwork_metrics in subnetwork_dicts.items():
                # Note that all positive changes are 'good', and all negative changes are 'bad'
                for metric_name, value in subnetwork_metrics.items():
                    name = f'{stage}/{subnetwork_name}/current vs {prior_channel_name} prior'
                    self.log(name=name + '/' + metric_name, value=value)

    def log_best_validation_metrics(self):
        if self.prior_metrics_val is not None:
            logged_metric_names = ['val/full/se/mean', 'val/full/ae/mean', 'val/full/error/mean',
                                   'val/full/mcc/mean']  # , 'val/full/hinge/mean']
            for prior_channel_name in self.prior_metrics_val.keys():
                for log_metric_name in logged_metric_names:
                    metric_name = log_metric_name.split("/")[-2]
                    maximize = any(m in metric_name for m in ['acc', 'mcc', 'f1'])
                    best_epoch_metrics = \
                    self.best_metrics(sort_metric=metric_name, sort_subnetwork=self.hparams.real_network, maximize=maximize)[0]
                    perc_change_best_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val,
                                                                             prediction_metrics={
                                                                                 'full': best_epoch_metrics['full'][
                                                                                     'mean']})
                    perc_change_best_full = perc_change_best_priors_subnets[self.hparams.prior_construction][self.hparams.real_network]
                    # positive = good change
                    for metric_names_best, values_best in perc_change_best_full.items():
                        # at the epoch where metric_name was minimized, how did all the other metrics do? raw and vs prior
                        self.log(name=f'best_val_by_{metric_name}/' + 'full/' + metric_names_best,
                                 value=best_epoch_metrics['full']['mean'][metric_names_best])
                        self.log(
                            name=f'best_val_by_{metric_name}/' + f'full/current vs {prior_channel_name} prior/' + metric_names_best,
                            value=perc_change_best_full[metric_names_best])

    def print_training_progress(self, outputs, perc_change_curr):
        # print summary of training results on full network: green good, red bad
        print(f'\nPercent reduction (using training set ONLY for train/threshold finding, and eval on validation) over subnetwork "{self.hparams.real_network}" using loss: *{self.hparams.which_loss}*: {format_color("Good", "green")}/{format_color("Bad", "red")}')
        logged_metrics = ['val/full/error/mean', 'val/full/ae/mean', 'val/full/se/mean', 'val/full/nse/mean']
        for log_metric in logged_metrics:
            metric = log_metric.split("/")[-2]
            maximize = any(m in metric for m in ['acc', 'mcc', 'f1'])
            # best_epoch_metrics is now a dict of {'metric': {'mean': __, 'stde': __ } }
            best_epoch_metrics = \
            self.best_metrics(sort_metric=metric, sort_subnetwork=self.hparams.real_network, maximize=maximize)[0]
            perc_change_best_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val,
                                                                     prediction_metrics={'full': best_epoch_metrics['full']['mean']})
            perc_change_best = perc_change_best_priors_subnets[self.hparams.prior_construction][self.hparams.real_network]
            print(f"{f'  {metric}: Best ':<15}", end="")
            num_in_color = format_color(
                f"{perc_change_best[metric].abs():.5f}% ({best_epoch_metrics['full']['mean'][metric]:.5f})",
                color='green' if perc_change_best[metric] > 0 else 'red')
            best_ep = f" on epoch {best_epoch_metrics['epoch']}"
            print(f"{f'{num_in_color}  {best_ep}':<15}", end="")
            # print(f"{s:<50}", end="")
            s = ' | Current: ' + format_color(f"{perc_change_curr[metric].abs():.5f}% ({outputs['full'][metric]:.5f})",
                                              color='green' if perc_change_curr[metric] > 0 else 'red')
            print(s)

    def progress_bar_update(self, outputs):
        metrics_in_progress_bar = ['error', 'mcc', 'se', 'nse', 'ae']
        prog_bar_metrics_dict = {}
        for metric in metrics_in_progress_bar:
            value = 100 * outputs[self.hparams.real_network]['mean'][metric] if metric not in ['se', 'ae', 'hinge'] else \
                outputs[self.hparams.real_network]['mean'][metric]
            prog_bar_metrics_dict[metric] = value
        self.log_dict(prog_bar_metrics_dict, logger=False, prog_bar=True)

    def log_subnetwork_metrics(self, outputs, stage):
        for subnetwork_name, means_stdes in outputs.items():
            self.log_metrics(means=means_stdes['mean'],
                             stdes=means_stdes['stde'],
                             stage=stage + f'/{subnetwork_name}')

    def log_metrics(self, means, stdes, stage):

        for metric_name in means.keys():
            name = f'{stage}/{metric_name}'
            self.log(name=name + '/' + 'mean', value=means[metric_name])
            self.log(name=name + '/' + 'stde', value=stdes[metric_name])


if __name__ == "__main__":
    print('gdn main loop')
