from __future__ import print_function, division
import numpy as np
from sklearn.metrics import accuracy_score
from numba import njit
from numba_optimizations.fast_funcs import return_matrix_diag
import math
import time
from typing import Union
import torch
DEBUG = False
import warnings
from torch.nn import functional as F


##########################
# These functions take two 1D vectors and output scalar metric
#@njit
def accuracy(tp, tn, total):
    return (tp+tn) / total

#@njit
def accuracy_raw(x: np.ndarray, y: np.ndarray):
    # takes raw inputs
    tp = np.sum((x > 0) & (y > 0))
    tn = np.sum((x == 0) & (y == 0))
    return accuracy(tp, tn, len(x))

#@njit
def perfect_predictions(tp, tn, fp, fn):
    return (tp+tn) == (tp+tn+fp+fn)
#@njit
def all_incorrect_predictions(tp, tn, fp, fn):
    return (fp+fn) == (tp+tn+fp+fn)


#@njit
#def mcc_along_axes(tp, tn, fp, fn, reduction_axes):



#@njit

def matthew_corr_coef(tp, tn, fp, fn):
    if torch.is_tensor(tp) and torch.is_tensor(tn) and torch.is_tensor(fp) and torch.is_tensor(fn):
        return matthew_corr_coef_torch(tp=tp, tn=tn, fp=fp, fn=fn, out_type=torch.float32)
    elif type(tp)==np.ndarray and type(tn)==np.ndarray and type(fp)==np.ndarray and type(fn)==np.ndarray:
        return matthew_corr_coef_np(tp=tp, tn=tn, fp=fp, fn=fn, out_type=np.float32)
    elif np.issubdtype(type(tp), np.int) and np.issubdtype(type(tn), np.int) and np.issubdtype(type(fp), np.int) and np.issubdtype(type(fn), np.int):
        return matthew_corr_coef_np(tp=tp, tn=tn, fp=fp, fn=fn, out_type=np.float32)
    else:
        raise ValueError(f'inputs must be same type: tp/tn/fp/fn: {type(tp)}/{type(tn)}/{type(fp)}/{type(fn)}')


def matthew_corr_coef_torch(tp, tn, fp, fn, out_type=torch.float32, large_nums=True):
    # formula: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
    """
        if perfect_predictions(tp=tp, tn=tn, fp=fp, fn=fn):
            return 1
        elif all_incorrect_predictions(tp=tp, tn=tn, fp=fp, fn=fn):
            return -1
    """

    # define MCC on these failure cases
    perfect_predict_mask = perfect_predictions(tp=tp, tn=tn, fp=fp, fn=fn)
    all_incorrect_predict_mask = all_incorrect_predictions(tp=tp, tn=tn, fp=fp, fn=fn)

    # will be True everywhere that doesn't have perfect/worst posssible prediction
    # still possible to divide by 0
    user_define_mcc_behvr_mask = torch.logical_or(perfect_predict_mask, all_incorrect_predict_mask)
    where_to_divide = torch.logical_not(user_define_mcc_behvr_mask)

    mcc_numerator = (tp * tn) - (fp * fn)
    if not large_nums:
        # the argument to torch.sqrt can easily overflow for large graphs. Compute is elementwise.
        mcc_denom = torch.sqrt( (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) )
    if large_nums or torch.any(torch.isnan(mcc_denom)):
        if not large_nums:
            print(f'WARNING: When computing MCC, got overflow. Attempting to recompute. Try reducing batch size.')
        # each column is a different sample, each row are the sum values (e.g. tp + fp)
        sums = torch.stack(((tp + fp) , (tp + fn) , (tn + fp) , (tn + fn)))

        a = torch.floor(torch.sqrt(sums))
        mcc_denom = torch.prod(a, dim=0) * torch.sqrt( torch.prod(sums/a**2, dim=0) )

        if not large_nums:
            if torch.any(torch.isnan(mcc_denom)):
                print(f'\tAttempted to recover, but still NAN in MCC computation')
            else:
                print(f'\tSuccesffuly recovered: no NAN values in MCC')

    out = torch.zeros_like(tp, dtype=out_type)
    out[perfect_predict_mask] = 1
    out[all_incorrect_predict_mask] = -1

    """
    if np.any(mcc_denom == 0):
        txt = "Undefined MCC: "
        if np.isclose(tp+fp, 0):
            txt += "0 Positive Preds,  "
        if np.isclose(tp + fn, 0):
            txt += "0 Actual Positives,  "
        if np.isclose(tn + fp, 0):
            txt += "0 Actual Negatives, "
        if np.isclose(tn + fn, 0):
            txt += "0 Negative Preds"
        raise ZeroDivisionError("MCC Denom is 0")
        """
    return torch.where(where_to_divide,  mcc_numerator/mcc_denom, out)
    #return torch.divide(mcc_numerator, mcc_denom, out=out, where=where_to_divide)

