import numpy as np
from numba import njit, prange
import math

# numba doesn't allow for axis arguments in reduction functions:
#  https://github.com/numba/numba/issues/1269


@njit(parallel = True)
def np_reduce_along_slice(func2d, arr, out_type):
    # assumed to be 3D tensor where first dimension is batch dimension
    # function applied over entire matrix slice.
    # Slices processed independantly (prange requires this)
    assert arr.ndim == 3
    result = np.empty(arr.shape[0], dtype=out_type)
    for i in prange(len(result)):
        result[i] = func2d(arr[i])
    return result

@njit
def np_slice_mean(array, out_type):
    # equivalent to np.mean(arr, axis=(1, 2))
    return np_reduce_along_slice(np.mean, array, out_type)

@njit
def np_slice_sum(array, out_type):
    # equivalent to np.sum(arr, axis=(1, 2))
    return np_reduce_along_slice(np.sum, array, out_type)
##############################

##############################
@njit(parallel = True)
def np_reduce_along_batch_dim(func1d, arr, out_type):#, is_symmetric=False):
    # assumed to be 3D tensor where first dimension is batch dimension
    # function applied over first axis: arr[:,i,j].

    assert arr.ndim == 3
    result = np.empty(shape=arr.shape[1:], dtype=out_type)
    for i in prange(arr.shape[1]):
        low = 0#i if is_symmetric else 0
        for j in range(low, arr.shape[2]):
            result[i, j] = func1d(arr[:, i, j])
            #if is_symmetric:
            #    result[j, i] = result[i, j]
    return result


@njit
def np_batch_dim_sum(array, out_type):#,  is_symmetric=False):
    # equivalent to np.sum(arr, axis=0)
    return np_reduce_along_batch_dim(np.sum, array, out_type)#, is_symmetric=is_symmetric)
##############################

"""
# takes 3D tensor of square matrices
def return_matrix_diag(arr):
    assert arr.ndim == 3
    # ith row of b <==> ith diagonal of arr[i]
    b = np.diagonal(arr, axis1=1, axis2=2)
    diag_of_arr = np.zeros_like(arr)
    diag_indices = np.diag_indices(n=arr.shape[-1], ndim=2)
    for i in prange(len(arr)):
        diag_of_arr[i][diag_indices] = b[i]

    return diag_of_arr
"""

# vectorized solution credit: https://stackoverflow.com/questions/67241824/3d-tensor-of-diagonal-matrices
def return_matrix_diag(arr):
    # 3D tensor of square matrices
    assert arr.ndim == 3 and (arr.shape[-1] == arr.shape[-2])
    # ith row of b <==> ith diagonal of arr[i]
    b = np.diagonal(arr, axis1=1, axis2=2)
    k, n = b.shape # k is num slices, n is size of nxn matrix

    B = np.zeros_like(b, shape=(k, n * n))
    B[:, ::(n + 1)] = b
    B = B.reshape(k, n, n)

    return B


#@njit(parallel = True)
def zero_matrix_diag(arr):
    # function applied over entire matrix slice
    assert arr.ndim == 3
    for i in prange(len(arr)):
        np.fill_diagonal(arr[i], 0)
    return arr

# for speed comparison
def regular_zero_matrix_diag(arr):
    # function applied over entire matrix slice
    assert arr.ndim == 3
    for i in range(len(arr)):
        np.fill_diagonal(arr[i], 0)
    return arr

#@njit(parallel = True)
def vectors_to_diagonal_tensor(x):
    # maps rows into diagonal matrices
    # returns tensor of these matrics
    assert x.ndim == 2
    num_vecs, n = x.shape
    #di = np.diag_indices(n)
    D = np.zeros((num_vecs, n, n), dtype=x.dtype)
    for i in prange(len(x)):
        np.fill_diagonal(D[i], x[i])
    return D

#@njit(parallel = True)
def apply_function_to_diagonal(x):
    assert (x.ndim == 3) and (x.shape[1] == x.shape[2])
    for i in prange(x.shape[0]):
        xd = np.diag(x[i])
        xd_f = np.divide(1, np.sqrt(xd))
        np.fill_diagonal(x[i], xd_f)


def TEST_return_matrix_diag():

    slices = 100
    for i in range(10):
        n = np.random.randint(low=2, high=100) # size of slice
        tensor = np.random.rand(slices, n, n)
        tensor_diag = return_matrix_diag(tensor)
        b = (tensor == tensor_diag) # slice diagonals of this should be True
        for j in range(slices):
            assert np.all(np.diagonal(b[i])), f' N {n} -> some matrix diagonal not equal'


def TEST_apply_function_to_diagonal():
    x = np.array([[[1,3],[4,5]]], dtype=np.float32)
    soln = np.array([[[1,3],[4,1/math.sqrt(5)]]])
    apply_function_to_diagonal(x)
    assert np.allclose(x, soln)

    x = np.array([[[.5, 3], [.67, 5]]], dtype=np.float32)
    x = np.tile(x, (3, 1, 1))
    soln = np.array([[1/math.sqrt(.5), 3], [.67, 1/math.sqrt(5)]])
    soln = np.tile(soln, (3, 1, 1))
    apply_function_to_diagonal(x)
    assert np.allclose(x, soln)


