import numpy as np
from numba import jit, float32, boolean, int32, float64
from nowcasting.hko_evaluation import rainfall_to_pixel
from nowcasting.config import cfg

@jit(float32(float32, float32, boolean))
def get_GDL_numba(prediction, truth, mask):
    """Accelerated version of get_GDL using numba(http://numba.pydata.org/)

    Parameters
    ----------
    prediction
    truth
    mask

    Returns
    -------
    gdl
    """
    seqlen, batch_size, _, height, width = prediction.shape
    gdl = np.zeros(shape=(seqlen, batch_size), dtype=np.float32)
    for i in range(seqlen):
        for j in range(batch_size):
            for m in range(height):
                for n in range(width):
                    if m + 1 < height:
                        if mask[i][j][0][m+1][n] and mask[i][j][0][m][n]:
                            pred_diff_h = abs(prediction[i][j][0][m+1][n] -
                                              prediction[i][j][0][m][n])
                            gt_diff_h = abs(truth[i][j][0][m+1][n] - truth[i][j][0][m][n])
                            gdl[i][j] += abs(pred_diff_h - gt_diff_h)
                    if n + 1 < width:
                        if mask[i][j][0][m][n+1] and mask[i][j][0][m][n]:
                            pred_diff_w = abs(prediction[i][j][0][m][n+1] -
                                              prediction[i][j][0][m][n])
                            gt_diff_w = abs(truth[i][j][0][m][n+1] - truth[i][j][0][m][n])
                            gdl[i][j] += abs(pred_diff_w - gt_diff_w)
    return gdl


def get_hit_miss_counts_numba(prediction, truth, mask, thresholds=None):
    """This function calculates the overall hits and misses for the prediction, which could be used
    to get the skill scores and threat scores:


    This function assumes the input, i.e, prediction and truth are 3-dim tensors, (timestep, row, col)
    and all inputs should be between 0~1

    Parameters
    ----------
    prediction : np.ndarray
        Shape: (seq_len, batch_size, 1, height, width)
    truth : np.ndarray
        Shape: (seq_len, batch_size, 1, height, width)
    mask : np.ndarray or None
        Shape: (seq_len, batch_size, 1, height, width)
        0 --> not use
        1 --> use
    thresholds : list or tuple

    Returns
    -------
    hits : np.ndarray
        (seq_len, batch_size, len(thresholds))
        TP
    misses : np.ndarray
        (seq_len, batch_size, len(thresholds))
        FN
    false_alarms : np.ndarray
        (seq_len, batch_size, len(thresholds))
        FP
    correct_negatives : np.ndarray
        (seq_len, batch_size, len(thresholds))
        TN
    """
    if thresholds is None:
        thresholds = cfg.HKO.EVALUATION.THRESHOLDS
    assert 5 == prediction.ndim
    assert 5 == truth.ndim
    assert prediction.shape == truth.shape
    assert prediction.shape[2] == 1
    thresholds = [rainfall_to_pixel(thresholds[i]) for i in range(len(thresholds))]
    thresholds = sorted(thresholds)
    ret = _get_hit_miss_counts_numba(prediction=prediction,
                                     truth=truth,
                                     mask=mask,
                                     thresholds=thresholds)
    return ret[:, :, :, 0], ret[:, :, :, 1], ret[:, :, :, 2], ret[:, :, :, 3]


@jit(int32(float32, float32, boolean, float32))
def _get_hit_miss_counts_numba(prediction, truth, mask, thresholds):
    seqlen, batch_size, _, height, width = prediction.shape
    threshold_num = len(thresholds)
    ret = np.zeros(shape=(seqlen, batch_size, threshold_num, 4), dtype=np.int32)

    for i in range(seqlen):
        for j in range(batch_size):
            for m in range(height):
                for n in range(width):
                    if mask[i][j][0][m][n]:
                        for k in range(threshold_num):
                            bpred = prediction[i][j][0][m][n] >= thresholds[k]
                            btruth = truth[i][j][0][m][n] >= thresholds[k]
                            ind = (1 - btruth) * 2 + (1 - bpred)
                            ret[i][j][k][ind] += 1
                            # The above code is the same as:
                            # ret[i][j][k][0] += bpred * btruth
                            # ret[i][j][k][1] += (1 - bpred) * btruth
                            # ret[i][j][k][2] += bpred * (1 - btruth)
                            # ret[i][j][k][3] += (1 - bpred) * (1- btruth)
    return ret