def matthew_corr_coef_np(tp, tn, fp, fn, out_type=np.float32):
    #formula: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
    """
        if perfect_predictions(tp=tp, tn=tn, fp=fp, fn=fn):
            return 1
        elif all_incorrect_predictions(tp=tp, tn=tn, fp=fp, fn=fn):
            return -1
    """

    # define MCC on these failure cases
    perfect_predict_mask = perfect_predictions(tp=tp, tn=tn, fp=fp, fn=fn)
    all_incorrect_predict_mask = all_incorrect_predictions(tp=tp, tn=tn, fp=fp, fn=fn)


    # will be True everywhere that doesn't have perfect/worst posssible prediction
    # still possible to divide by 0
    user_define_mcc_behvr_mask = np.logical_or(perfect_predict_mask, all_incorrect_predict_mask)
    where_to_divide = np.logical_not(user_define_mcc_behvr_mask)


    mcc_numerator = (tp * tn) - (fp * fn)
    mcc_denom = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))

    out = np.zeros_like(tp, dtype=out_type)
    out[perfect_predict_mask] = 1
    out[all_incorrect_predict_mask] = -1

    """
    if np.any(mcc_denom == 0):
        txt = "Undefined MCC: "
        if np.isclose(tp+fp, 0):
            txt += "0 Positive Preds,  "
        if np.isclose(tp + fn, 0):
            txt += "0 Actual Positives,  "
        if np.isclose(tn + fp, 0):
            txt += "0 Actual Negatives, "
        if np.isclose(tn + fn, 0):
            txt += "0 Negative Preds"
        raise ZeroDivisionError("MCC Denom is 0")
        """
    return np.divide(mcc_numerator, mcc_denom, out=out,  where=where_to_divide)


#@njit
def precision_(tp, fp, eps: float = 1e-12):
    return tp / (tp + fp + eps)


#@njit
def recall_(tp, fn, eps: float = 1e-12):
    return tp / (tp + fn + eps)


#@njit
def f1_(tp, fp, fn, eps: float = 1e-12):
    return tp / (tp + .5*(fp+fn) + eps)


#@njit
def macro_f1_(tp, tn, fp, fn, eps: float = 1e-12):
    return (f1_(tp=tp, fp=fp, fn=fn, eps=eps) + f1_(tp=tn, fp=fp, fn=fn, eps=eps))/2

#for this to work, might need subfunctions to be njitted
#@njit
def pr_re_f1s_acc(tp, tn, fp, fn):
    pr = precision_(tp=tp, fp=fp)
    re = recall_(tp=tp, fn=fn)
    f1 = f1_(tp=tp, fp=fp, fn=fn)
    macro_f1 = macro_f1_(tp=tp, tn=tn, fp=fp, fn=fn)
    acc = accuracy(tp=tp, tn=tn, total=(tp+tn+fp+fn))
    return pr, re, f1, macro_f1, acc

# x and y both 1D
def binary_classification_metrics(x: np.ndarray, y: np.ndarray):
    tp, tn, fp, fn = confusion_matrix(x, y)
    pr, re, f1, macro_f1, acc = pr_re_f1s_acc(tp=tp, tn=tn, fp=fp, fn=fn)
    mcc = matthew_corr_coef(tp=tp, tn=tn, fp=fp, fn=fn)
    """
    try
    except FloatingPointError:
        # 4 cases could have caused failure (possibly at once)
        # 1: all Actual Positives
        # 2: all Actual Negatives
        # 3: all Pred Positives
        # 4: all Pred Negatives
        if np.sum(x) in [0, len(x)]:
            print(f"\n\nMCC WARNING divide by 0: All pred {np.sum(x)}'s")
        if np.sum(y) in [0, len(y)]:
            print(f"\n\nMCC WARNING divide by 0: All actual {np.sum(x)}'s")
        """
    return {'accuracy': acc, 'precision': pr, 'recall': re, 'f1': f1, 'macro-f1': macro_f1, 'mcc': mcc}
##########################


##########################
# These functions apply above funcs to each batch_axis vectors (groups of same edges)
def edge_mse(Ar, Ag):
    assert Ar.shape == Ag.shape, f'inputs must be same size: {Ar.shape} vs {Ag.shape}'
    assert len(Ar.shape) == len(Ag.shape) == 3, f'inputs must be 3D'
    assert (Ar.shape[1] == Ar.shape[2]) and (Ag.shape[1] == Ag.shape[2]), f"inputs must be square: Ar {Ar.shape}, Ag {Ag.shape}"

    sq = (Ar-Ag)**2
    edge_mse = np.mean(sq, axis=0)

    return edge_mse


def mse_score(Ar, Ag, ignore_diagonal=True):
    assert Ar.shape == Ag.shape, f'inputs must be same size: {Ar.shape} vs {Ag.shape}'
    assert len(Ar.shape) == len(Ag.shape) == 3, f'inputs must be 3D'
    assert (Ar.shape[1] == Ar.shape[2]) and (Ag.shape[1] == Ag.shape[2]), f"inputs must be square: Ar {Ar.shape}, Ag {Ag.shape}"
    bs, N, _ = Ar.shape

    if ignore_diagonal:
        # -= changes function arguments on the outside!
        # https://stackoverflow.com/questions/11585793/are-numpy-arrays-passed-by-reference
        Ar = Ar - return_matrix_diag(Ar)
        Ag = Ag - return_matrix_diag(Ag)

    edges_per_graph = N*(N-1) if ignore_diagonal else N*N
    num_graphs = bs

    # squared_error_per_edge is the average (over the batch of graphs) of the average squared error (over each possible edge in graph)
    # squared_error_per_graph is the average (over the batch of graphs) of the total squared error for each graph
    # normalized_squared_error_per_graph is the average (over the batch of graphs) of the normalized squared error for each graph
    squared_error = (Ar-Ag)**2
    total_squared_error = np.sum(squared_error)
    squared_error_per_edge = total_squared_error / (edges_per_graph * num_graphs)
    squared_error_per_graph = total_squared_error / num_graphs
    #                                                           ||Estimated object-True object||^2/||True object||^2
    normalized_squared_error_per_graph = np.mean( np.sum(squared_error, axis=(1, 2)) / np.sum(Ag**2, axis=(1, 2)) )

    return total_squared_error, squared_error_per_edge, squared_error_per_graph, normalized_squared_error_per_graph


