from typing import List, Any, Optional

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

from typing import Any, Dict, List, Optional, Type

from utils.util_funcs import filter_repeats, normalize_slices, mimo_tensor_polynomial, format_color
from model.model_utils import shallowest_layer_all_zero, resample_params,  clamp_tau, best_threshold_by_metric,\
    construct_prior, hinge_loss, prediction_metrics_for_each_subnetwork
from utils.util_funcs import add_bool_arg, print_subnet_perf_dict, percent_change_metrics
from model.metrics import batch_graph_metrics # for graph size gen eval

DEBUG = False

#logistic regression loss: with logits: https://sebastianraschka.com/faq/docs/pytorch-crossentropy.html
class logistic_regression(nn.Module):
    def __init__(self):#, learn_bias=True):
        super().__init__()
        #self.learn_bias = learn_bias
        self.a = nn.Parameter(torch.tensor([5.0]))
        #self.bias = torch.ones((1,)) * -2.5
        #if learn_bias:
        self.bias = nn.Parameter(torch.tensor([-2.5]))

    def forward(self, x):
        #x is a tensor batch_size x N x N
        #perform logistic regression on every element of input tensor
        # with same parameters.
        # Note: diagnal are ok bc zero'd and output label is 0

        out = torch.mul(x, self.a) + self.bias #applied elementwise
        #out = torch.sigmoid(out) #when using CELoss with logits, leave out
        return out

    def extra_repr(self):
        return f'log regr: a = {self.a}, bias = {self.bias}'

    def resample(self, cuttoff_fmsr_pairs):
        #take best cuttoff by fmsr = bc
        # decision_boundary = bc = -bias/a
        #turn off gradients
        new_a = 5 #torch.copy...
        #new_bias = bc*self.a #torch copyyy