def get_balancing_weights_numba(data, mask, base_balancing_weights=None, thresholds=None):
    """Get the balancing weights
    
    Parameters
    ----------
    data
    mask
    base_balancing_weights
    thresholds

    Returns
    -------

    """
    if thresholds is None:
        thresholds = cfg.HKO.EVALUATION.THRESHOLDS
    if base_balancing_weights is None:
        base_balancing_weights = cfg.HKO.EVALUATION.BALANCING_WEIGHTS
    assert data.shape[2] == 1
    thresholds = [rainfall_to_pixel(thresholds[i]) for i in range(len(thresholds))]
    thresholds = sorted(thresholds)
    ret = _get_balancing_weights_numba(data=data,
                                       mask=mask,
                                       base_balancing_weights=base_balancing_weights,
                                       thresholds=thresholds)
    return ret


@jit(float32(float32, boolean, float32, float32))
def _get_balancing_weights_numba(data, mask, base_balancing_weights, thresholds):
    seqlen, batch_size, _, height, width = data.shape
    threshold_num = len(thresholds)
    ret = np.zeros(shape=(seqlen, batch_size, 1, height, width), dtype=np.float32)

    for i in range(seqlen):
        for j in range(batch_size):
            for m in range(height):
                for n in range(width):
                    if mask[i][j][0][m][n]:
                        ele = data[i][j][0][m][n]
                        for k in range(threshold_num):
                            if ele < thresholds[k]:
                                ret[i][j][0][m][n] = base_balancing_weights[k]
                                break
                        if ele >= thresholds[threshold_num - 1]:
                            ret[i][j][0][m][n] = base_balancing_weights[threshold_num]
    return ret

if __name__ == '__main__':
    from nowcasting.hko_evaluation import get_GDL, get_hit_miss_counts, get_balancing_weights
    from numpy.testing import assert_allclose, assert_almost_equal

    prediction = np.random.uniform(size=(10, 16, 1, 480, 480))
    truth = np.random.uniform(size=(10, 16, 1, 480, 480))
    mask = np.random.randint(low=0, high=2, size=(10, 16, 1, 480, 480)).astype(np.bool)
    import time

    begin = time.time()
    gdl = get_GDL(prediction=prediction, truth=truth, mask=mask)
    end = time.time()

    print("numpy gdl:", end - begin)
    begin = time.time()
    gdl_numba = get_GDL_numba(prediction=prediction, truth=truth, mask=mask)
    end = time.time()
    print("numba gdl:", end - begin)
    # gdl_mx = mx_get_GDL(prediction=prediction, truth=truth, mask=mask)
    # print gdl_mx
    assert_allclose(gdl, gdl_numba, rtol=1E-4, atol=1E-3)

    begin = time.time()
    for i in range(5):
        hits, misses, false_alarms, correct_negatives = get_hit_miss_counts(prediction=prediction,
                                                                            truth=truth,
                                                                            mask=mask)
    end = time.time()
    print("numpy hits misses:", end - begin)

    begin = time.time()
    for i in range(5):
        hits_numba, misses_numba, false_alarms_numba, correct_negatives_numba = get_hit_miss_counts_numba(
            prediction=prediction,
            truth=truth,
            mask=mask)
    end = time.time()
    print("numba hits misses:", end - begin)
    print(np.abs(hits - hits_numba).max())
    print(np.abs(misses - misses_numba).max(), np.abs(misses - misses_numba).argmax())
    print(np.abs(false_alarms - false_alarms_numba).max(),
          np.abs(false_alarms - false_alarms_numba).argmax())
    print(np.abs(correct_negatives - correct_negatives_numba).max(),
          np.abs(correct_negatives - correct_negatives_numba).argmax())

    begin = time.time()
    for i in range(5):
        weights_npy = get_balancing_weights(data=truth, mask=mask,
                                            base_balancing_weights=None, thresholds=None)
    end = time.time()
    print("numpy balancing weights:", end - begin)

    begin = time.time()
    for i in range(5):
        weights_numba = get_balancing_weights_numba(data=truth, mask=mask,
                                                    base_balancing_weights=None, thresholds=None)
    end = time.time()
    print("numba balancing weights:", end - begin)
    print("Inconsistent Number:", (np.abs(weights_npy - weights_numba) > 1E-5).sum())