def binary_classification_metrics_tests():

    np.seterr(divide='ignore', invalid='ignore')
    #######################
    # mcc wiki ex: https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
    prediction = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1])
    actual = np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0])
    assert np.isclose(binary_classification_metrics(prediction, actual)['mcc'], 0.4780914)

    actual = np.array([0, 0, 0, 0])

    all_zero = np.zeros_like(actual)
    all_ones = np.ones_like(actual)
    mixed = np.array([0, 0, 1, 1])
    opp_mixed = np.array([1, 1, 0, 0])

    ### Actual class homogoneous ###
    # will always fail unless we pred all correct(1)/incorrect(-1)
    # all correct pred
    assert np.isclose(binary_classification_metrics(all_zero, all_zero)['mcc'], 1) # No Pos Pred & No Actual Pos
    assert np.isclose(binary_classification_metrics(all_ones, all_ones)['mcc'], 1) # No Neg Pred & No Actual Neg

    # all wrong preds
    assert np.isclose(binary_classification_metrics(all_ones, all_zero)['mcc'], -1)
    assert np.isclose(binary_classification_metrics(all_zero, all_ones)['mcc'], -1)

    # mixed preds
    assert np.isnan(binary_classification_metrics(mixed, all_zero)['mcc']) # No Actual Pos
    assert np.isnan(binary_classification_metrics(mixed, all_ones)['mcc']) # No Actual Neg

    ### Actual class mixed ###
    # will succeed unless pred are homogoneos positive/negative

    # homogenous pred
    assert np.isnan(binary_classification_metrics(all_zero, mixed)['mcc'])  # No Pred Pos
    assert np.isnan(binary_classification_metrics(all_ones, mixed)['mcc'])  # No Pred Neg

    # mixed pred: all correct, all wrong, in-between
    assert np.isclose(binary_classification_metrics(mixed, mixed)['mcc'], 1) # all correct
    assert np.isclose(binary_classification_metrics(mixed, opp_mixed)['mcc'], -1) # all wrong
    assert np.isclose(binary_classification_metrics(opp_mixed, mixed)['mcc'], -1) # all wrong & symmetric

    #######################
    # accuracy
    assert np.isclose(accuracy_raw(np.zeros(100), np.ones(100)), 0.0)
    assert np.isclose(accuracy_raw(np.ones(100), np.ones(100)), 1.0)
    x = np.concatenate((np.ones(50), np.zeros(50)))
    y = np.ones(100)
    assert np.isclose(accuracy_raw(x, y), .5)
    y = np.zeros(100)
    assert np.isclose(accuracy_raw(x, y), .5)
    #######################

    #######################
    # pr/re/macro-f1/accuracy

    # labels all 1's
    y = np.array([1, 1, 1])

    x = np.array([0, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 0)
    assert np.isclose(metrics['macro-f1'], 0)

    x = np.array([1, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1, 1 / 3, 1 / 2])
    assert np.isclose(metrics['accuracy'], 1 / 3)
    assert np.isclose(metrics['macro-f1'], 1 / 4)

    x = np.array([1, 1, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1, 2 / 3, 4 / 5])
    assert np.isclose(metrics['accuracy'], 2 / 3)
    assert np.isclose(metrics['macro-f1'], 2 / 5)

    x = np.array([1, 1, 1])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1, 1, 1])
    assert np.isclose(metrics['accuracy'], 1)
    assert np.isclose(metrics['macro-f1'], 1 / 2)

    # mixed labels
    y = np.array([0, 1, 1])
    x = np.array([0, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 1 / 3)
    assert np.isclose(metrics['macro-f1'], 1 / 4)

    x = np.array([1, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 0)
    assert np.isclose(metrics['macro-f1'], 0)

    x = np.array([1, 1, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1 / 2, 1 / 2, 1 / 2])
    assert np.isclose(metrics['accuracy'], 1 / 3)
    assert np.isclose(metrics['macro-f1'], 1 / 4)

    x = np.array([1, 1, 1])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [2 / 3, 1, 4 / 5])
    assert np.isclose(metrics['accuracy'], 2 / 3)
    assert np.isclose(metrics['macro-f1'], 2 / 5)

    y = np.array([0, 0, 1])
    x = np.array([0, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 2 / 3)
    assert np.isclose(metrics['macro-f1'], 2 / 5)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [2 / 3, 1, 4 / 5])

    x = np.array([1, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 1 / 3)
    assert np.isclose(metrics['macro-f1'], 1 / 4)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1 / 2, 1 / 2, 1 / 2])

    x = np.array([1, 1, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 0)
    assert np.isclose(metrics['macro-f1'], 0)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])

    x = np.array([1, 1, 1])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1 / 3, 1, 1 / 2])
    assert np.isclose(metrics['accuracy'], 1 / 3)
    assert np.isclose(metrics['macro-f1'], 1 / 4)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])

    # label all 0's
    y = np.array([0, 0, 0])
    x = np.array([0, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 1)
    assert np.isclose(metrics['macro-f1'], 1 / 2)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1, 1, 1])

    x = np.array([1, 0, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 2 / 3)
    assert np.isclose(metrics['macro-f1'], 2 / 5)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1, 2 / 3, 4 / 5])

    x = np.array([1, 1, 0])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 1 / 3)
    assert np.isclose(metrics['macro-f1'], 1 / 4)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [1, 1 / 3, 1 / 2])

    x = np.array([1, 1, 1])
    metrics = binary_classification_metrics(x, y)
    assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    assert np.isclose(metrics['accuracy'], 0)
    assert np.isclose(metrics['macro-f1'], 0)
    #metrics = binary_classification_metrics(x, y, auto_flip_f1=True)
    #assert np.allclose([metrics['precision'], metrics['recall'], metrics['f1']], [0, 0, 0])
    np.seterr(divide='warn', invalid='warn')
