from __future__ import print_function, division
import torch

import numpy as np
from model.metrics import batch_graph_metrics
from utils.util_funcs import sparsity
from data.brain_data.matlab_to_python_brain_data import apply_subnetwork_mask

DEBUG = False


def sparsity_of_subnetworks(x, subnetwork_mask_dict):
    subnetwork_sparsity = {}
    for subnetwork_mask_name, subnetwork_mask in subnetwork_mask_dict.items():
        x_subnetwork = apply_subnetwork_mask(x, subnetwork_mask)
        subnetwork_sparsity[subnetwork_mask_name] = sparsity(x_subnetwork)
    return subnetwork_sparsity


# go through each subnetowrk with given prior channel
def prediction_metrics_for_each_subnetwork(y, y_hat, threshold, subnetwork_mask_dict, hinge_margin, hinge_slope, reduction):
    subnet_metrics = {}
    for subnetwork_mask_name, subnetwork_mask in subnetwork_mask_dict.items():
        y_hat_subnetwork = apply_subnetwork_mask(y_hat, subnetwork_mask)
        y_subnetwork = apply_subnetwork_mask(y, subnetwork_mask)
        metrics = prediction_metrics(y=y_subnetwork, y_hat=y_hat_subnetwork, threshold=threshold, hinge_margin=hinge_margin, hinge_slope=hinge_slope, reduction=reduction)
        if 'acc' in metrics:
            metrics['error'] = 1-metrics['acc']
        subnet_metrics[subnetwork_mask_name] = metrics
    return subnet_metrics


def prediction_metrics(y, y_hat, threshold, hinge_margin, hinge_slope, reduction):
    assert y.shape == y_hat.shape
    assert y.ndim == 3
    assert y.shape[-1] == y.shape[-2]
    batch_size, N = y.shape[:-1]
    diff = y_hat - y
    # this is now a tensor of shape (batch_size): batch_size scalars representing the mean/sum mse/mae/hinge of that scan
    scan_mses, scan_maes = (diff ** 2).mean(axis=(1, 2)), diff.abs().mean(axis=(1, 2)) # this counts diagonal (in the denominator!)
    hinge = hinge_loss(y=y, y_hat=y_hat, per_edge=True, margin=hinge_margin, slope=hinge_slope)
    if reduction in ['ave', 'mean']:
        mse, mae, hinge = scan_mses.mean(), scan_maes.mean(), hinge.mean()
    else:
        mse, mae, hinge = scan_mses.sum(), scan_maes.sum(), hinge.sum()

    # thresholded y_hat for acc/mcc/...
    # if reduction is sum/mean/median -> these should be tensors of shape (batch_size) as well
    pr, re, f1, macro_F1, acc, mcc = score_graphs_batch(threshold=threshold, adjs=y, preds=y_hat, o=reduction)
    return {'mse': mse, 'mae': mae, 'hinge': hinge, 'acc': acc, 'mcc': mcc, 'macro_F1': macro_F1}


def are_scs_binary(scs):
    return np.all(np.logical_or(scs.flatten() == 1, scs.flatten() == 0))


##### Parameter Printing and Optimization Control #####
def print_model_params(model, vals=True, grads=True, gradients=None):
    np.set_printoptions(precision=9)
    #print("==========")
    for i in range(model.depth):
        module = model.prox_layers[i]
        if gradients is not None:
            level_grads = gradients[i]
        print(f'--prox_layer {i}--')
        for name, param in module.named_parameters():
            print(name)
            print(f'\tvalue    : {param.data.numpy()}')
            if gradients is not None:
                print(f'\tlast grad : {level_grads[name]}')
            elif grads:
                print(f'\tlast grad  : {param.grad.numpy()}')
        print('----------')
    #print("==========")
    np.set_printoptions(precision=2)


def shallowest_layer_all_zero(model):
    #starting from beginning of model, check if the layer output all zeros.
    #  Return layer depth or -1
    for i, module in enumerate(model.prox_layers):
        if module.output_zeros:
            print(f"\t\t{i}th layer!")
            return i
    return -1


def resample_params(module, stdv_scaling=1/3, alpha_mean=1, beta_mean=0):
    for name, param in module.named_parameters():
        with torch.no_grad():
            if ('tau' in name) and module.learn_tau:
                param.copy_(param * 0.1) # reduce it for less all zeros
            if 'alpha' in name:
                new_param = module.stdv_scaling * torch.randn_like(param) + module.init_alpha_sample_mean
                param.copy_(new_param)
            if 'beta' in name:
                new_param = module.stdv_scaling * torch.randn_like(param) + module.init_beta_sample_mean
                param.copy_(new_param)
            if 'poly' in name:
                new_param = module.stdv_scaling * torch.randn_like(param)
                new_param[:, 1, :] = 1
                param.copy_(new_param)


