from __future__ import print_function, division
import numpy as np
import torch
from utils import adj2vec

def perfect_predictions(tp, tn, fp, fn):
    return (tp+tn) == (tp+tn+fp+fn)


def all_incorrect_predictions(tp, tn, fp, fn):
    return (fp+fn) == (tp+tn+fp+fn)


# performance metrics: These functions take two 1D vectors and output scalar metric
def accuracy(tp, tn, total):
    return (tp+tn) / total


def matthew_corr_coef(tp, tn, fp, fn, out_type=torch.float32, large_nums=True):
    # formula: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
    """
        if perfect_predictions(tp=tp, tn=tn, fp=fp, fn=fn):
            return 1
        elif all_incorrect_predictions(tp=tp, tn=tn, fp=fp, fn=fn):
            return -1
    """

    # define MCC on these failure cases
    perfect_predict_mask = perfect_predictions(tp=tp, tn=tn, fp=fp, fn=fn)
    all_incorrect_predict_mask = all_incorrect_predictions(tp=tp, tn=tn, fp=fp, fn=fn)

    # will be True everywhere that doesn't have perfect/worst posssible prediction
    # still possible to divide by 0
    user_define_mcc_behvr_mask = torch.logical_or(perfect_predict_mask, all_incorrect_predict_mask)
    where_to_divide = torch.logical_not(user_define_mcc_behvr_mask)

    mcc_numerator = (tp * tn) - (fp * fn)
    if not large_nums:
        # the argument to torch.sqrt can easily overflow for large graphs. Compute is elementwise.
        mcc_denom = torch.sqrt( (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) )
    if large_nums or torch.any(torch.isnan(mcc_denom)):
        if not large_nums:
            print(f'WARNING: When computing MCC, got overflow. Attempting to recompute. Try reducing batch size.')
        # each column is a different sample, each row are the sum values (e.g. tp + fp)
        sums = torch.stack(((tp + fp) , (tp + fn) , (tn + fp) , (tn + fn)))

        a = torch.floor(torch.sqrt(sums))
        mcc_denom = torch.prod(a, dim=0) * torch.sqrt( torch.prod(sums/a**2, dim=0) )

        if not large_nums:
            if torch.any(torch.isnan(mcc_denom)):
                print(f'\tAttempted to recover, but still NAN in MCC computation')
            else:
                print(f'\tSuccesffuly recovered: no NAN values in MCC')

    out = torch.zeros_like(tp, dtype=out_type)
    out[perfect_predict_mask] = 1
    out[all_incorrect_predict_mask] = -1

    """
    if np.any(mcc_denom == 0):
        txt = "Undefined MCC: "
        if np.isclose(tp+fp, 0):
            txt += "0 Positive Preds,  "
        if np.isclose(tp + fn, 0):
            txt += "0 Actual Positives,  "
        if np.isclose(tn + fp, 0):
            txt += "0 Actual Negatives, "
        if np.isclose(tn + fn, 0):
            txt += "0 Negative Preds"
        raise ZeroDivisionError("MCC Denom is 0")
        """
    return torch.where(where_to_divide,  mcc_numerator/mcc_denom, out)
    #return torch.divide(mcc_numerator, mcc_denom, out=out, where=where_to_divide)


def precision(tp, fp, eps: float = 1e-12):
    return tp / (tp + fp + eps)


def recall(tp, fn, eps: float = 1e-12):
    return tp / (tp + fn + eps)


def f1(tp, fp, fn, eps: float = 1e-12):
    return tp / (tp + .5*(fp+fn) + eps)


def macro_f1(tp, tn, fp, fn, eps: float = 1e-12):
    return (f1(tp=tp, fp=fp, fn=fn, eps=eps) + f1(tp=tn, fp=fp, fn=fn, eps=eps))/2


def fdr(tp, fp, eps: float = 1e-12):
    return fp / (fp + tp + eps)


def fpr(tn, fp, eps: float = 1e-12):
    return fp / (fp + tn + eps)