##########################

##########################
# These functions take two (adjacency) matrices output scalar metric

# behavior
# if a label graph has 0 edges, this will cause pr=re=0 --> 1/pr=1/re=inf --> f1=2/(inf+inf)=0
# Fix this by adding small epsilon to a the divisions.
#  precision = 0/(|Ar|+eps) = 0
#  recall = 0/(eps) = 0
#       => 2/(1/(pr+eps) + 1/(re+eps) ~= 0


# only tested on BINARY symmetric matrix slices (undirected graphs)
def batch_graph_metrics(x, y, ignore_diagonal: bool = True, graph_or_edge='graph'):
    if torch.is_tensor(x) and torch.is_tensor(y):
        return batch_graph_metrics_torch(x, y, ignore_diagonal, graph_or_edge)
    elif type(x) == np.ndarray and type(y) == np.ndarray:
        return batch_graph_metrics_np(x, y, ignore_diagonal, graph_or_edge)
    else:
        raise ValueError(f'x and y must be same type: x {type(x)} y: {type(y)}')


def batch_graph_metrics_torch(x: torch.tensor,
                        y: torch.tensor,
                        ignore_diagonal: bool = True,
                        graph_or_edge='graph'):
    assert x.dtype == torch.bool and y.dtype == torch.bool, f'only take binary inputs -> must threshold first'
    batch_size, N, _ = x.shape

    if ignore_diagonal:
        true_on_diag = torch.diag(torch.ones(N, dtype=torch.bool, device=x.device))
        true_on_diag_slice = torch.broadcast_to(true_on_diag, (batch_size, N, N)) # tensor with slice diag's==True
        false_on_diag_slice = ~true_on_diag_slice
        x = x & false_on_diag_slice # now diagonals of x are all false, anything that was true before, stays true
        y = y & false_on_diag_slice

    if x.ndim == 2:
        x = torch.unsqueeze(x, dim=0)
    if y.ndim == 2:
        y = torch.unsqueeze(y, dim=0)

    assert (x.shape == y.shape), 'inputs must be same size'
    assert (x.ndim == 3) and (y.ndim == 3), 'inputs must be 3D'
    assert (x.shape[1] == x.shape[2]) and (y.shape[1] == y.shape[2]), 'inputs must be square'

    num_graphs, n = x.shape[:2]
    if graph_or_edge == 'graph':
        reduction_axes = (1, 2)
        tp, tn, fp, fn = confusion_matrix(x, y, reduction_axes=reduction_axes)
        if ignore_diagonal:
            tn -= n  # x=y=0 on diagonal...subtract theses out!
            num_possible_edges = n*(n-1) # subtract out diags
        else:
            num_possible_edges = n*n
        total = num_possible_edges#*np.ones_like(tp)

    elif graph_or_edge == 'edge':
        reduction_axes = 0
        tp, tn, fp, fn = confusion_matrix(x, y, reduction_axes=reduction_axes)
        total = num_graphs

    else:
        print(f'unrecognized graph_or_edge {graph_or_edge}\n')
        exit(2)

    pr = precision_(tp=tp, fp=fp)
    re = recall_(tp=tp, fn=fn)
    f1 = f1_(tp=tp, fp=fp, fn=fn)
    macro_f1 = macro_f1_(tp=tp, tn=tn, fp=fp, fn=fn)
    acc = accuracy(tp=tp, tn=tn, total=total)
    mcc = matthew_corr_coef(tp=tp, tn=tn, fp=fp, fn=fn)

    return pr, re, f1, macro_f1, acc, mcc