def clamp_tau(model, large_tau):
    for name, W in model.named_parameters():
        if 'tau' in name:
            with torch.no_grad():
                W.clamp_(min=0, max=large_tau)


def construct_named_coeffs_dict(coeffs_dict):
    named_coeffs_dict = {}

    for key in coeffs_dict:
        layer_str = key
        coeffs = coeffs_dict[key]
        for i, c in enumerate(coeffs):
            name = f'c{i}_' + layer_str
            named_coeffs_dict[name] = c

    return named_coeffs_dict


def deep_copy_gradients(model):

    #deep copy each layers gradients. List of dicts with key = 'name',
    # value = np.array of gradients
    layer = []

    with torch.no_grad():
        for i in range(model.depth):
            module = model.prox_layers[i]
            pre_normed_gradients = {}
            for name, param in module.named_parameters():
                #if param.grad is None:
                #   leaf variable, not used in computation for loss
                #   print('\t\t NO GRAD')
                grads = param.grad.numpy()
                deep_copy = np.zeros(grads.shape)
                deep_copy[:] = param.grad.numpy()
                pre_normed_gradients[name] = deep_copy
                #print(f'layer {i}: {name}: grads: {grads},dcg: {deep_copy}')
                layer.append(pre_normed_gradients)
    return layer


def dict_param(model, option='param'):
    ## initialize dictionaries

    #value of each tau coefficient by layer
    #tau_dict = {f'layer_{j}': 0 for j in range(model.depth)}

    coeffs_1_dict, coeffs_2_dict, tau_dict= {}, {}, {}
    for i,pl in enumerate(model.prox_layers):
        key = f'layer_{i}'
        if i>9:
            key = f'layer__{i}'

        if option == 'param':
            coeffs_1_dict[key] = pl.coeffs_1.clone().detach().numpy()
            coeffs_2_dict[key] = pl.coeffs_2.clone().detach().numpy()
            tau_dict[key]      = (pl.tau.clone().detach().numpy())[0]
        else:
            coeffs_1_dict[key] = pl.coeffs_1.grad.clone().detach().numpy()
            coeffs_2_dict[key] = pl.coeffs_2.grad.clone().detach().numpy()
            #if were not learning tau, it will not have tau.grad field
            if pl.tau.grad is None:
                tau_dict[key] = 0
            else:
                tau_dict[key] = (pl.tau.grad.clone().detach().numpy())[0]

    log_regr_dict = None
    if model.log_regr:
        if option == 'param':
            log_regr_dict = {'a': model.log_regr_layer.a.clone().detach().numpy(),
                        'bias': model.log_regr_layer.bias.clone().detach().numpy()}
        else:
            log_regr_dict = {'a': model.log_regr_layer.a.grad.clone().detach().numpy(),
                        'bias': model.log_regr_layer.bias.grad.clone().detach().numpy()}

    return tau_dict, coeffs_1_dict, coeffs_2_dict, log_regr_dict
##############


##### Experiment with threshold to optimize metric #####
# take a a list of thresholds to try, and return the metrics  at each threshold
def metrics_at_thresholds(thresholds, adjs, preds, o='ave'):
    num_points = len(thresholds)
    if o == 'raw':
        precisions, recalls, f1s, macro_f1s, accs, mccs \
            = num_points*[None], num_points*[None], num_points*[None], num_points*[None], num_points*[None], num_points*[None]
    else:
        if torch.is_tensor(adjs) and torch.is_tensor(preds):
            precisions, recalls, f1s, macro_f1s, accs, mccs \
                = torch.zeros(num_points), torch.zeros(num_points), torch.zeros(num_points), torch.zeros(num_points), torch.zeros(num_points), torch.zeros(num_points)
        elif type(adjs) == np.ndarray and type(preds) == np.ndarray:
            precisions, recalls, f1s, macro_f1s, accs, mccs \
                = np.zeros(num_points), np.zeros(num_points), np.zeros(num_points), np.zeros(num_points), np.zeros(num_points), np.zeros(num_points)
        else:
            raise ValueError(f'adjs and preds must have same type: adjs {type(adjs)} preds {type(preds)}')

    for i, threshold in enumerate(thresholds):
        # score_graphs_batch internally converts non-binarized scs into binarized for calculatison
        precisions[i], recalls[i], f1s[i], macro_f1s[i], accs[i], mccs[i] \
            = score_graphs_batch(threshold, adjs=adjs, preds=preds, o=o)
    return precisions, recalls, f1s, macro_f1s, accs, mccs


