import numpy as np
import cv2
from util.LAP.cal_LAP import cal_metric as lap


def F1_score_128(pred_lines_128_list, gt_lines_128_list, thickness=3):
    """
     @brief heat  F1 score, draw the lines to a 128 * 128 img
     @pred_lines_128 [ [x0, y0, x1, y1],  ... ]
     @gt_lines_128_list [ [x0, y0, x1, y1],  ... ]
    """
    pred_heatmap = np.zeros((128, 128), np.uint8)
    gt_heatmap = np.zeros((128, 128), np.uint8)

    for l in pred_lines_128_list:
        x0, y0, x1, y1 = l
        x0 = int(round(x0))
        y0 = int(round(y0))
        x1 = int(round(x1))
        y1 = int(round(y1))
        cv2.line(pred_heatmap, (x0, y0), (x1, y1), (1, 1, 1), thickness, 8)

    for l in gt_lines_128_list:
        x0, y0, x1, y1 = l
        x0 = int(round(x0))
        y0 = int(round(y0))
        x1 = int(round(x1))
        y1 = int(round(y1))
        cv2.line(gt_heatmap, (x0, y0), (x1, y1), (1, 1, 1), thickness, 8)

    pred_heatmap = np.array(pred_heatmap, np.float32)
    gt_heatmap = np.array(gt_heatmap, np.float32)

    intersection = np.sum(gt_heatmap * pred_heatmap)
    # union = np.sum(gt_heatmap) + np.sum(gt_heatmap)
    eps = 0.001
    # dice = (2. * intersection + eps) / (union + eps)

    recall = intersection /(np.sum(gt_heatmap) + eps)
    precision = intersection /(np.sum(pred_heatmap) + eps)

    fscore = (2 * precision * recall) / (precision + recall + eps)
    return fscore, recall, precision


def LAP(output_list, target_list):
    return lap(output_list, target_list)


def f_score(tp, fp):
    recall = tp
    precision = tp / np.maximum(tp + fp, 1e-9)

    recall = np.concatenate(([0.0], recall, [1.0]))
    precision = np.concatenate(([0.0], precision, [0.0]))

    Fscore = (2*precision*recall/(precision+recall+0.0000000001)).max()
    return Fscore


def sAP(output_list, target_list, threshold=5):
    n_gt = 0
    lcnn_tp, lcnn_fp, lcnn_scores = [], [], []
    for (pred_score, pred_line), gt_line in zip(output_list, target_list):
        score_idx = np.argsort(-pred_score)
        pred_line = pred_line[score_idx]
        pred_score = pred_score[score_idx]
        n_gt += gt_line.shape[0]

        tp, fp = msTPFP(pred_line.reshape(-1, 2, 2), gt_line.reshape(-1, 2, 2), threshold)  # 变成2x2
        lcnn_tp.append(tp)
        lcnn_fp.append(fp)
        lcnn_scores.append(pred_score)

    lcnn_tp = np.concatenate(lcnn_tp)
    lcnn_fp = np.concatenate(lcnn_fp)
    lcnn_scores = np.concatenate(lcnn_scores)
    lcnn_index = np.argsort(-lcnn_scores)
    lcnn_tp = np.cumsum(lcnn_tp[lcnn_index]) / n_gt
    lcnn_fp = np.cumsum(lcnn_fp[lcnn_index]) / n_gt
    return ap(lcnn_tp, lcnn_fp), f_score(lcnn_tp, lcnn_fp)


def ap(tp, fp):
    recall = tp
    precision = tp / np.maximum(tp + fp, 1e-9)

    recall = np.concatenate(([0.0], recall, [1.0]))
    precision = np.concatenate(([0.0], precision, [0.0]))

    for i in range(precision.size - 1, 0, -1):
        precision[i - 1] = max(precision[i - 1], precision[i])
    i = np.where(recall[1:] != recall[:-1])[0]
    return np.sum((recall[i + 1] - recall[i]) * precision[i + 1])


def msTPFP(line_pred, line_gt, threshold):
    diff = ((line_pred[:, None, :, None] - line_gt[:, None]) ** 2).sum(-1)
    diff = np.minimum(
        diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
    )
    choice = np.argmin(diff, 1)
    dist = np.min(diff, 1)
    hit = np.zeros(len(line_gt), bool)
    tp = np.zeros(len(line_pred), np.float32)
    fp = np.zeros(len(line_pred), np.float32)
    for i in range(len(line_pred)):
        if dist[i] < threshold and not hit[choice[i]]:
            hit[choice[i]] = True
            tp[i] = 1
        else:
            fp[i] = 1
    return tp, fp


def metric(output_list, target_list):
    if isinstance(output_list[0], list):
        sAP5_debug, Fs5_debug = sAP([item for sublist in output_list[::2] for item in sublist], target_list)
        sAP5, Fs5 = sAP([item[:2] for sublist in output_list[1::2] for item in sublist], target_list)
        sAP10_debug, Fs10_debug= sAP([item for sublist in output_list[::2] for item in sublist], target_list, 10)
        sAP10, Fs10 = sAP([item[:2] for sublist in output_list[1::2] for item in sublist], target_list, 10)
        sAP15, Fs15 = sAP([item[:2] for sublist in output_list[1::2] for item in sublist], target_list, 15)
        # lap = LAP([item[:2] for sublist in output_list[1::2] for item in sublist], target_list)
        return sAP5_debug, sAP5, sAP10_debug, sAP10, Fs5_debug, Fs5, Fs10_debug, Fs10, Fs15, Fs15
    sAP5, Fs5 = sAP([item[:2] for item in output_list], target_list)
    sAP10, Fs10 = sAP([item[:2] for item in output_list], target_list, 10)
    lap = LAP([item[:2] for item in output_list], target_list)
    return sAP5, sAP10, lap