def batch_graph_metrics_np(x: np.ndarray,
                        y: np.ndarray,
                        ignore_diagonal: bool = True,
                        graph_or_edge='graph'):
    # x :: batch_size x N x N ndarray of recovered/learned/predicted adjacency matrices
    # y :: batch_size x N x N ndarray of ground truth/label adjacency matrices
    assert x.dtype==np.dtype('bool') and y.dtype==np.dtype('bool'), f'only take binary inputs -> must threshold first'
    batch_size, N, _ = x.shape

    # Assumptions:
    # diagonal of all elements are assumed to be 0 (and thus False in both x and y), and thus equal (thus contributing
    # to true negatives -> tn). Enforce this, and subtract of excess tn later.
    if ignore_diagonal:
        true_on_diag_slice = np.broadcast_to(np.diag(np.ones(N, dtype=bool)), (batch_size, N, N)) # tensor with slice diag's==True
        false_on_diag_slice = ~true_on_diag_slice
        x = x & false_on_diag_slice # now diagonals of x are all false, anything that was true before, stays true
        y = y & false_on_diag_slice
    # any element greater than 0 (including True) is considered an edge. Any element EQUAL
    #  to 0 (including False) is considered a non-edge

    if len(x.shape) == 2:
        x = np.expand_dims(x, axis=0)
    if len(y.shape) == 2:
        y = np.expand_dims(y, axis=0)

    assert np.all(np.equal(x.shape, y.shape)), 'inputs must be same size'
    assert (len(x.shape) == 3) and (len(y.shape) == 3), 'inputs must be 3D'
    assert (x.shape[1] == x.shape[2]) and (y.shape[1] == y.shape[2]), 'inputs must be square'

    num_graphs, n = x.shape[:2]
    if graph_or_edge == 'graph':
        reduction_axes = (1, 2)
        tp, tn, fp, fn = confusion_matrix(x, y, reduction_axes=reduction_axes)
        if ignore_diagonal:
            #tn is a vector: a count for each slice. Subtract n true negatives out from each (corresponding to
            # the diagonal), and set the denominator as the total # of possible edges w/o self loops (corresponding
            # to diagonal).
            # Note that for an undirected graph, we are double counting tp/tn/fp/fn's. Thus don't divide
            # num_possible_edges by 2 (for undirected graph). This will cancel double counting.
            tn -= n  # x=y=0 on diagonal...subtract theses out!
            num_possible_edges = n*(n-1) # subtract out diags

        else:
            num_possible_edges = n*n
        total = num_possible_edges#*np.ones_like(tp)

    elif graph_or_edge == 'edge':
        reduction_axes = 0
        tp, tn, fp, fn = confusion_matrix(x, y, reduction_axes=reduction_axes)
        total = num_graphs

    else:
        print(f'unrecognized graph_or_edge {graph_or_edge}\n')
        exit(2)
    """
    if auto_flip_f1: # convert tests to average macro f1
        num_negative_labels = (fp + tn)
        num_positive_labels = (tp + fn)
        flip_idxs = num_negative_labels > num_positive_labels
        tp[flip_idxs], tn[flip_idxs] = tn[flip_idxs], tp[flip_idxs]
        fp[flip_idxs], fn[flip_idxs] = fn[flip_idxs], fp[flip_idxs]

    if auto_flip_f1:
        f1_flip = f1_(tp=tn, fp=fp, fn=fn) #TN used instead of TP
        f1 = np.maximum(f1, f1_flip)
    """

    pr = precision_(tp=tp, fp=fp)
    re = recall_(tp=tp, fn=fn)
    f1 = f1_(tp=tp, fp=fp, fn=fn)
    macro_f1 = macro_f1_(tp=tp, tn=tn, fp=fp, fn=fn)
    acc = accuracy(tp=tp, tn=tn, total=total)
    mcc = matthew_corr_coef(tp=tp, tn=tn, fp=fp, fn=fn)

    return pr, re, f1, macro_f1, acc, mcc

    """
    #note that adj for undirected graph double counts edges. But the
    # double counting will cancel out in subsequent division

    # all 0/1s. Thus max is # elements, min = 0
    max_out_val = x.shape[1]*x.shape[2]
    if max_out_val < np.iinfo(np.uint8).max:
        out_type = np.uint8
    elif max_out_val < np.iinfo(np.uint16).max:
        out_type = np.uint16
    elif max_out_val < np.iinfo(np.uint32).max:
        out_type = np.uint32
    else:
        out_type = np.uint64

    # NUMBER of edges in each slice of tensor: |x & y|
    x_y_intersection_cardinality = np_slice_sum(x_y_intersection, out_type) # np.sum(x_y_intersection, (1, 2))
    # |x|
    x_cardinality = np_slice_sum(x, out_type) # np.sum(x, (1, 2))
    # |y|
    y_cardinality = np_slice_sum(y, out_type) # np.sum(y, (1, 2))
    """


def confusion_matrix(x: np.ndarray, y: np.ndarray, reduction_axes = 0):
    if torch.is_tensor(x) and torch.is_tensor(y):
        assert len(x) == len(y)
        tp = torch.sum((x == y) & (y > 0), dim=reduction_axes)
        tn = torch.sum((x == y) & (y == 0), dim=reduction_axes)
        fp = torch.sum((x != y) & (y == 0), dim=reduction_axes)
        fn = torch.sum((x != y) & (y > 0), dim=reduction_axes)
    elif type(x) == np.ndarray and type(y) == np.ndarray:
        assert len(x) == len(y)
        tp = np.sum((x == y) & (y > 0), axis=reduction_axes)
        tn = np.sum((x == y) & (y == 0), axis=reduction_axes)
        fp = np.sum((x != y) & (y == 0), axis=reduction_axes)
        fn = np.sum((x != y) & (y > 0), axis=reduction_axes)
    else:
        raise ValueError(f'x and y must be same type: x {type(x)}, y {type(y)}')

    return tp, tn, fp, fn

##########################