# Given list of thresholds, see which one optimizes given metric
def best_threshold_by_metric(thresholds, adjs, preds, metric='acc'):

    ave_precisions, ave_recalls, ave_f1s, ave_macro_f1s, ave_accs, ave_mccs = \
        metrics_at_thresholds(thresholds=thresholds, adjs=adjs, preds=preds, o='ave')

    if metric in ['acc', 'accs']:
        data = ave_accs
    elif metric in ['err', 'error']:
        data = 1-ave_accs
        # we wan to minimize this:
        return thresholds[np.nanargmin(data)]
    elif metric in ['mcc', 'mccs']:
        data = ave_mccs
    elif metric in ['macro_f1', 'macro-f1', 'macro_f1s', 'macro-f1s']:
        data = ave_macro_f1s
    elif metric in ['f1', 'f1s', 'fmsr', 'f_msr', 'fmsrs', 'f_msrs']:
        data = ave_f1s
    elif metric == 're':
        data = ave_recalls
    elif metric == 'pr':
        data = ave_precisions
    else:
        raise ValueError(f'unrecognized metric {metric}\n')

    # ignores nan in max
    return thresholds[np.nanargmax(data)]
##############


# returns hinge_loss of each scan in tensor
# FOR y in {0,+1} NOT {-1, +1}
def hinge_loss(y, y_hat, margin=0.2, already_bin=False, per_edge=True, slope=1):
    y_bin = y if already_bin else (y > 0)
    if torch.is_tensor(y_bin):
        assert y_bin.dtype == torch.bool
    else:
        assert y_bin.dtype == np.bool
    assert (y.ndim == 3) and (y.shape[-2]==y.shape[-1])
    assert (y_hat.ndim == 3) and (y_hat.shape[-2]==y_hat.shape[-1])
    loss_when_label_zero = torch.maximum(torch.zeros_like(y_hat), y_hat - margin) # assume all y_hat >= 0
    loss_when_label_one = torch.maximum(torch.zeros_like(y_hat), -y_hat + (1 - margin))
    hinge_loss = torch.where(condition=y_bin, input=loss_when_label_one, other=loss_when_label_zero) # outputs input where true
    hinge_loss = slope*hinge_loss # more slope = more punishment of errorprediction_metrics_for_each_subnetwork
    hinge_loss_per_scan = torch.sum(hinge_loss, dim=(1, 2))

    # each hinge loss is now normalized by number of possible edges: hinge loss/# possible edges
    if per_edge:
        bs, N = y.shape[0:2]
        total_possible_edges = N * N - N  # ignore diagonals
        return hinge_loss_per_scan/total_possible_edges
    else:
        return hinge_loss_per_scan


###### scoring graphs - precision, recall, f-measure, acc, error, etc ######
# TODO: add sparsity!!
def score_graphs_batch(threshold:float, adjs, preds, o="ave", ignore_diagonal=True):
    o = 'ave' if o == 'mean' else o
    if torch.is_tensor(adjs) and torch.is_tensor(preds):
        return score_graphs_batch_torch(threshold=threshold, adjs=adjs, preds=preds, o=o, ignore_diagonal=ignore_diagonal)
    elif type(adjs) == np.ndarray and type(preds) == np.ndarray:
        return score_graphs_batch_np(threshold=threshold, adjs=adjs, preds=preds, o=o, ignore_diagonal=ignore_diagonal)
    else:
        raise ValueError(f'adjs and preds must be same type: adjs {type(adjs)} preds {type(preds)}')