# From raw prediction and labels, compute tp/tn/fp/fn
def confusion_matrix_unsigned(y_hat, y, reduction_axes):
    if torch.is_tensor(y_hat) and torch.is_tensor(y):
        assert y_hat.shape == y.shape
        tp = torch.sum((y_hat == y) & (y != 0), dim=reduction_axes)
        tn = torch.sum((y_hat == y) & (y == 0), dim=reduction_axes)
        fp = torch.sum((y_hat != y) & (y == 0), dim=reduction_axes)
        fn = torch.sum((y_hat != y) & (y != 0), dim=reduction_axes)
    elif type(y_hat) == np.ndarray and type(y) == np.ndarray:
        assert len(y_hat) == len(y)
        tp = np.sum((y_hat == y) & (y != 0), axis=reduction_axes)
        tn = np.sum((y_hat == y) & (y == 0), axis=reduction_axes)
        fp = np.sum((y_hat != y) & (y == 0), axis=reduction_axes)
        fn = np.sum((y_hat != y) & (y != 0), axis=reduction_axes)
    else:
        raise ValueError(f'y_hat and y must be same type: y_hat {type(y_hat)}, y {type(y)}')

    return tp, tn, fp, fn


def confusion_matrix(y_hat, y, reduction_axes):
    # y_hats, y \in {-1, 0, 1}
    assert y_hat.shape == y.shape
    assert (y == 1).logical_or(y == -1).logical_or(y == 0).all(), 'only values y can take are {0, -1, 1}'
    y_zero, y_hat_zero = (y == 0), (y_hat == 0)
    y_nonzero, y_hat_nonzero = ~y_zero, ~y_hat_zero

    tp = torch.sum((y_hat == y) & y_nonzero, dim=reduction_axes)
    tn = torch.sum(y_hat_zero & y_zero, dim=reduction_axes)
    fp = torch.sum(y_hat_nonzero & y_zero, dim=reduction_axes)
    fn = torch.sum(y_hat_zero & y_nonzero, dim=reduction_axes)
    return tp, tn, fp, fn


# Aggregate relevent metrics into dict
def classification_metrics(y_hat: torch.tensor, y: torch.tensor, threshold: float, non_neg: bool, graph_or_edge='graph'):
    #assert ('int' in str(y_hat.dtype) and 'int' in str(y.dtype)) or (y_hat.dtype == torch.bool and y.dtype == torch.bool), f'only take binary/integer inputs -> must threshold first'
    y_hat, y = y_hat.squeeze(), y.squeeze()
    assert (y_hat.shape == y.shape), 'inputs must be same size'

    if not non_neg:
        # first use threshold to remove small edge values (in pos or neg dir), then convert back to o.g. sign
        # signed metrics - other things (e.g. f1) not well defined when values have > 2 states
        y_hat_thresh = y_hat.clone()
        y_hat_thresh[y_hat_thresh.abs() < threshold] = 0.0
        reduction_axes = (1, 2) if y.ndim == 3 else 1
        signed_acc = (y.sign() == y_hat_thresh.sign()).to(y.dtype).mean(dim=reduction_axes)
        return {'acc': signed_acc, 'error': 1-signed_acc}

    # any y_hat values in [-thresh, thresh] -> 0
    y_hat, y = y_hat.sign() * (y_hat.abs() > threshold), y.sign()# * (y.abs() > 0)
    y_hat, y = y_hat.to(torch.int), y.to(torch.int)

    if y.ndim == 3:
        y, y_hat = adj2vec(y), adj2vec(y_hat)
    assert y_hat.ndim == y.ndim == 2
    batch_size, num_possible_edges = y_hat.shape
    num_graphs = batch_size

    if graph_or_edge == 'graph':
        reduction_axes = 1
        tp, tn, fp, fn = confusion_matrix(y_hat=y_hat, y=y, reduction_axes=reduction_axes)
        total = num_possible_edges

    elif graph_or_edge == 'edge':
        reduction_axes = 0
        tp, tn, fp, fn = confusion_matrix(y_hat=y_hat, y=y, reduction_axes=reduction_axes)
        total = num_graphs

    else:
        print(f'unrecognized graph_or_edge {graph_or_edge}\n')
        exit(2)

    # recall = true positive rate (tpr), precision = positive predictive value (ppv)
    # unsigned metrics
    acc = accuracy(tp=tp, tn=tn, total=total)
    return {'pr': precision(tp=tp, fp=fp), 're': recall(tp=tp, fn=fn), 'f1': f1(tp=tp, fp=fp, fn=fn),
            'macro_f1': macro_f1(tp=tp, tn=tn, fp=fp, fn=fn),
            'acc': acc, 'error': 1 - acc,
            'mcc': matthew_corr_coef(tp=tp, tn=tn, fp=fp, fn=fn)}