def TEST_vectors_to_diagonal_tensor():
    a = np.array([[1, 2, 3], [3, 4, 5]])

    b = np.zeros((2, 3, 3))
    b[0] = np.diag(a[0])
    b[1] = np.diag(a[1])
    soln = vectors_to_diagonal_tensor(a)
    assert np.allclose(b, soln)

    a = np.transpose(a, axes=(1,0))
    b = np.zeros((3, 2, 2))
    b[0] = np.diag(a[0])
    b[1] = np.diag(a[1])
    b[2] = np.diag(a[2])
    soln = vectors_to_diagonal_tensor(a)
    assert np.allclose(b, soln)

######
@njit(parallel = True)
def confusion_matrix(x: np.ndarray, y: np.ndarray, edge_threshold=0, reduction_axes: np.ndarray=[1,2]):
    # > edge_threshold => edge exists
    is_tp = lambda pred, label, thresh: pred>thresh and label>thresh
    is_tn = lambda pred, label, thresh: pred<thresh and label<thresh
    is_fp = lambda pred, label, thresh: pred>thresh and label<thresh
    is_fn = lambda pred, label, thresh: pred<thresh and label>thresh
    x = np.reshape(x, (-1))
    y = np.reshape(y, (-1))
    assert len(x)==len(y), f'preds and labels must be same size'
    tp, tn, fp, fn = 0, 0, 0, 0

    for i in prange(len(x)):
        if is_tp(pred=x[i], label=y[i], thresh=edge_threshold):
            tp+=1
        elif is_tn(pred=x[i], label=y[i], thresh=edge_threshold):
            tn+=1
        elif is_fp(pred=x[i], label=y[i], thresh=edge_threshold):
            fp+=1
        elif is_fn(pred=x[i], label=y[i], thresh=edge_threshold):
            fn+=1
    """
    for i in prange(len(x_)):
        if x_[i]>edge_threshold and y_[i]>edge_threshold:
            tp += 1
        elif x_[i]<edge_threshold and y_[i]<edge_threshold:
            tn += 1
        elif x_[i]>edge_threshold and y_[i]<edge_threshold:
            fp += 1
        elif x_[i]<edge_threshold and y_[i]>edge_threshold:
            fn += 1
    """
    return tp, tn, fp, fn




if __name__ == "__main__":
    # numba apply along slice tests
    import time

    TEST_return_matrix_diag()
    TEST_vectors_to_diagonal_tensor()
    TEST_apply_function_to_diagonal()
    print('passing vec2diagtensor test')




    n = 50

    size = (100, 68, 68)

    ######################

    """
    # call once for compilation time
    x = np.random.binomial(1, .5, size=size)
    zero_matrix_diag(x)
    numba_times = np.zeros(n)
    numpy_times = np.zeros(n)
    for i in range(n):
        x1 = np.random.binomial(1, .5, size=size)#.astype(np.int8)
        x2 = np.copy(x1)

        start = time.time()
        zero_matrix_diag(x1)
        numba_times[i] = time.time() - start

        start = time.time()
        regular_zero_matrix_diag(x2)
        numpy_times[i] = time.time() - start
        assert np.allclose(np.diagonal(x1, axis1=1, axis2=2), 0)
        assert np.allclose(np.diagonal(x2, axis1=1, axis2=2), 0)

    print(f'zero_diag numba mean time: {numba_times.mean():.8f}\n')
    print(f'zero_diag numpy mean time: {numpy_times.mean():.8f}\n')
    """

    x = np.random.binomial(1, .5, size=size).astype(np.int8)
    # call once for compilation time
    np_slice_sum(x, x.dtype)
    n = 50
    numba_times = np.zeros(n)
    numpy_times = np.zeros(n)
    for i in range(n):
        x = np.random.binomial(1, .5, size=size)
        y = np.random.binomial(1, .5, size=size)

        start = time.time()
        tp = np_slice_sum((x == y) & (y > 0), x.dtype)
        tn = np_slice_sum((x == y) & (y == 0), x.dtype) - n
        fp = np_slice_sum((x != y) & (y == 0), x.dtype)
        fn = np_slice_sum((x != y) & (y > 0), x.dtype)
        numba_times[i] = time.time() - start

    reduction_axes = (1, 2)
    for i in range(n):
        x = np.random.binomial(1, .5, size=size)
        y = np.random.binomial(1, .5, size=size)

        start = time.time()
        tp = np.sum((x == y) & (y > 0), axis=reduction_axes)
        tn = np.sum((x == y) & (y == 0), axis=reduction_axes) - n  # x=y=0 on diagonal...subtract theses out!
        fp = np.sum((x != y) & (y == 0), axis=reduction_axes)
        fn = np.sum((x != y) & (y > 0), axis=reduction_axes)
        numpy_times[i] = time.time() - start

    print("Slice sum")
    print(f'numba mean/med: {numba_times.mean():.5f}/{np.median(numba_times):.5f}\n')
    print(f'numpy mean/med: {numpy_times.mean():.5f}/{np.median(numpy_times):.5f}\n')