def numba_speedup_tests(n=10):
    # test numba speed up vs regular python
    custom_times = np.zeros(n)
    scipy_times = np.zeros(n)

    x, y = np.random.binomial(1, .5, 100), np.random.binomial(1, .5, 100)
    accuracy_raw(x,y)


    for i in range(n):
        x, y = np.random.binomial(1, .5, 100), np.random.binomial(1, .5, 100)
        start = time.time()
        a = accuracy_raw(x, y)
        custom_times[i] = time.time() - start

        start = time.time()
        b = accuracy_score(x, y)
        scipy_times[i] = time.time() - start
        assert np.isclose(a, b)
    print(f'accuracy custom mean time: {custom_times.mean():.8f}\n') #njit makes it slower!
    print(f'accuracy scipy mean time: {scipy_times.sum():.8f}\n\n\n')

    #######################
    size = (1000, 68, 68)
    x = np.random.binomial(1, .5, size=size)
    y = np.random.binomial(1, .5, size=size)
    # call once for compilation time
    batch_graph_metrics(x, y, graph_or_edge='graph')
    times = np.zeros(n)
    for i in range(n):
        x = np.random.binomial(1, .5, size=size)#.astype(np.int8)
        y = np.random.binomial(1, .5, size=size)#.astype(np.int8)
        start = time.time()
        batch_graph_metrics(x, y, graph_or_edge='graph')
        times[i] = time.time() - start
    print(f'pr/re/f1/acc mean time: {times.mean():.8f}\n')
    print(f'pr/re/f1/acc total time: {times.sum():.8f}\n')
    #######################

def np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t):
    assert np.allclose(pr, pr_t) and np.allclose(re, re_t) and np.allclose(f1, f1_t) \
           and np.allclose(macro_f1, macro_f1_t) and np.allclose(acc, acc_t)
    assert np.allclose(np.isnan(mcc), np.isnan(mcc_t)) and np.allclose(mcc[~np.isnan(mcc)], mcc_t[~torch.isnan(mcc_t)])