class mimo_prox_layer(nn.Module):

    def __init__(self,
                 c_in: int = 1,
                 c_out: int = 1,
                 order_poly_fc: int = 4,
                 low_tau: float = .01,
                 high_tau: float = .5,
                 learn_tau: bool = False,
                 include_nonlinearity: bool = True,
                 poly_basis: Optional[str] = 'cheb',  # basis (standard, cheb) used for P(A_hat)
                 channel_reduction = torch.mean,
                 where_normalize_slices: str = 'after_reduction',
                 which_slice_normalization: Optional[str] = 'max_abs',
                 slice_normalization_val: Optional[float] = 99):
        super().__init__()
        self.c_in, self.c_out = c_in, c_out

        # parameters
        self.order_poly_fc = order_poly_fc
        assert order_poly_fc >= 1, f'order of polynomial of A_hat_obs must be >= 1'
        self.stdv_scaling = 1/3 # scale N(0,1)*stdv_sampling-> N(0,stdv_sampling^2)
        self.init_alpha_sample_mean, self.init_beta_sample_mean = 1, 0
        # alpha scales the input recieved from last layer. Make it ~1 -> making small changes to previous layer.
        self.alpha = nn.Parameter(torch.randn(c_out, c_in)*self.stdv_scaling + self.init_alpha_sample_mean) # N(1, 1/9)
        # beta scales the part of the gradient. Make it ~0.
        self.beta = nn.Parameter(torch.randn(c_out, c_in)*self.stdv_scaling + self.init_beta_sample_mean)
        # these define the coefficients of the polynomial of the FC. Make these all ~0 except for the 1st order term,
        # which should be ~1. because we group it with an unscaled A_in. See paper.
        # Note that 0th order term is NOT used because it only affects diagonal, but will be included in parameter
        # count by pytorch.
        # coeffs_poly_fc[i] in R^(order+1 x c_in) are all polynomial coeffs used by output channel i.
        # coeffs_poly_fc[i, j] in R^c_in are the poly coeffs corresponding to jth polynomial basis
        # coeffs_poly_fc[i, j, k] in R is the poly coeff corresponding to jth polynomial basis for the
        # k^th input channel
        self.coeffs_poly_fc = nn.Parameter(torch.randn(c_out, order_poly_fc + 1, c_in)*self.stdv_scaling)
        with torch.no_grad():
            self.coeffs_poly_fc[:, 1, :] = 1

        self.learn_tau = learn_tau
        tau = torch.linspace(start=low_tau, end=high_tau, steps=c_out).view(c_out, 1, 1, 1)
        if self.learn_tau:
            self.tau = nn.Parameter(tau)#torch.tensor([low_tau]))
        else:
            self.register_buffer('tau', torch.tensor(tau, dtype=torch.float32), persistent=True)
        # end of parameters

        assert where_normalize_slices in ['before_reduction', 'after_reduction']
        self.where_normalize_slices = where_normalize_slices

        self.include_nonlinearity = include_nonlinearity

        self.poly_basis = poly_basis
        assert poly_basis in ['cheb', 'standard', None]
        self.channel_reduction = channel_reduction

        self.which_slice_normalization, self.slice_normalization_val = which_slice_normalization, slice_normalization_val
        assert which_slice_normalization in ["frob", "max_abs", "percentile", "none", None]
        self.output_zeros = False

    def forward(self, S_in, Cov, extra_outs=False, layer=None, normalize=True):
        # S_in = [c_in, batch_size, N, N]
        # Cov = [batch_size, N, N]
        assert (len(S_in.shape) == 4) and (len(Cov.shape) == 3) and (S_in.shape[1] == Cov.shape[0])
        assert (S_in.shape[-1] == S_in.shape[-2]) and (Cov.shape[-1] == Cov.shape[-2])
        c_in = S_in.shape[0]
        assert (c_in == self.c_in), f'input has {c_in} input channels but expected {self.c_in} input channels'
        batch_size, N, _ = Cov.shape
        assert torch.all(self.tau >= 0), f'tau is negative {self.tau}'

        # shape of intermediate tensors: [c_out, c_in, batch_size, N, N]
        intermed_shape = (self.c_out, self.c_in, batch_size, N, N)

        A = torch.broadcast_to(S_in, intermed_shape)
        A_hat_obs = torch.broadcast_to(Cov, intermed_shape)
        bilinear = (torch.matmul(A, A_hat_obs) + torch.matmul(A_hat_obs, A)) # Should 2nd term simply be transpose of 1st term?? More efficient?
        poly_A_hat_obs = mimo_tensor_polynomial(torch.broadcast_to(Cov, S_in.shape), self.coeffs_poly_fc, cob=self.poly_basis)

        temp = \
            self.alpha.view(self.c_out, self.c_in, 1, 1, 1) * A \
            + self.beta.view(self.c_out, self.c_in, 1, 1, 1) * bilinear \
            + poly_A_hat_obs

        # Set diagonal elements to 0: product with tensor which has 0 on on slice diagonals, 1s everywhere else.
        zd = (torch.ones((N, N), device=S_in.device) - torch.eye(N, device=S_in.device)).broadcast_to(intermed_shape)
        temp = temp * zd

        # Normalizing here forces all slices to have entries in [-1, 1].
        # Thus by reducing (mean)right after this, entries in output channel will still be in [-1, 1].
        # But this limits the ability of the model to scale input channels differently.
        if 'before' in self.where_normalize_slices and normalize:
            temp = normalize_slices(temp, which_norm=self.which_slice_normalization, extra=self.slice_normalization_val)

        # this can be a mean or sum reduction over the channels
        temp = self.channel_reduction(temp, dim=1)
        #temp = torch.mean(temp, dim=1)

        # Normalizing here forces entries in output channels to be in [-1, 1] but also allows relative scaling between
        # input channels to be learning.
        if 'after' in self.where_normalize_slices and normalize:
            temp = normalize_slices(temp, which_norm=self.which_slice_normalization, extra=self.slice_normalization_val)

        if self.include_nonlinearity:
            temp = F.relu(temp - self.tau)
        else:
            temp = temp - self.tau

        S_out = temp
        self.output_zeros = False
        if torch.allclose(S_out, torch.zeros(S_out.shape, device=S_out.device)):  # and False:
            self.output_zeros = True
            print('==================================')
            print('WARNING: A prox_layer is outputting an all 0 S_out')
            # print(cdp(self.tau), 'tau: threshold cutoff')
            print('==================================')

        if extra_outs:
            return S_out, S_in, temp
        else:
            return S_out

    def forward_old(self, S_in, Cov, extra_outs=False, layer=None):
        # S_in = [c_in, batch_size, N, N]
        # Cov = [batch_size, N, N]
        assert (len(S_in.shape) == 4) and (len(Cov.shape) == 3) and (S_in.shape[1] == Cov.shape[0])
        assert (S_in.shape[-1] == S_in.shape[-2]) and (Cov.shape[-1] == Cov.shape[-2])
        c_in = S_in.shape[0]
        assert (c_in == self.c_in), f'input has {c_in} input channels but expected {self.c_in} input channels'
        batch_size, N, _ = Cov.shape
        assert torch.all(self.tau >= 0).item(), f'tau is negative {self.tau}'


        # any slices all zeros?
        #zero_slice = (torch.sum(S_in.view(-1, N, N), dim= (1, 2)) < 1e-10)
        #if torch.any(zero_slice):
        #    print(f'Layer {layer}: some slice is all zeros')
        # always use power iteration for max eig
        S_in = normalize_slices(S_in, which_norm=self.s_in_norm, extra='custom')

        # construct tensor which has 0 on on slice diagonals, 1s everywhere else
        remove_diag = torch.ones((N, N), device=S_in.device)-torch.eye(N, device=S_in.device)
        zd = torch.broadcast_to(remove_diag, (self.c_out, self.c_in, batch_size, N, N))

        #Construct 2-4 constituent terms of output: {poly_s,  poly_s_fc, poly_fc, poly_fc_s}
        # each shape = [c_out, c_in, batch_size, N, N]
        # e.g. poly_s[2, 1] is the the polynomial of the 1st channel of S_in (S_in[1])
        # using the coefficients coeffs_poly_s[2,:,1]


        poly_s_temp = mimo_tensor_polynomial(S_in, self.coeffs_poly_s_fc, cob=self.poly_basis, normalize_basis=self.normalize_basis)
        fc_bc = torch.broadcast_to(Cov, poly_s_temp.shape) #fc_broadcast
        # CHANGED: now longer divide by 2
        poly_s_fc = (torch.matmul(poly_s_temp, fc_bc) + torch.matmul(fc_bc, poly_s_temp)) #/2


        # for numerical stability/gradient stability we only use first order polynomials of S: [I, h_0*S]
        poly_s = mimo_tensor_polynomial(S_in, self.coeffs_poly_s, cob=self.poly_basis, normalize_basis=self.normalize_basis)

        # poly_fc
        if self.order_poly_fc is not None:
            poly_fc = mimo_tensor_polynomial(torch.broadcast_to(Cov, S_in.shape), self.coeffs_poly_fc, cob=self.poly_basis, normalize_basis=self.normalize_basis)

        # poly_fc_s
        if self.order_poly_fc_s is not None:
            fc_bc_to_s = torch.broadcast_to(Cov, S_in.shape) # fc broadcast to S_in
            poly_fc_temp = mimo_tensor_polynomial(fc_bc_to_s, self.coeffs_poly_fc_s, cob=self.poly_basis, normalize_basis=self.normalize_basis)
            poly_fc_s = (torch.matmul(poly_fc_temp, S_in) + torch.matmul(S_in, poly_fc_temp))/2 # mimo_tensor_polynomial(torch.broadcast_to(Cov, S_in.shape), self.coeffs_poly_fc_s, cob='cheb')

        # Normalize each component separately?
        if self.component_norm is not None:
            poly_s_fc = normalize_slices(poly_s_fc, which_norm=self.component_norm, extra=self.component_norm_val)
            poly_s = normalize_slices(poly_s, which_norm=self.component_norm, extra=self.component_norm_val)
            if self.order_poly_fc is not None:
                poly_fc = normalize_slices(poly_fc, which_norm=self.component_norm, extra=self.component_norm_val)
            if self.order_poly_fc_s is not None:
                poly_fc_s = normalize_slices(poly_fc_s, which_norm=self.component_norm, extra=self.component_norm_val)

        # CHANGED: Want A_in - grad_g()  = A_in - [-h_0*FC - h1(A_in*FC + FC*A_in)]
        # Need to append 0 onto S_in in order for dimensions to match up.
        temp = poly_s_fc + poly_s #poly_s - poly_s_fc
        if self.order_poly_fc is not None:
            temp += poly_fc
        if self.order_poly_fc_s is not None:
            temp += poly_fc_s

        temp = temp*zd
        if self.where_channel_reduce == 'before_norm':
            temp = torch.mean(temp, dim=1)

        # normalizing here instead of after shifted relu.
        # Why? Forces all taus to have same magnitude, bc now all thresholds
        #   are occuring on inputs of <1
        temp = normalize_slices(temp, which_norm=self.which_slice_normalization, extra=self.slice_normalization_val)

        if self.where_channel_reduce == 'after_norm':
            temp = torch.mean(temp, dim=1)

        # what percent of entries are being killed by tau?
        # temp = [c_out, c_in, bs, N, N]
        # num_killed := entries set to zero by shifted relu that were NOT zero before
        if False:
            zero_before = (temp == 0)

        if self.include_nonlinearity:
            temp = F.relu(temp-self.tau)
        else:
            temp = temp - self.tau

        if False:
            zero_after = (temp == 0)
            subset_exist_2_killed = ~zero_before & zero_after
            num_possible_edges = N**2 - N
            ave_killed_by_sr_per_graph = torch.sum(subset_exist_2_killed, dim=(3, 4))/num_possible_edges # (c_out, c_in, bs)
            self.ave_killed_by_sr_per_batch = torch.mean(ave_killed_by_sr_per_graph, dim=-1) # ave over batch, mean of channel # (c_out, c_in)

        # reduce out input channels. 0 dim => # out channels, 1 dim => input channels
        if self.where_channel_reduce == 'after_relu':
            temp = torch.mean(temp, dim=1)

        S_out = temp
        self.output_zeros = False
        if torch.allclose(S_out, torch.zeros(S_out.shape, device=S_out.device)):  # and False:
            self.output_zeros = True
            print('==================================')
            print('WARNING: A prox_layer is outputting an all 0 S_out')
            #print(cdp(self.tau), 'tau: threshold cutoff')
            print('==================================')

        if extra_outs:
            return S_out, S_in, poly_s_fc, poly_s, poly_fc, poly_fc_s, temp
        else:
            return S_out

    def param_dict(self, gradients = False):
        if gradients:
            coeffs_poly_s = self.coeffs_poly_s.grad.clone().detach().cpu().numpy()
            coeffs_poly_fc = self.coeffs_poly_fc.grad.clone().detach().cpu().numpy()
        else:
            coeffs_poly_s = self.coeffs_poly_s.clone().detach().cpu().numpy()
            coeffs_poly_fc = self.coeffs_poly_fc.clone().detach().cpu().numpy()


        d = {}
        if self.learn_tau:
            tau = self.tau.clone().detach().cpu().numpy()
            for out_ch_idx in range(self.c_out):
                d[f'TAU-out{out_ch_idx}'] = tau[out_ch_idx]

        for out_ch_idx in range(self.c_out):
            for in_ch_idx in range(self.c_in):
                # H1 coeffs SCALE_SIN^-{coeff_idx}*FC
                for coeff_idx in range(self.order_poly_s_fc):
                    d[f'out:{out_ch_idx}|in:{in_ch_idx}-poly_s_fc:{coeff_idx}'] = \
                        coeffs_poly_s_fc[out_ch_idx, coeff_idx, in_ch_idx]

                # H2 coeffs SCALE_SIN^{coeff_idx}
                for coeff_idx in range(self.order_poly_s):
                    d[f'out{out_ch_idx}|in{in_ch_idx}-poly_s:{coeff_idx}'] = \
                        coeffs_poly_s[out_ch_idx, coeff_idx, in_ch_idx]

                if self.order_poly_fc is not None:
                    # cov coeffs SCALE_SIN^{coeff_idx}
                    for coeff_idx in range(self.order_poly_fc):
                        d[f'out{out_ch_idx}|in{in_ch_idx}-poly_fc:{coeff_idx}'] = \
                            coeffs_poly_fc[out_ch_idx, coeff_idx, in_ch_idx]

                if self.order_poly_fc_s is not None:
                    # cov coeffs SCALE_SIN^{coeff_idx}
                    for coeff_idx in range(self.order_poly_fc_s):
                        d[f'out{out_ch_idx}|in{in_ch_idx}-poly_fc_s:{coeff_idx}'] = \
                            coeffs_poly_fc_s[out_ch_idx, coeff_idx, in_ch_idx]

        return d