def score_graphs_batch_torch(threshold: float, adjs: torch.tensor, preds: torch.tensor, o="ave", ignore_diagonal=True):
    assert torch.allclose(torch.tensor(adjs.shape), torch.tensor(preds.shape))
    adjs = adjs.detach()
    preds = preds.detach()

    if adjs.ndim == 2:
        adjs = torch.unsqueeze(adjs, dim=0)

    Ar = (preds > threshold)  # recovered = predicted
    Ag = (adjs > 0)  # ground truth

    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x=Ar, y=Ag, ignore_diagonal=ignore_diagonal,
                                                         graph_or_edge='graph')
    if o in ["ave", "mean"]:
        num_non_nan_mcc = torch.sum(~torch.isnan(mcc))
        return torch.mean(pr), torch.mean(re), torch.mean(f1), torch.mean(macro_f1), torch.mean(acc), torch.nansum(mcc)/num_non_nan_mcc
    elif o == "sum":
        return torch.sum(pr), torch.sum(re), torch.sum(f1), torch.sum(macro_f1), torch.sum(acc), torch.nansum(mcc)
    elif o == "median":
        return torch.median(pr), torch.median(re), torch.median(f1), torch.median(macro_f1), torch.median(acc), torch.nanmedian(mcc)
    elif o == "raw":
        return pr, re, f1, macro_f1, acc, mcc
    else:
        print(f'score_graphs_batch: invalid option o given {o}')
        exit(1)


def score_graphs_batch_np(threshold:float, adjs: np.ndarray, preds: np.ndarray, o="ave", ignore_diagonal=True):
    assert torch.allclose(torch.tensor(adjs.shape), torch.tensor(preds.shape))

    if len(adjs.shape) == 2:
        adjs = np.expand_dims(adjs, axis=0)
        adjs = np.repeat(adjs, 1, axis=0)

    Ar = (preds > threshold)# recovered = predicted
    Ag = (adjs > 0)# ground truth

    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x=Ar, y=Ag, ignore_diagonal=ignore_diagonal, graph_or_edge='graph')
    if o == "ave":
        return np.mean(pr), np.mean(re), np.mean(f1), np.mean(macro_f1), np.mean(acc), np.nanmean(mcc)
    elif o == "sum":
        return np.sum(pr), np.sum(re), np.sum(f1), np.sum(macro_f1), np.sum(acc), np.nansum(mcc)
    elif o == "median":
        return np.median(pr), np.median(re), np.median(f1), np.median(macro_f1), np.median(acc), np.nanmedian(mcc)
    elif o == "raw":
        return pr, re, f1, macro_f1, acc, mcc
    else:
        print(f'score_graphs_batch: invalid option o given {o}')
        exit(1)


##### General matrix/tensor utils #####
def is_matrix_symmetric(a, rtol=1e-05, atol=1e-08):
    isSym = np.allclose(a, a.T, rtol=rtol, atol=atol)
    indxs = np.array([])
    diff = np.zeros(a.shape)
    if not isSym:
        diff          = np.absolute(np.subtract(a,a.T))
        not_sym_elems = np.zeros(a.shape)
        not_sym_elems[np.nonzero(diff)] = 1
        #print(f'show where not sym : \n{not_sym_elems}')
        ixs, iys = np.nonzero(diff)
        indxs = np.array(list(zip(ixs,iys)))
    return isSym, indxs, diff


its = 'is_tensor_symmetric:'
def is_tensor_symmetric(a, rtol=1e-05, atol=1e-08):
    #print(f'a  : {a.shape}')
    #print(f'a.T: {(a.T).shape}')
    assert len(a.shape)==3, f'{its} passed matrix to tensor func'
    indxs = np.array([])
    diff = np.zeros(a.shape)

    tensor, batch_size = True, a.shape[0]
    for i in range(batch_size):
        isSym, indxs, diff = is_matrix_symmetric(a[i,:,:])
        if(not isSym):
            #print(f'{its}: {i}th slice of tensor is not symmetric')
            return isSym, indxs, diff

    return True, indxs, diff


def construct_prior(train_scs: torch.Tensor, frac_contains: float = 1.0, reduction: str = "median", plt_prior: bool = False):
    # construct prior based *only* on training set
    # then use prior to directly predict sc structure on validation and test set for baseline
    # frac_contains :: float in [0,1] determining what fraction of incoming scs should contain an edge
    #   in order to keep it in the prior

    # find intersection mask
    prior_edge_mask = (train_scs > 0).sum(dim=0) >= (frac_contains * len(train_scs))

    # find edge values
    if reduction == 'median':
        [prior_edge_values, _] = train_scs.median(dim=0)
    elif reduction == 'mean':
        prior_edge_values = train_scs.mean(dim=0)
    else:
        raise ValueError(f'reduction {reduction} not implimented')

    # only keep edge value in mask
    prior = torch.where(prior_edge_mask, prior_edge_values, torch.zeros(1))

    return prior.view(1, prior.shape[-2], prior.shape[-1])


if __name__ == "__main__":
    print('model_utils main loop')