def graph_classification_metric_tests():
    np.seterr(divide='ignore', invalid='ignore')
    #######################
    # batch_graph_metrics
    no_edges = np.zeros((3, 3), dtype=bool)
    fully_connected = (np.ones((3, 3)) - np.eye(3)) > 0
    mix_majority_1 = np.array([
                    [0, 1, 0],
                    [1, 0, 1],
                    [0, 1, 0]])
    mix_majority_0 = np.array([
                    [0, 0, 1],
                    [0, 0, 0],
                    [1, 0, 0]])
    mix_majority_0, mix_majority_1 = mix_majority_0>0, mix_majority_1>0

    x = np.stack((no_edges, mix_majority_1, mix_majority_0, fully_connected))

    y = np.tile(fully_connected[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [0, 1, 1, 1]
    re_soln = [0, 2/3, 1/3, 1]
    f1_soln = [0, 4/5, 1/2, 1]
    macro_f1_soln = [0, 2/5, 1/4, 1/2]
    acc_soln = [0, 2/3, 1/3, 1]
    mcc_soln = [-1, np.nan, np.nan, 1]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='graph')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='graph')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'label all 1s.'

    y = np.tile(mix_majority_1[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [0, 1, 0, 2/3]
    re_soln = [0, 1, 0, 1]
    f1_soln = [0, 1, 0, 4/5]
    macro_f1_soln = [1/4, 1, 0, 2/5]
    acc_soln = [1/3, 1,  0, 2/3]
    mcc_soln = [np.nan, 1, -1, np.nan]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='graph')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='graph')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'label mix: majority 1s.'

    y = np.tile(mix_majority_0[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [0, 0, 1, 1/3]
    re_soln = [0, 0, 1, 1]
    f1_soln = [0, 0, 1, 1/2]
    macro_f1_soln = [2/5, 0, 1, 1/4]
    acc_soln = [2/3, 0, 1, 1/3]
    mcc_soln = [np.nan, -1, 1, np.nan]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='graph')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='graph')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'label mix: majority 0s.'

    y = np.tile(no_edges[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [0, 0, 0, 0]
    re_soln = [0, 0, 0, 0]
    f1_soln = [0, 0, 0, 0]
    macro_f1_soln = [1/2, 1/4, 2/5, 0]
    acc_soln = [1, 1/3, 2/3, 0]
    mcc_soln = [1, np.nan, np.nan, -1]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='graph')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='graph')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'label mix: all 0s.'

    """
    ########################3
    ### flip cases
    # this case wont flip
    y = np.tile(fully_connected[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [0, 1, 1, 1]
    re_soln = [0, 2 / 3, 1 / 3, 1]
    f1_soln = [0, 4 / 5, 1 / 2, 1]
    acc_soln = [0, 2 / 3, 1 / 3, 1]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='graph', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label all 1s. FLIP'

    # this case wont flip
    y = np.tile(mix_majority_1[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [0, 1, 0, 2 / 3]
    re_soln = [0, 1, 0, 1]
    f1_soln = [0, 1, 0, 4 / 5]
    acc_soln = [1 / 3, 1, 0, 2 / 3]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='graph', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: majority 1s. FLIP'

    # this will flip
    y = np.tile(mix_majority_0[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [2/3, 0, 1, 0]
    re_soln = [1, 0, 1, 0]
    f1_soln = [4/5, 0, 1, 0]
    acc_soln = [2/3, 0, 1, 1/3]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='graph', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: majority 0s. FLIP'

    # this will flip
    y = np.tile(no_edges[np.newaxis, :], (len(x), 1, 1))
    pr_soln = [1, 1, 1, 0]
    re_soln = [1, 1/3, 2/3, 0]
    f1_soln = [1, 1/2, 4/5, 0]
    acc_soln = [1,  1/3, 2/3, 0]
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='graph', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: all 0s. FLIP'
    """
    ## extend to tensors
    eps = 1e-12
    l = 4
    acc_soln, pr_soln, re_soln, f1_soln, macro_f1_soln, mcc_soln = \
        np.zeros(l), np.zeros(l), np.zeros(l), np.zeros(l), np.zeros(l), np.zeros(l)

    # intersect at subset
    input1 = [[0, 1, 0], [1, 0, 0], [0, 0, 0]]
    label1 = [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
    acc_soln[0], pr_soln[0], re_soln[0], f1_soln[0], macro_f1_soln[0], mcc_soln[0] = 2/3, 1, 1/2, 2/3, 2/3, 1/2
    # intersect all
    input2 = [[0, 1, 0], [1, 0, 0], [0, 0, 0]]
    label2 = [[0, 1, 0], [1, 0, 0], [0, 0, 0]]
    acc_soln[1], pr_soln[1], re_soln[1], f1_soln[1], macro_f1_soln[1], mcc_soln[1] = 1, 1, 1, 1, 1, 1
    # intersect nowhere: zero recovered graph
    input3 = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
    label3 = [[0, 1, 0], [1, 0, 1], [0, 1, 0]]
    acc_soln[2], pr_soln[2], re_soln[2], f1_soln[2], macro_f1_soln[2], mcc_soln[2] = 1/3, 0, 0, 0, 1/4, np.nan
    # intersect nowhere: both non-zero
    input4 = [[0, 0, 0], [0, 0, 1], [0, 1, 0]]
    label4 = [[0, 1, 0], [1, 0, 0], [0, 0, 0]]
    acc_soln[3], pr_soln[3], re_soln[3], f1_soln[3], macro_f1_soln[3], mcc_soln[3] = 1/3, 0, 0, 0, 1/4, -1/2
    X = np.array([input1, input2, input3, input4], dtype=bool)
    Y = np.array([label1, label2, label3, label4], dtype=bool)
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(X, Y, graph_or_edge='graph')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(X), torch.tensor(Y), graph_or_edge='graph')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.stack((pr, re, f1, macro_f1, acc, mcc))
    solns = np.stack((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    assert np.allclose(results, solns, equal_nan=True), f'misc ex'
    np.seterr(divide='warn', invalid='warn')

def edge_classification_metric_tests():
    np.seterr(divide='ignore', invalid='ignore')
    #######################
    # batch_graph_metrics
    no_edges = np.zeros((3, 3), dtype=bool)
    fully_connected = (np.ones((3, 3)) - np.eye(3)) > 0
    mix_majority_1 = np.array([[0, 1, 0],
                               [1, 0, 1],
                               [0, 1, 0]], dtype=bool)
    mix_majority_0 = np.array([[0, 0, 1],
                               [0, 0, 0],
                               [1, 0, 0]], dtype=bool)

    ones_dr = (np.ones_like(no_edges)-np.eye(3)) > 0
    ones_diamond = np.copy(ones_dr)
    ones_diamond[2, 0] = ones_diamond[0, 2] = False

    x = np.stack((no_edges, mix_majority_1, mix_majority_0, fully_connected))
    y = x #np.tile(fully_connected[np.newaxis, :], (len(x), 1, 1))
    pr_soln = re_soln = f1_soln = ones_dr
    macro_f1_soln = ones_dr + .5*np.eye(3)
    acc_soln = mcc_soln = np.ones_like(ones_dr)
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='edge')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='edge')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'input and output same, all mixed'


    # input and output the same, all homog
    x = np.stack((fully_connected, fully_connected, fully_connected, fully_connected))
    y = x
    pr_soln = re_soln = f1_soln = ones_dr
    macro_f1_soln = .5*np.ones_like(pr_soln)
    acc_soln = mcc_soln =  np.ones_like(pr_soln)
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='edge')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='edge')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'input and output all ones'

    # input and output the same, all homog
    x = np.stack((no_edges, no_edges, no_edges, no_edges))
    y = x
    pr_soln = re_soln = f1_soln = np.zeros_like(no_edges)
    macro_f1_soln = .5 * np.ones_like(pr_soln)
    acc_soln = mcc_soln = np.ones_like(pr_soln)
    solns = np.concatenate((pr_soln, re_soln, f1_soln, macro_f1_soln, acc_soln, mcc_soln))
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='edge')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='edge')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((pr, re, f1, macro_f1, acc, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'input and output all zeros'



    # input and output not same
    # include all edge cases: 4 ways mcc can fail and 2 user defined cases
    x = np.array([
        [[0, 1, 0],
         [1, 0, 1],
         [0, 0, 0]],
        [[0, 0, 0],
         [1, 0, 1],
         [0, 1, 0]],
        [[0, 1, 0],
         [1, 0, 1],
         [0, 0, 0]]], dtype=bool)
    y = np.zeros_like(x)
    y[0,0,1]=y[0,1,2]=y[1,0,1]=y[1,1,0]=y[1,2,0]=y[1,1,2]=y[2,0,1]=y[2,1,2]=True
    f1_soln = np.zeros_like(x[0], dtype=np.float32)
    f1_soln[1, 0], f1_soln[0, 1], f1_soln[1, 2]  = 1/2, 4/5, 1
    macro_f1_soln = .5*np.ones_like(f1_soln)
    macro_f1_soln[1, 0] = 1/4
    macro_f1_soln[0, 1] = macro_f1_soln[2, 0] = macro_f1_soln[2, 1] = 2/5
    mcc_soln = np.ones_like(f1_soln)
    mcc_soln[0, 1] = mcc_soln[1, 0] = mcc_soln[2, 0] = mcc_soln[2, 1] = np.nan
    solns = np.concatenate((f1_soln, macro_f1_soln, mcc_soln))
    #
    pr, re, f1, macro_f1, acc, mcc = batch_graph_metrics(x, y, graph_or_edge='edge')
    pr_t, re_t, f1_t, macro_f1_t, acc_t, mcc_t = batch_graph_metrics(torch.tensor(x), torch.tensor(y), graph_or_edge='edge')
    np_torch_compare(pr, pr_t, re, re_t, f1, f1_t, macro_f1, macro_f1_t, acc, acc_t, mcc, mcc_t)
    results = np.concatenate((f1, macro_f1, mcc))
    assert np.allclose(results, solns, equal_nan=True), f'input and output all zeros'



    """
    x = np.stack((no_edges, mix_majority_1, mix_majority_0, fully_connected))
    y = np.tile(fully_connected[np.newaxis, :], (len(x), 1, 1))
    # all batch-cols are all zeros/ones bc copying same y
    pr_soln = ones_dr
    re_soln = .5 * ones_dr
    f1_soln = (2 / 3) * ones_dr
    macro_f1_soln = .5*f1_soln
    acc_soln = .5 * ones_dr
    mcc_soln = np.zeros_like[acc_soln]
    mcc_soln[ones_dr>0] = np.nan
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge')
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'all batch-cols are same class (all ones/zeros).'

    #DONE
    y = np.tile(mix_majority_1[np.newaxis, :], (len(x), 1, 1))
    pr_soln = ones_diamond
    re_soln = .5*ones_diamond
    f1_soln = (2/3)*ones_diamond
    acc_soln = .5*ones_dr
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge')
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: majority 1s.'

    #DONE
    y = np.tile(mix_majority_0[np.newaxis, :], (len(x), 1, 1))
    pr_soln = ones_dr - ones_diamond
    re_soln = .5 * (ones_dr - ones_diamond)
    f1_soln = (2 / 3) * (ones_dr - ones_diamond)
    acc_soln = .5 * ones_dr
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge')
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: majority 0s.'

    ##DONE
    y = np.tile(no_edges[np.newaxis, :], (len(x), 1, 1))
    pr_soln = np.zeros_like(no_edges)
    re_soln = np.zeros_like(no_edges)
    f1_soln = np.zeros_like(no_edges)
    acc_soln = .5*ones_dr
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge')
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label all 0s.'

    ########################
    ### flip cases
    # Flipping has a bit different meaning here. Will flip if majority of labels
    #  in each batch-axis-columns are 0's. Since we are simply tiling/repeating
    #  same y over and over, if y has a 1 somewhere, that batch-axis-column will NOT flip,
    #  o.w. if will flip

    # For example, let y = label matrix 3x3 matrix with ones at the corners ( (2,1), (1,2) ),
    #  zeros elsewhere, tiled M times to match predition size. Then the batch-axis-columns at
    #  ( (:, 0, 1) , (:, 1, 0) , (:, 2, 1), (:, 1, 2) ) are all zeros (amd thus majority zeros).
    #  We thus want to flip the preds & labels in these columns so f-msr is (better) defined.


    # No batch-columns will flip
    y = np.tile(fully_connected[np.newaxis, :], (len(x), 1, 1))
    pr_soln = ones_dr
    re_soln = .5 * ones_dr
    f1_soln = (2 / 3) * ones_dr
    acc_soln = .5 * ones_dr
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label all 1s. FLIP'

    # 2 batch-axis-columns will flip (upper right/lower left elements)
    y = np.tile(mix_majority_1[np.newaxis, :], (len(x), 1, 1))
    # Non-flip solns
    pr_soln = np.copy(ones_diamond)
    re_soln = np.copy(.5*ones_diamond)
    f1_soln = np.copy((2/3)*ones_diamond)
    acc_soln = np.copy(.5*ones_dr)
    # Flip solns
    flip_idxs = ((0, 2), (2, 0)) # ( (rows), (cols) )
    pr_soln[flip_idxs] = 1
    re_soln[flip_idxs] = 1 / 2
    f1_soln[flip_idxs] = 2 / 3
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: majority 1s. FLIP'

    # 4 batch-axis-columns will flip (diamond elements)
    y = np.tile(mix_majority_0[np.newaxis, :], (len(x), 1, 1))
    # non-flip solns
    pr_soln = ones_dr-ones_diamond
    re_soln = .5 * (ones_dr-ones_diamond)
    f1_soln = (2 / 3) * (ones_dr-ones_diamond)
    acc_soln = .5 * np.copy(ones_dr)
    # flip solns
    flip_idxs = ((0, 1, 1, 2), (1, 0, 2, 1))
    pr_soln[flip_idxs] = 1
    re_soln[flip_idxs] = 1/2
    f1_soln[flip_idxs] = 2/3
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label mix: majority 0s. FLIP'

    # all 6 batch-axis-columns will flip (all non-diag elements)
    y = np.tile(no_edges[np.newaxis, :], (len(x), 1, 1))
    pr_soln = ones_dr
    re_soln = .5*ones_dr
    f1_soln = (2/3)*ones_dr
    acc_soln = .5*ones_dr
    solns = np.concatenate((pr_soln, re_soln, f1_soln, acc_soln))
    pr, re, f1, acc = batch_graph_metrics(x, y, graph_or_edge='edge', auto_flip_f1=True)
    results = np.concatenate((pr, re, f1, acc))
    assert np.allclose(results, solns), f'label: all 0s. FLIP'
    """

    np.seterr(divide='warn', invalid='warn')

if __name__ == "__main__":
    binary_classification_metrics_tests()
    #numba_speedup_tests(n=10)
    graph_classification_metric_tests()
    edge_classification_metric_tests()