class plCovNN(pl.LightningModule):
    def __init__(self,
                 # architecture
                 channels,
                 share_parameters,
                 poly_fc_orders: [List[int]],
                 include_nonlinearity: bool = True,
                 poly_basis: Optional[str] = 'cheb',  # basis (standard, cheb) used for P(A_hat)
                 channel_reduction: str = 'mean',
                 # normalization procedure: where and what type
                 where_normalize_slices: str = 'after_reduction',
                 which_slice_normalization: Optional[str] = 'max_abs', slice_normalization_val: Optional[float] = 99,
                 norm_last_layer: bool = False,
                 # loss, optimizer
                 which_loss: str = 'hinge',
                 monitor: str = 'error',  # monitor is used by optimizer to check to reduce lr, stop running, etc
                 optimizer: str = 'adam',
                 learning_rate: float = .01,
                 adam_beta_1: float = .85, adam_beta_2: float = .99, momentum: float = .95,
                 hinge_margin: float = 1.0, hinge_slope: float = 1.0,
                 weight_decay: float = 0,
                 l2_strength: float = 0.0, l1_strength: float = 0.0, l2_cutoff: float = 0.0, l1_cutoff: float = 0.0,
                 # subnetwork of interest: None for synthetics.
                 # # 'full' means entire network. for brain data, other options are ['occipital', 'frontal', ..., 'full' ] etc
                 real_network: Optional[str] = 'full',
                 # tau
                 learn_tau: bool = True,
                 low_tau_val: float = .01,  #for init
                 max_tau_clamp_val: float = .99,  #max allowable tau val
                 log_regr: bool = False,
                 # prior
                 n_train : int = 68,
                 prior_construction: str = 'mean',
                 prior_frac_contains: float = 0.0,
                 # logging
                 logging=None,
                 # reproducability
                 rand_seed: int = 45,
                 # threshold
                 threshold_metric: str = 'acc', #which metric to use when choosing threshold value
                 threshold_metric_test_points: np.ndarray = np.arange(0, .6, .025) # which discrete points to check
                 ):
        super().__init__()
        self.rand_seed = rand_seed
        if self.rand_seed is not None:
            seed_everything(self.rand_seed, workers=True)
        self.logging = None if logging == 'none' else logging
        #https://github.com/PyTorchLightning/pytorch-lightning/issues/1225
        # for logging of hparams and computationa graph
        #self.example_input_array = torch.zeros(10, 68, 68)

        self.log_regr = log_regr

        self.share_parameters = share_parameters

        if type(channels) is str:
            channels = [int(ch) for ch in channels]
        self.channels = channels
        self.depth = len(self.channels)-1 # channels define maps between layers


        assert channel_reduction in ['mean', 'sum']
        self.channel_reduction = torch.mean if 'mean' in channel_reduction else torch.sum
        self.where_normalize_slices = where_normalize_slices

        self.poly_basis = poly_basis
        self.poly_fc_orders = poly_fc_orders#filter_order_layers(poly_fc_orders, self.depth, order_if_none=1)
        assert len(poly_fc_orders) == len(self.channels)-1

        self.learn_tau = learn_tau
        self.include_nonlinearity = include_nonlinearity
        self.max_tau_clamp_val = max_tau_clamp_val
        self.low_tau_vals = np.repeat(low_tau_val, self.depth)
        self.which_loss = which_loss
        self.learning_rate, self.weight_decay, self.momentum = learning_rate, weight_decay, momentum
        self.adam_beta_1, self.adam_beta_2 = adam_beta_1, adam_beta_2
        self.monitor = monitor
        self.l2_strength, self.l1_strength = l2_strength, l1_strength
        self.l2_cutoff, self.l2_cutoff = l2_cutoff, l1_cutoff

        self.which_slice_normalization = None if (which_slice_normalization == 'none') else which_slice_normalization
        self.slice_normalization_val = slice_normalization_val

        self.norm_last_layer = norm_last_layer

        self.real_network = real_network

        self.hinge_margin, self.hinge_slope = hinge_margin, hinge_slope


        #### Prior Construction and stuff to make model work ####
        self.prior_construction = 'zeros' if prior_construction in ['none', 'None', None, 'zeros'] else prior_construction
        assert self.prior_construction in ['zeros', 'ones', 'median', 'mean', 'multi', 'block', 'sbm']
        assert self.prior_construction != 'multi', f'must change channels and prior_prep to accomodate >1 prior channel'

        # register_buffer() wont save tensor in state_dict if fed None. Thus we must know what size graphs
        # are used for training to initialize training_prior tensor (so it can be save/loaded later on)
        self.n_train = n_train
        num_prior_channels = 1  # FORCE THIS TO BE
        self.register_buffer('training_prior', -1*torch.ones(size=(num_prior_channels, n_train, n_train)), persistent=True)
        # best threshold found on training set when training_prior was constructed
        self.register_buffer('training_prior_threshold', -1*torch.ones(num_prior_channels, 1), persistent=True)
        # best threshold found on validation set after training done, to be used for testing
        self.register_buffer('testing_prior_threshold', -1*torch.ones(num_prior_channels, 1), persistent=True)

        self.prior_channels, self.prior_channel_names = None, None

        #self.register_buffer('real_data_prior', torch.zeros(1, 68, 68), persistent=True)
        self.register_buffer('threshold', torch.tensor([-1.0]), persistent=True)
        self.test_threshold = None # placeholder, will save in checkpoint and load when needed
        #self.register_buffer('test_threshold', torch.tensor([-1.0]), persistent=True)

        self.threshold_metric, self.threshold_metric_test_points = \
            threshold_metric, threshold_metric_test_points
        # used for when construction prior with mean/median. What fraction of samples should contain an edge for us
        # to count it in prior (use in mean/median calculation)?
        self.prior_frac_contains = prior_frac_contains
        ######

        # on validation set
        self.list_of_metrics = []
        self.prior_metrics_val, self.prior_metrics_test = {}, {}

        if self.share_parameters:
            print(f'Given {self.channels} and Sharing Parameters.  All channels must be the same!')
            assert all([c == self.channels[0] for c in self.channels]), f'for shared parameters, all input/output must be same'
            print(f'Shared parameters. Looping over Single MIMO layer with c_in/c_out = {self.channels[0]}, {len(self.channels) - 1} times')
            #self.channels[0] = 1
            layers = []
            c = self.channels[0] # they must all be the same
            single_layer = \
                mimo_prox_layer(c_in=c, c_out=c,
                                poly_basis=self.poly_basis,
                                order_poly_fc=self.poly_fc_orders[0],
                                where_normalize_slices=self.where_normalize_slices,
                                channel_reduction=self.channel_reduction,
                                low_tau=self.low_tau_vals[0], learn_tau=self.learn_tau,
                                which_slice_normalization=self.which_slice_normalization,
                                slice_normalization_val=self.slice_normalization_val,
                                include_nonlinearity=self.include_nonlinearity
                                )
            for layer in range(self.depth):
                layers.append(single_layer)
        else:
            print('\tNot sharing parameters. First channel overridden for proper prior dim. Last channel dim will be overridden to reduce unused params')
            print(f'\t\tGIVEN {self.channels}', end="")
            self.channels[0], self.channels[-1] = 1, 1
            print(f' ---> USING {self.channels}')

            layers = []

            for layer in range(self.depth):
                # last layer no norm
                c_in, c_out = self.channels[layer], self.channels[layer+1]
                layers.append(mimo_prox_layer(c_in=c_in, c_out=c_out,
                                              poly_basis=self.poly_basis,
                                              order_poly_fc=self.poly_fc_orders[layer],
                                              where_normalize_slices=self.where_normalize_slices,
                                              channel_reduction=self.channel_reduction,
                                              low_tau=self.low_tau_vals[layer], learn_tau=self.learn_tau,
                                              which_slice_normalization=self.which_slice_normalization, slice_normalization_val=self.slice_normalization_val,
                                              include_nonlinearity=self.include_nonlinearity
                                              ))

        if self.log_regr:
            layers.append(logistic_regression())

        self.module_list = nn.ModuleList(layers)

        if self.log_regr:
            self.prox_layers = self.module_list[:-1]
            self.log_regr_layer = self.module_list[-1]
        else:
            self.prox_layers = self.module_list

        # private variables to be used for printing/logging
        self.epoch_val_losses = []

        self.subnetwork_masks = None

        self.save_hyperparameters()
        #kwargs = {'prior': None}
        #self.save_hyperparameters(**kwargs)

    def setup(self, stage):
        # we must construct prior for the first time and save it in self.training_prior. Done in self.construct_prior_()
        if stage in ['fit']:
            #assert torch.allclose(self.training_prior, -1*torch.ones(1)), f"init'ed to all -1. This is how we know not used yet."
            self.prior_channels, self.prior_channel_names = self.construct_prior_(prior_dl=self.train_dataloader())
            _, train_scs, _, _, _ = self.train_dataloader().dataset.full_ds()
            _, val_scs, _, _, _ = self.val_dataloader().dataset.full_ds()
            #_, test_scs, test_subject_ids, _, _ = self.test_dataloader().dataset.full_ds()

            for i, (prior_channel, prior_channel_name) in enumerate(zip(self.prior_channels, self.prior_channel_names)):
                # find best threshold to use to optimize acc on train set
                # optimize threshold on training set
                self.training_prior_threshold[i] \
                    = best_threshold_by_metric(thresholds=self.threshold_metric_test_points,
                                               adjs=train_scs.detach(),
                                               preds=torch.broadcast_to(prior_channel, train_scs.shape).detach(),
                                               metric='acc')
                # find performance on validation set for each subnetwork using threshold found on train set
                self.prior_metrics_val[prior_channel_name] = \
                    prediction_metrics_for_each_subnetwork(y=val_scs.detach(),#.to(prior_channel.device) #device here
                                                           y_hat=torch.broadcast_to(prior_channel, val_scs.shape).detach(),
                                                           threshold=self.training_prior_threshold[i],
                                                           subnetwork_mask_dict=self.subnetwork_masks,
                                                           hinge_margin=self.hinge_margin,
                                                           hinge_slope=self.hinge_slope,
                                                           reduction='ave')

                # display metrics found
                print(f"Prior {prior_channel_name} metrics using {self.training_prior_threshold[i]}")
                print(f'ON VAL')
                print_subnet_perf_dict(subnetwork_metrics_dict=self.prior_metrics_val[prior_channel_name],
                                       indents=2, convert_to_percent=['acc', 'error'],
                                       metrics2print=['hinge', 'mse', 'mae', 'acc', 'error', 'mcc'])

        if stage in ['test']: # and same_size=True
            _, val_scs, val_subject_ids, _, _ = self.val_dataloader().dataset.full_ds()
            _, test_scs, test_subject_ids, _, _ = self.test_dataloader().dataset.full_ds()

            # if testing is NOT right after train (loading from checkpoint), self.prior_channels is None
            # use validation data to find best threshold for PRIOR as prediction on test set
            print(f'PERFORMANCE OF {self.prior_construction} PRIOR ON TEST SET USING VALIDATION SET TO FIND THRESHOLD')
            for i, (prior_channel, prior_channel_name) in enumerate(zip(self.prior_channels, self.prior_channel_names)):
                # optimize threshold on validation set
                self.testing_prior_threshold[i] \
                    = best_threshold_by_metric(thresholds=self.threshold_metric_test_points,
                                               adjs=val_scs.detach(),
                                               preds=torch.broadcast_to(prior_channel, val_scs.shape).detach(),
                                               metric='acc')

                # find performance of prior on test set for each subnetwork
                self.prior_metrics_test[prior_channel_name] = \
                    prediction_metrics_for_each_subnetwork(y=test_scs.detach(),#to(prior_channel.device)
                                                           y_hat=torch.broadcast_to(prior_channel, test_scs.shape).detach(),
                                                           threshold=self.testing_prior_threshold[i], # already computed
                                                           subnetwork_mask_dict=self.subnetwork_masks,
                                                           hinge_margin=self.hinge_margin,
                                                           hinge_slope=self.hinge_slope,
                                                           reduction='ave')
                # display metrics found
                print('PERF OF PRIOR ON TEST, USING VALIDATION TO FIND THRESHOLD')
                print_subnet_perf_dict(subnetwork_metrics_dict=self.prior_metrics_test[prior_channel_name],
                                       indents=2, convert_to_percent=['acc', 'error'],
                                       metrics2print=['hinge', 'mse', 'mae', 'acc', 'error', 'mcc'])

    def on_train_start(self) -> None:
        for prior_channel_name, metrics_per_subnet in self.prior_metrics_val.items():
            for subnet_name, subnetwork_metrics in metrics_per_subnet.items():
                name = f'val/{prior_channel_name} prior/'
                for metric_name, value in subnetwork_metrics.items():
                    self.log(name=name+metric_name, value=value)

        for prior_channel_name, metrics_per_subnet in self.prior_metrics_test.items():
            for subnet_name, subnetwork_metrics in metrics_per_subnet.items():
                name = f'test/{prior_channel_name} prior/'
                for metric_name, value in subnetwork_metrics.items():
                    self.log(name=name+metric_name, value=value)

    def forward(self, batch) -> Any:
        return self.shared_step(batch=batch, batch_idx=0)

    def predict_step(self, batch):
        return self(batch=batch)

    def shared_step(self, batch, batch_idx):
        fcs, adjs, subject_ids, scan_dirs, tasks = batch
        batch_size, N = fcs.shape[:-1]

        prior = self.prior_prep(batch_size, N)

        s_in = prior
        s_out = s_in
        for i, prox_layer in enumerate(self.prox_layers):
            s_out = prox_layer(s_out, fcs, layer=i, normalize=self.norm_last_layer if i == (self.depth - 1) else True)

        #if torch.any(torch.isnan(s_out)):
        #    print(f'model output has nan')

        # s_out shape = [c_out, bs, N, N], where c_out of final layer.
        # For non-shared_params c_out will always be 1. For shared params will be>=1.
        # Thus simply only look at 0th output channel for loss.
        adjs_hat = s_out[0]

        return adjs_hat

    def training_step(self, batch, batch_idx):
        fcs, adjs, subject_ids, scan_dirs, tasks = batch
        adjs_hat = self.shared_step(batch, batch_idx)
        loss = self.compute_loss(adjs=adjs, adjs_hat=adjs_hat, bs=fcs.shape[0], N=fcs.shape[1])

        #self.log(f'per_batch_train_{self.which_loss}', loss, logger=(self.logger is not None))

        return loss

    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
        # find the shallowest layer where model outputting ALL zeros. -1 if no layer outputting all zeros.
        sl = shallowest_layer_all_zero(self)

        if sl > -1:
            #sl = shallowest_layer_all_zero(self)
            print(f'\n\t~~~~shallowest layer with zero outputs = {sl}. Resample param vals of layer.~~~')
            module = self.prox_layers[sl]
            resample_params(module)

        # ensure tau stays in [0. ~99]
        if self.learn_tau:
            clamp_tau(self, large_tau=self.max_tau_clamp_val)

        return

    def training_epoch_end(self, train_step_outputs):
        # this must be overridden for batch outputs to be fed to callback. Bug.
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/4326

        # should this be done in val_epoch_end? Only when find best val_loss?
        avg_loss = torch.stack([x['loss'] for x in train_step_outputs]).mean()
        #self.log("epoch_avg_train_loss", avg_loss, logger=False, prog_bar=True)
        self.log(name=f'train/{self.which_loss}_epoch', value=avg_loss, on_step=False, on_epoch=True)
        return None

    def on_validation_start(self) -> None:
        # find best threshold only once before every validation epoch
        # use training set to optimize threshold during training
        self.threshold[0] = self.find_threshold(dl=self.train_dataloader(), threshold_test_points=self.threshold_metric_test_points, metric2chooseThresh=self.threshold_metric)
        self.log('threshold', self.threshold, prog_bar=True, on_epoch=True, on_step=False)

    #used by on_val/test_epoch end (and size_gen_eval) to accumulate info on batches at end of epoch
    def eval_step(self, batch, batch_idx, threshold, reduction='sum'):
        fcs, adjs, subject_ids, scan_dirs, tasks = batch
        adjs_hat = self.shared_step(batch, batch_idx)

        subnetwork_metrics_dict = \
            prediction_metrics_for_each_subnetwork(y=adjs, y_hat=adjs_hat, threshold=threshold,
                                                   subnetwork_mask_dict=self.subnetwork_masks,
                                                   hinge_margin=self.hinge_margin,
                                                   hinge_slope=self.hinge_slope,
                                                   reduction=reduction)
        return {'subnetwork_metrics_dict': subnetwork_metrics_dict, 'batch_size': torch.tensor(len(fcs))}

    def validation_step(self, batch, batch_idx):
        return self.eval_step(batch, batch_idx, threshold=self.threshold, reduction='sum')

    def validation_epoch_end(self, val_step_outputs):
        # aggregate all outputs from validation step batches
        #metric_names = ['hinge', 'mse', 'mae', 'macro_F1', 'acc', 'mcc']
        total_epoch_samples = torch.stack([x['batch_size'] for x in val_step_outputs]).sum()
        total_subnetwork_metrics_dict = {}
        subnetwork_names = val_step_outputs[0]['subnetwork_metrics_dict'].keys()
        for subnetwork_name in subnetwork_names:
            total_subnetwork_metrics_dict[subnetwork_name] = {}
            for metric in ['hinge', 'mse', 'mae', 'macro_F1', 'acc', 'mcc']:
                total_subnetwork_metrics_dict[subnetwork_name][metric] = torch.stack([x['subnetwork_metrics_dict'][subnetwork_name][metric] for x in val_step_outputs]).sum()/total_epoch_samples
            total_subnetwork_metrics_dict[subnetwork_name]['error'] = 1 - total_subnetwork_metrics_dict[subnetwork_name]['acc']

        # save running list of metrics as we go
        self.list_of_metrics.append({**total_subnetwork_metrics_dict, 'epoch': self.current_epoch})

        # focus on network specified in self.real_network
        assert self.real_network in ['full'], f"haven't tested with any other subnetwork {self.real_network}"
        #metrics = total_subnetwork_metrics_dict[self.real_network]

        # for logger and wandb
        # is this needed? will be included in 'full' subnetwork below.
        # maybe be needed for monitoring...
        stage = 'val'
        # log metrics for each subnetwork
        for subnetwork_name, subnetwork_metrics in total_subnetwork_metrics_dict.items():
            name = f'{stage}/{subnetwork_name}'
            for metric_name, value in subnetwork_metrics.items():
                self.log(name=name+'/'+metric_name, value=value)

        # for progress bar
        metrics_in_progress_bar = ['error', 'mcc', 'mse', 'mae']
        prog_bar_metrics_dict = {}
        for metric in metrics_in_progress_bar:
            value = 100*total_subnetwork_metrics_dict[self.real_network][metric] if metric not in ['mse', 'mae', 'hinge'] else total_subnetwork_metrics_dict[self.real_network][metric]
            prog_bar_metrics_dict[metric] = value
        self.log_dict(prog_bar_metrics_dict, logger=False, prog_bar=True)

        # to able to see/log lr, need to do this
        current_lr = self.trainer.optimizers[0].state_dict()['param_groups'][0]['lr']
        self.log("lr", round(current_lr, 10), logger=False, prog_bar=True)

        # how much better are we doing than prior? POSITIVE = GOOD
        perc_change_curr_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val, prediction_metrics=total_subnetwork_metrics_dict)
        perc_change_curr = perc_change_curr_priors_subnets[self.prior_construction][self.real_network]

        for prior_channel_name, subnetwork_dicts in perc_change_curr_priors_subnets.items():
            for subnetwork_name, subnetwork_metrics in subnetwork_dicts.items():
                # Note that all positive changes are 'good', and all negative changes are 'bad'
                for metric_name, value in subnetwork_metrics.items():
                    name = f'{stage}/{subnetwork_name}/current vs {prior_channel_name} prior'
                    self.log(name=name+'/'+metric_name, value=value)

        # print summary of training results on full network: green good, red bad
        print(f'\nPercent reduction (using training set ONLY for train/threshold finding, and eval on validation) over subnetwork "{self.real_network}" using loss: *{self.which_loss}*: {format_color("Good", "green")}/{format_color("Bad", "red")}')
        logged_metrics = ['val/full/error', 'val/full/mae', 'val/full/mse']
        for log_metric in logged_metrics:
            metric = log_metric.split("/")[-1]
            maximize = any(m in metric for m in ['acc', 'mcc', 'f1'])
            best_epoch_metrics = self.best_metrics(sort_metric=metric, sort_subnetwork=self.real_network, maximize=maximize)[0]
            perc_change_best_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val, prediction_metrics=best_epoch_metrics)
            perc_change_best = perc_change_best_priors_subnets[self.prior_construction][self.real_network]
            print(f"{f'  {metric}: Best ':<15}", end="")
            num_in_color = format_color(f"{perc_change_best[metric].abs():.5f}% ({best_epoch_metrics['full'][metric]:.5f})", color='green' if perc_change_best[metric] > 0 else 'red')
            best_ep = f" on epoch {best_epoch_metrics['epoch']}"
            print(f"{f'{num_in_color}  {best_ep}':<15}", end="")
            #print(f"{s:<50}", end="")
            s = ' | Current: ' + format_color(f"{perc_change_curr[metric].abs():.5f}% ({total_subnetwork_metrics_dict['full'][metric]:.5f})", color='green' if perc_change_curr[metric] > 0 else 'red')
            print(s)

    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        checkpoint['prior_metrics_val'] = self.prior_metrics_val
        checkpoint['prior_channels'] = self.prior_channels
        checkpoint['prior_channel_names'] = self.prior_channel_names
        checkpoint['list_of_metrics'] = self.list_of_metrics
        checkpoint['test_threshold'] = self.find_threshold(dl=self.val_dataloader(), threshold_test_points=self.threshold_metric_test_points, metric2chooseThresh=self.threshold_metric)
        #self.log('test_threshold', self.test_threshold, on_epoch=True, on_step=False)
        print(f"\n\tsaving checkpoint: saving threshold found {checkpoint['test_threshold']:.3f} using validation set during training")

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        self.prior_metrics_val = checkpoint['prior_metrics_val']
        self.prior_channels = checkpoint['prior_channels']
        self.prior_channel_names = checkpoint['prior_channel_names']
        self.list_of_metrics = checkpoint['list_of_metrics']

        self.test_threshold = checkpoint['test_threshold']
        print(f"\nLoading threshold found using validaiton set during training: {self.test_threshold:.5f} which achieved {self.list_of_metrics[-1]['full']['error']*100:.4f}% error")

    """ 
    def on_test_start(self) -> None:
        # find best threshold only once before every validation epoch
        # use training set to optimize threshold during training
        # this is to be used to threshold the MODEL OUTPUTS
        self.test_threshold[0] = self.find_threshold(dl=self.val_dataloader(), threshold_test_points=self.threshold_metric_test_points, metric2chooseThresh=self.threshold_metric)
        self.log('test_threshold', self.test_threshold, prog_bar=True, on_epoch=True, on_step=False)
    """

    def test_step(self, batch, batch_idx):
        # we just finished training
        #if self.threshold is None:
        return self.eval_step(batch, batch_idx, threshold=self.test_threshold, reduction='sum')
        """ 
        fcs, adjs, subject_ids, scan_dirs, tasks = batch
        reduction = 'sum'
        adjs_hat = self.shared_step(batch, batch_idx)

        subnetwork_metrics_dict = \
            prediction_metrics_for_each_subnetwork(y=adjs, y_hat=adjs_hat,
                                                   threshold=self.test_threshold,
                                                   subnetwork_mask_dict=self.subnetwork_masks,
                                                   hinge_margin=self.hinge_margin,
                                                   hinge_slope=self.hinge_slope,
                                                   reduction=reduction)
        return {'subnetwork_metrics_dict': subnetwork_metrics_dict, 'batch_size': torch.tensor(len(fcs))}
        
        """

    def test_epoch_end(self, outputs: List[Any]) -> None:
        # aggregate all outputs from test step batches
        total_epoch_samples = torch.stack([x['batch_size'] for x in outputs]).sum()
        total_subnetwork_metrics_dict = {}
        subnetwork_names = outputs[0]['subnetwork_metrics_dict'].keys()
        for subnetwork_name in subnetwork_names:
            total_subnetwork_metrics_dict[subnetwork_name] = {}
            for metric in ['hinge', 'mse', 'mae', 'macro_F1', 'acc', 'mcc']:
                total_subnetwork_metrics_dict[subnetwork_name][metric] = torch.stack([x['subnetwork_metrics_dict'][subnetwork_name][metric] for x in outputs]).sum()/total_epoch_samples
            total_subnetwork_metrics_dict[subnetwork_name]['error'] = 1 - total_subnetwork_metrics_dict[subnetwork_name]['acc']

        stage = 'test'
        # focus on network specified in self.real_network
        assert self.real_network in ['full'], f"haven't tested with any other subnetwork {self.real_network}"

        # log metrics for each subnetwork
        for subnetwork_name, subnetwork_metrics in total_subnetwork_metrics_dict.items():
            name = f'{stage}/{subnetwork_name}'
            for metric_name, value in subnetwork_metrics.items():
                self.log(name=name+'/'+metric_name, value=value)

        # log % change of test metrics over prior metrics
        if (self.prior_metrics_test is not None) and len(self.prior_metrics_test) > 0:
            perc_change_curr_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_test, prediction_metrics=total_subnetwork_metrics_dict)
            for prior_channel_name, subnetwork_dicts in perc_change_curr_priors_subnets.items():
                for subnetwork_name, subnetwork_metrics in subnetwork_dicts.items():
                    # Note that all positive changes are 'good', and all negative changes are 'bad'
                    name = f'{stage}/{subnetwork_name}/current vs {prior_channel_name} prior'
                    for metric_name, value in subnetwork_metrics.items():
                        self.log(name=name+'/'+metric_name, value=value)

        # log BEST epoch vals: raw and % change
        if self.prior_metrics_val is not None:
            logged_metric_names = ['val/full/mse', 'val/full/mae', 'val/full/error', 'val/full/mcc', 'val/full/hinge']
            for prior_channel_name in self.prior_metrics_val.keys():
                for log_metric_name in logged_metric_names:
                    metric_name = log_metric_name.split("/")[-1]
                    maximize = any(m in metric_name for m in ['acc', 'mcc', 'f1'])
                    best_epoch_metrics = self.best_metrics(sort_metric=metric_name, sort_subnetwork=self.real_network, maximize=maximize)[0]
                    # positive = good change
                    perc_change_best_priors_subnets = percent_change_metrics(prior_metrics=self.prior_metrics_val, prediction_metrics=best_epoch_metrics)
                    perc_change_best_full = perc_change_best_priors_subnets[self.prior_construction][self.real_network]
                    for metric_names_best, values_best in perc_change_best_full.items():
                        # at the epoch where metric_name was minimized, how did all the other metrics do? raw and vs prior
                        self.log(name=f'best_val_by_{metric_name}/'+'full/' + metric_names_best, value=best_epoch_metrics['full'][metric_names_best])
                        self.log(name=f'best_val_by_{metric_name}/'+f'full/current vs {prior_channel_name} prior/' + metric_names_best, value=perc_change_best_full[metric_names_best])

        return

    """
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        lr_scheduler = \
            {'scheduler':
                 torch.optim.lr_scheduler.OneCycleLR(
                     optimizer,
                     max_lr=self.learning_rate,
                     steps_per_epoch=int(len(self.train_dataloader())),
                     epochs=self.hparams.epochs,
                     anneal_strategy="linear",
                     final_div_factor=10,
                ),
            'name': 'learning_rate',
            'interval': 'step',
            'frequency': 1
            }
        scheduler = torch.optim.lr_scheduler.CyclicLR

        return [optimizer], [lr_scheduler]

    """
    def configure_optimizers(self):
        m = str.lower(self.monitor)
        if ('acc' in m) or ('hinge' in m):
            monitor = 'acc'
            mode = 'max'
        elif 'error' in m or 'err' in m:
            monitor = 'error'
            mode = 'min'
        elif 'mse' in m:
            monitor = 'mse'#'loss'
            mode = 'min'
        elif 'mae' in m:
            monitor = 'mae'
            mode = 'min'
        elif 'f1' in m:
            monitor = 'macro_F1'
            mode = 'max'
        elif 'mcc' in m:
            monitor = 'mcc'
            mode = 'max'
        else:
            raise ValueError(f'monitor {self.monitor} is unsupported')

        #optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=self.momentum)#, weight_decay=self.weight_decay)

        lr = self.hparams.learning_rate
        if 'adam' in self.hparams.optimizer:
            b1, b2 = self.hparams.adam_beta_1, self.hparams.adam_beta_2
            optimizer = torch.optim.Adam(self.parameters(), lr=lr, betas=(b1, b2))#, weight_decay=self.weight_decay)
        elif 'sgd' in self.hparams.optimizer:
            optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=self.hparams.momentum)
        else:
            raise ValueError(f'only configured Adam and SGD optimizzer. Given {self.hparams.optimizer}')
        """
        scheduler = {'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer,
            mode=mode,
            factor=0.7,
            threshold_mode='rel',
            threshold=0.005, # must make a .05% improvement within patience steps
            patience=5*5, # 5 batches/epoch * 5 epochs
            min_lr=5e-6),
            'monitor': monitor}
        """

        """
        scheduler_ = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=.99)

        """
        """
        scheduler_ = torch.optim.lr_scheduler.CyclicLR(
            optimizer=optimizer,
            base_lr=.00001,
            max_lr=self.learning_rate,
            step_size_up=10*10,
            cycle_momentum=True,
            base_momentum=.9,
            max_momentum=.98,
        )
        """

        """
        scheduler = \
            {
                'scheduler': scheduler_,
                'monitor': monitor,
                'name': 'learning_rate',
                'interval': 'epoch',
                'frequency': 1,
                'strict': True
            }
        """
        """
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, 0.9)
        #T_0 = 1
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=1, eta_min=0.01, last_epoch=-1, verbose=False)

        scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,
                                                  warmup_epochs=2,
                                                  max_epochs=10000000,
                                                  warmup_start_lr=0.01,
                                                  eta_min=0.001,
                                                  last_epoch=-1)

        #torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 20, T_mult=1, eta_min=0.001, last_epoch=-1, verbose=True)
        """
        optimizers = [optimizer]
        #lr_schedulers = [scheduler]
        return optimizer#s, lr_schedulers
        #return {"optimizer": optimizer}#, "lr_scheduler": scheduler, "monitor": monitor}

    ### HELPER METHODS ###
    @torch.no_grad()
    def test_large_graphs(self, val_dl, test_dl, use_val_for_threshold=False):
        #user responsibl
        # find threshold with val set
        if use_val_for_threshold:
            threshold = self.find_threshold(dl=val_dl, threshold_test_points=self.threshold_metric_test_points, metric2chooseThresh=self.threshold_metric)
            print(f'Used Larger Graph validation set to find threshold: {threshold:.5f}')
        else:
            # test_threshold is found using the validtion set (during training will smaller graphs)
            threshold = self.test_threshold
            if threshold < 0 or threshold is None: #we init this as none
                raise ValueError(f'Model must compute threshold_test using validation data during training.Currently uninitialized: {threshold:.5f}')
            print(f'Used threshold found during training.: {threshold:.5f}')

        # use this to get performance on test set
        outputs = []
        errors, mses, maes, mccs = [], [], [], []
        for batch_idx, batch in enumerate(iter(test_dl)):
            for i in range(3):
                batch[i] = batch[i].to(self.device) # move x, y, y_hat to cuda if needed
            x, y = batch[:2]
            y_hat = self.shared_step(batch, batch_idx)
            # batch graph metrics only takes binary inputs
            pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x=(y_hat > threshold), y=(y > 0)) #argument 'x' is predicted graphs! Bad name.
            errors.append(1 - acc)
            mccs.append(mcc)
            diff = y_hat - y
            mses.append((diff**2).mean(dim=(1, 2)))
            maes.append(torch.abs(diff).mean(dim=(1, 2)))
            print(f'{batch_idx} ', end="")
            #outputs.append(self.eval_step(batch, batch_idx, threshold, reduction='sum'))
        print("\nDONE!")
        # metric for each sample in entire dataset
        metrics = {'errors': torch.cat(errors, dim=0), 'mses': torch.cat(mses, dim=0),
                   'maes': torch.cat(maes, dim=0), 'mcc': torch.cat(mccs, dim=0)}
        return metrics, threshold
        """ 
        total_epoch_samples = torch.stack([x['batch_size'] for x in outputs]).sum()
        # only to be used with 'full' network
        network_name, total_network_metrics_dict = 'full', {'mean': {}, 'standard_error': {}}
        for metric in ['hinge', 'mse', 'mae', 'macro_F1', 'acc', 'mcc']:
            all_samples = torch.stack([x['subnetwork_metrics_dict'][network_name][metric] for x in outputs])
            total_network_metrics_dict[metric]['mean'] = all_samples.sum().detach() / total_epoch_samples
            total_network_metrics_dict[metric]['standard_error'] = all_samples.sum().detach() / total_epoch_samples
            total_network_metrics_dict[metric] = torch.stack([x['subnetwork_metrics_dict'][network_name][metric] for x in outputs]).sum().detach() / total_epoch_samples
        total_network_metrics_dict['error'] = 1 - total_network_metrics_dict['acc']

        return total_network_metrics_dict, threshold
        """


    # DOES NOT CHANGE ANY PRIVATE VARIABLES: STATIC METHOD
    @torch.no_grad()
    def find_threshold(self, dl, threshold_test_points, metric2chooseThresh):
        # use data (train or val) set to optimize threshold during training
        ys, y_hats = [], []
        for i, batch in enumerate(iter(dl)):  # loop so dont run out of memory
            batch[0] = batch[0].to(self.device)  # move fcs/scs to GPU (ligthning doesnt do this for us)
            batch[1] = batch[1].to(self.device)
            ys.append(batch[1])
            y_hats.append(self.shared_step(batch, i))
        y, y_hat = torch.cat(ys, dim=0), torch.cat(y_hats, dim=0)

        # loop over candidate thresholds, see which one optimizes threshold_metric (acc, mcc, mse, etc)
        # over FULL network
        return best_threshold_by_metric(thresholds=threshold_test_points,
                                        adjs=y, preds=y_hat, metric=metric2chooseThresh)

    def compute_loss(self, adjs, adjs_hat, N, bs):
        if self.log_regr:
            loss = F.binary_cross_entropy_with_logits(adjs_hat, adjs, reduction='sum')
        else:
            if self.which_loss == 'mse':
                loss = F.mse_loss(adjs_hat, adjs, reduction='sum')
            elif self.which_loss == 'mae':
                loss = F.l1_loss(adjs_hat, adjs, reduction='sum')
            elif self.which_loss == 'mse+mae':
                loss = F.mse_loss(adjs_hat, adjs, reduction='sum') + F.l1_loss(adjs_hat, adjs, reduction='sum')
            elif self.which_loss == 'cross entropy':
                # adjs_hat_bce = torch.where(adjs_hat>1, 1)
                binary_adjs = torch.where(adjs > 0, 1, 0)
                loss = F.binary_cross_entropy_with_logits(input=adjs_hat, target=binary_adjs)
            elif self.which_loss == 'hinge':
                loss = hinge_loss(y=adjs, y_hat=adjs_hat, margin=self.hinge_margin, slope=self.hinge_slope,  per_edge=False).sum()
            elif self.which_loss == 'hinge+mse':
                loss = hinge_loss(y=adjs, y_hat=adjs_hat, margin=self.hinge_margin, slope=self.hinge_slope, per_edge=False).sum()
                loss += F.mse_loss(adjs_hat, adjs, reduction='sum')
            else:
                raise ValueError(f'which_loss {self.which_loss} not recognized')
        # link to implimentation: https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/regression/linear_regression.py
            # L1 regularizer
        if self.l1_strength > 0:
            l1_reg = sum(param.abs().sum() for param in self.parameters())
            loss += self.l1_strength * l1_reg

            # L2 regularizer
        if self.l2_strength > 0:
            l2_reg = sum(param.pow(2).sum() for param in self.parameters())
            loss += self.l2_strength * l2_reg

        denom = bs * N * N - bs * N  # ignore diagonals
        # sum_to_mean_constant = fcs.shape[0]*fcs.shape[1]*fcs.shape[2] - fcs.shape[1] #remove all zero diagonal from constant?? These will always be zero
        per_edge_loss = loss / denom
        return per_edge_loss

    def construct_prior_(self, prior_dl=None, prior_dtype=torch.float32):
        # constructing prior from data
        if self.prior_construction in ['mean', 'median', 'multi']:
            # create prior with prior_ds set (train set): prior_dl cannot be none
            assert prior_dl is not None, f'to construct prior from data, need train data'
            _, prior_scs, prior_subject_ids, _, _ = prior_dl.dataset.full_ds()
            unique_scs_train_set = filter_repeats(prior_scs, prior_subject_ids)
            if self.prior_construction in ['mean', 'median']:
                prior = construct_prior(unique_scs_train_set, frac_contains=self.prior_frac_contains, reduction=self.prior_construction)
                self.training_prior[0] = prior[0]
                prior_channels = [prior]
                prior_channel_names = [self.prior_construction]
            else: #multi
                self.training_prior[0, 0] = construct_prior(unique_scs_train_set, frac_contains=self.prior_frac_contains, reduction='mean')
                self.training_prior[1, 0] = construct_prior(unique_scs_train_set, frac_contains=self.prior_frac_contains, reduction='median')
                prior_channels = [self.training_prior[0, 0], self.training_prior[1, 0]]
                prior_channel_names = ['mean', 'median']
        # these priors do not use real data. Reconstructed each time, only need graph size N.
        # This is useful when we want to test on different sized data than we trained on: simply feed in test set as
        # val_dl, and prior will be constructed appropriately.
        else:
            if self.prior_construction == 'block':
                block_scale = .35  # minimizes mse
                assert self.n_train % 2 == 0, f'for block prior, n must be even (or in general divisible by number of communities'
                ones = torch.ones((self.n_train // 2), (self.n_train // 2))
                # self.prior = torch.block_diag(ones, ones).view(1, N, N)*block_scale
                prior_channels = [torch.block_diag(ones, ones).view(1, self.n_train, self.n_train) * block_scale]
                prior_channel_names = ['block']
                # for sc in np.arange(0.3, .4, .01):
                #    print(f'scale: {sc}', predicition_metrics(y_hat=sc*torch.block_diag(ones, ones).view(1, N, N).repeat(len(val_scs), 1, 1), y=val_scs, y_subject_ids=val_subject_ids))
            if self.prior_construction == 'sbm':
                prob_matrix = prior_dl.dataset.prob_matrix()
                prior_channels = [prob_matrix.expand(1, self.n_train, self.n_train).to(prior_dtype)]
                prior_channel_names = ['sbm']
            elif self.prior_construction in ['zeros']:
                # self.prior = torch.zeros(1, N, N, dtype=train_fcs.dtype)
                prior_channels = [torch.zeros(1, self.n_train, self.n_train, dtype=prior_dtype)]
                prior_channel_names = ['zeros']
            elif self.prior_construction in ['ones']:
                prior_channels = [torch.ones(1, self.n_train, self.n_train, dtype=prior_dtype)]
                prior_channel_names = ['ones']
            else:
                raise ValueError('unrecognized prior construction arg')
            self.training_prior[0] = prior_channels[0]

        return prior_channels, prior_channel_names

    def prior_prep(self, batch_size, N):
        # called by train_step/val_step/test_step to construct the prior tensor.

        # By NOT using self.training_prior (constructed in setup() during initial training) when we don't
        # need to (e.g. for mean/median), then we can test on graphs of sizes different than those trained on.
        prior_channels = 1 if (not self.share_parameters) else self.prox_layers[0].c_in
        if self.prior_construction == 'zeros':
            return torch.zeros(size=(N, N), device=self.device).expand(prior_channels, batch_size, N, N)
        elif self.prior_construction == 'ones':
            return 0.5*torch.ones(size=(N, N), device=self.device).expand(prior_channels, batch_size, N, N)
        elif self.prior_construction == 'block':
            block_scale = .35  # minimizes mse for brain graphs
            assert (N % 2) == 0, f'block diagram requires even N'
            ones = torch.ones(size=((N // 2), (N // 2)), device=self.device)
            return torch.block_diag(ones, ones).expand(prior_channels, batch_size, N, N) * block_scale
        elif self.prior_construction in ['mean', 'median', 'sbm']:
            # must use training_prior
            assert self.training_prior.shape[-1] == N, f'when using prior constructed from data, we cannot test on data of different size than it'
            return self.training_prior.expand(prior_channels, batch_size, N, N)

    def best_metrics(self, sort_metric, sort_subnetwork='full', top_k=1, maximize=True): #          e['full'][sort_metric]
        sorted_list_of_metrics = sorted(self.list_of_metrics, key = lambda e: e[sort_subnetwork][sort_metric], reverse=maximize)
        return sorted_list_of_metrics[:top_k]

    def extra_repr(self):
        return f'channels {self.channels}# Layers: {self.depth}'

    @torch.no_grad()
    def forward_intermed_outs(self, cov_raw):

        batch_size, N, _ = cov_raw.shape

        # normalize (frob) each covariance matrix in batch
        Cov = cov_raw #normalize_slices(cov_raw, which_norm=self.fc_normalization)
        S_in = torch.broadcast_to(self.prior, (batch_size, N, N))
        S_out = S_in

        S_outs = torch.zeros([self.depth+1, batch_size, N, N], dtype=Cov.dtype)
        H1s = torch.zeros([self.depth, batch_size, N, N], dtype=Cov.dtype)
        H2s = torch.zeros([self.depth, batch_size, N, N], dtype=Cov.dtype)
        S_ins = torch.zeros([self.depth, batch_size, N, N], dtype=Cov.dtype)
        temp_sym_zd = torch.zeros([self.depth, batch_size, N, N], dtype=Cov.dtype)
        S_outs[0] = S_in

        for i, prox_layer in enumerate(self.prox_layers):
            S_out_pre_ReLU, S_ins[i], H1s[i], H2s[i], temp_sym_zd[i] = prox_layer(S_out, Cov, extra_outs=True)
            S_out = F.relu(S_out_pre_ReLU)

            # index is i+1 because save prior at 0th entry
            S_outs[i+1] = S_out

        return {'S_ins': S_ins, 'covs_normed': Cov, 'covs_raw': cov_raw, 'H1s': H1s, 'H2s': H2s, 'temp_sym_zds': temp_sym_zd, 'S_outs': S_outs}

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("Covariance_NN")
        # model ###
        parser.add_argument('--channels', nargs='+')
        parser.add_argument('--where_channel_reduce', type=str)
        #parser.add_argument('--filter_orders', nargs='+')
        add_bool_arg(parser, 'learn_tau', default=False) # TODO remove default
        parser.add_argument('--low_tau_val', type=float)
        parser.add_argument('--learning_rate', type=float)
        parser.add_argument('--momentum', type=float)
        parser.add_argument('--adam_beta_1', type=float)
        parser.add_argument('--adam_beta_2', type=float)

        #parser.add_argument('--weight_decay', type=float)
        parser.add_argument('--l1_strength', type=float)
        parser.add_argument('--l2_strength', type=float)
        parser.add_argument('--l1_cutoff', type=float)
        parser.add_argument('--l2_cutoff', type=float)
        #parser.add_argument('--fc_norm', type=str)
        #parser.add_argument('--fc_norm_val', type=float)

        parser.add_argument('--component_norm', type=str)
        parser.add_argument('--component_norm_val', type=float)
        parser.add_argument('--which_slice_normalization', type=str)
        parser.add_argument('--slice_normalization_val', type=float)
        add_bool_arg(parser, 'norm_last_layer')

        add_bool_arg(parser, 'log_regr', default=False)
        parser.add_argument('--threshold_metric', type=str)  # mcc
        parser.add_argument('--threshold_metric_test_points', type=np.ndarray)

        # prior
        parser.add_argument('--prior_contruction', type=str)
        parser.add_argument('--prior_frac_contains', type=float)

        # trainer
        parser.add_argument('--monitor', type=str)

        # reproducibility
        parser.add_argument('--rand_seed', type=int)
        add_bool_arg(parser, 'train_deterministically', default=False)

        add_bool_arg(parser, 'binarize_labels_for_train', default=False)

        # logging
        parser.add_argument('--logging', type=str) #[ None, 'only_scalars', 'all' ]

        return parent_parser#, model_kwargs


if __name__ == "__main__":
    print('pl_prox_model main loop')