def regression_metrics(y_hat: torch.tensor, y: torch.tensor, self_loops=False): # raw_adj=False):
    # self_loops: if self loops - take full difference between full adjacency matrices -> includes diagonal and double
    #  counts edges (in undirected graphs)
    # - if no self_loops -> should be using simplified adjacency -> vectorized form
    if y_hat.ndim == 3:
        y_hat = y_hat.squeeze()
    if y.ndim == 3:
        y = y.squeeze()
    assert (y_hat.shape == y.shape), 'inputs must be same size, subtraction operator broadcasts incorrectly if not'
    assert (self_loops and (y.ndim == 3)) or ((not self_loops) and (y.ndim == 2))

    reduction_axes = (1, 2) if self_loops else 1
    diff = (y_hat-y)
    se, ae = torch.square(diff).sum(dim=reduction_axes), torch.abs(diff).sum(dim=reduction_axes)
    nse = se / (y ** 2).sum(dim=reduction_axes)
    se_per_edge, ae_per_edge = torch.square(diff).mean(dim=reduction_axes), torch.abs(diff).mean(dim=reduction_axes)
    #nse_glad = 10*torch.log10(se.sum() / label_sizes.sum())
    """
    if not raw_adj:
        assert y_hat.ndim == y.ndim == 2
        reduction_axes = 1 if graph_or_edge=='graph' else 0
        mse, mae = torch.sum(se, dim=reduction_axes), torch.sum(ae, dim=reduction_axes)
        nmse = torch.divide(mse, torch.sum(y**2, dim=reduction_axes))
    else:
        assert y_hat.ndim == y.ndim == 3
        reduction_axes = (1, 2)
        mse, mae = torch.sum(se, dim=(1, 2)), torch.sum(ae, dim=reduction_axes)
        nmse = torch.divide(mse, torch.sum(y**2, dim=reduction_axes))
    """
    return {'se': se, 'ae': ae, 'nse': nse, '10_log_nse': 10*torch.log10(nse), 'se_per_edge': se_per_edge, 'ae_per_edge': ae_per_edge}


def compute_metrics(y_hat: torch.tensor, y: torch.tensor, threshold, self_loops, non_neg=True): #, use_full_adj_regr=False):
    # compute metrics for cases:
    # 1) signed edges: when the edge values can be positive or negative, we the case where we pred +1 and true is -1
    # (and vica versa) is not well defined as tp/tn/fp/fn. Use subset of  metrics.
    # 2) non-negative edges: simple case where all well defined
    # Better term for non_neg ==> homogonous sign. Are all values of interest non_neg or non_pos?
    y_hat, y = y_hat.squeeze(), y.squeeze()
    assert y_hat.shape == y.shape
    if (not self_loops) and y.ndim == 3:
        y, y_hat = adj2vec(y), adj2vec(y_hat)

    return {**classification_metrics(y_hat=y_hat, y=y, threshold=threshold, non_neg=non_neg),
            **regression_metrics(y_hat=y_hat, y=y, self_loops=self_loops)} # raw_adj=use_full_adj_regr)}


def hinge_loss(y, y_hat, margin, per_edge=True, slope=1):
    # returns hinge_loss of each scan in tensor
    # FOR y in {0,+1} NOT {-1, +1}
    assert y.dtype == torch.bool
    if y_hat.ndim == 3:
        y_hat = y_hat.squeeze()
    if y.ndim == 2:
        y = y.squeeze()
    assert y.shape == y_hat.shape and y.ndim == 2, f'adjs must be in row form'
    loss_when_label_zero = torch.maximum(torch.zeros_like(y_hat), y_hat - margin) # assume all y_hat >= 0
    loss_when_label_one = torch.maximum(torch.zeros_like(y_hat), -y_hat + (1 - margin))
    hinge_loss = torch.where(condition=y, input=loss_when_label_one, other=loss_when_label_zero) # outputs input where true
    hinge_loss = slope*hinge_loss # more slope = more punishment of error

    return torch.mean(hinge_loss, dim=1) if per_edge else torch.sum(hinge_loss, dim=1)


def best_threshold_by_metric(y_hat: torch.tensor, y: torch.tensor, thresholds: torch.tensor, metric: str, non_neg: bool= True):
    # Given list of thresholds, see which one optimizes given metric
    assert metric in ['acc', 'error', 'f1', 'pr', 're', 'mcc'], f'given {metric}'

    values = torch.zeros(len(thresholds))
    for i, threshold in enumerate(thresholds):
        # v contains a metric value for each sample. Need to reduct
        values[i] = classification_metrics(y_hat=y_hat, y=y, threshold=threshold, non_neg=non_neg)[metric].mean()
        #values[i] = torch.mean(classification_metrics(y_hat=(y_hat > threshold), y=(y > 0.0))[metric])

    # mcc will return nan if all predictions True or False. Handle this by not considering these cases.
    if metric in ['err', 'error']:
        return thresholds[np.nanargmin(values)]
    else:
        return thresholds[np.nanargmax(values)]


