import xarray as xr
from xrspatial import proximity
import numpy as np
import fire


def nearest_neighbor_distance(gt_range_map):
    """
    gt_range_map: binary valued numpy array
    """
    raster = xr.DataArray(gt_range_map, dims=["y", "x"], name="raster")
    proximity_agg = np.array(proximity(raster))
    return proximity_agg


def pwcd_dist(pred_range_map, gt_range_map, alpha=0.1, fp_only="False"):
    """
    pred_range_map: numpy array
    gt_range_map: binary valued numpy array
    """
    pwcd = 1 - np.exp(-alpha * pred_range_map * nearest_neighbor_distance(gt_range_map))
    if fp_only == "True":
        return np.sum(pwcd)
    elif fp_only == "False":
        return np.sum(pwcd) / np.sum(gt_range_map == 0)
    else:
        return np.sum(pwcd) / np.sum(gt_range_map == 0), np.sum(pwcd)


def average_precision_score_faster(y_true, y_scores, pwcd_fps):
    # drop in replacement for sklearn's average_precision_score
    # comparable up to floating point differences
    num_positives = y_true.sum()
    inds = np.argsort(y_scores)[::-1]
    y_true_s = y_true[inds]
    cdist = pwcd_fps  # Replace false positives with pwcd
    cdist = cdist[inds]

    false_pos_c = np.cumsum(cdist)
    true_pos_c = np.cumsum(y_true_s)
    recall = true_pos_c / num_positives
    false_neg = np.maximum(true_pos_c + false_pos_c, np.finfo(np.float32).eps)
    precision = true_pos_c / false_neg

    recall_e = np.hstack((0, recall, 1))
    recall_e = (recall_e[1:] - recall_e[:-1])[:-1]
    map_score = (recall_e * precision).sum()
    return map_score


def get_all_metrics(
    pred_range_map,
    gt_range_map,
    mask_inds,
    alpha=0.1,
):
    pwcd, pwcd_fps = pwcd_dist(
        pred_range_map, gt_range_map, alpha=alpha, fp_only="Both"
    )
    preds = pred_range_map.reshape(-1)[mask_inds]
    gts = gt_range_map.reshape(-1)[mask_inds]
    ap = average_precision_score_faster(gts, preds, pwcd_fps)
    return ap, pwcd


if __name__ == "__main__":
    import fire

    fire.Fire(pwcd_dist)