######  TESTS  #####
def confusion_matrix_tests():
    # unsigned input
    y = torch.tensor([0, 0, 1, 1]).view(1, -1)>0
    p = torch.tensor([0, 1, 0, 1]).view(1, -1)>0
    tp_, tn_, fp_, fn_ = 1, 1, 1, 1
    tp, tn, fp, fn = confusion_matrix(y_hat=p, y=y, reduction_axes=1)
    assert torch.allclose(torch.tensor([tp_, tn_, fp_, fn_]), torch.tensor([tp, tn, fp, fn]))

    # signed input
    y = torch.tensor([0, 0, 0, -1, -1, -1, 1, 1, 1]).view(1, -1).expand(2, -1)
    p = torch.tensor([0, -1, 1, 0, -1, 1, 0, -1, 1]).view(1, -1).expand(2, -1)
    tp_, tn_, fp_, fn_ = 2, 1, 2, 2
    true = torch.tensor([tp_, tn_, fp_, fn_]).repeat(2, 1)
    tp, tn, fp, fn = confusion_matrix(y_hat=p, y=y, reduction_axes=1)
    out = torch.stack((tp, tn, fp, fn)).t()
    assert torch.allclose(true, out)


def classification_metrics_tests():
    # unsigned input
    y = (torch.tensor([0, 0, 1, 1])>0).view(1, -1).expand(2, -1)
    p = (torch.tensor([0, 1, 0, 1])>0).view(1, -1).expand(2, -1)
    tp_, tn_, fp_, fn_ = 1, 1, 1, 1
    true = torch.tensor([tp_, tn_, fp_, fn_]).repeat(2, 1)
    tp, tn, fp, fn = confusion_matrix(y_hat=p, y=y, reduction_axes=1)
    out = torch.stack((tp, tn, fp, fn)).t()
    assert torch.allclose(true, out)
    cm = compute_metrics(y_hat=p+0, y=y+0, threshold=0, non_neg=True)
    #cm = classification_metrics(y_hat=p, y=y, signed=True)
    true = torch.tensor([1/2, 1/2]).repeat(2, 1)
    out = torch.stack([cm['acc'], cm['error']]).t()
    assert torch.allclose(true, out)

    # signed input
    y = torch.tensor([0, 0, 0, -1, -1, -1, 1, 1, 1]).view(1, -1).repeat(2, 1)
    p = torch.tensor([0, -1, 1, 0, -1, 1, 0, -1, 1]).view(1, -1).repeat(2, 1)
    tp_, tn_, fp_, fn_ = 2, 1, 2, 2 # fn is key here -> pred -1 but actually +1...not included in fn
    true = torch.tensor([tp_, tn_, fp_, fn_]).repeat(2, 1)
    tp, tn, fp, fn = confusion_matrix(y_hat=p, y=y, reduction_axes=1)
    out = torch.stack((tp, tn, fp, fn)).t()
    assert torch.allclose(true, out)
    # now make continuos valued, inlcude thresholding
    p = p.to(torch.float32)
    p_contin = torch.zeros_like(p)
    p_contin[p < 0] = p[p < 0] - 1.67
    p_contin[p > 0] = p[p > 0] + 1.67
    p_contin[p == 0] = p[p == 0] + .02
    cm = compute_metrics(y_hat=p_contin, y=y, threshold=.03, non_neg=False)
    true = torch.tensor([1/3, 2/3]).repeat(2, 1)
    out = torch.stack([cm['acc'], cm['error']]).t()
    assert torch.allclose(true, out)


if __name__ == "__main__":
    confusion_matrix_tests()
    classification_metrics_tests()

    torch.manual_seed(50)
    a = torch.randint(low=0, high=2, size=(2, 5))
    b = a.clone()
    b[0, 0] = b[1, 1] = 1
    tp, tn, fp, fn = confusion_matrix(a, b, reduction_axes=1)
    metrics = classification_metrics(a.to(torch.bool), b.to(torch.bool), graph_or_edge='graph')

    print(a,'\n', b)
    print(metrics)