from util.LAP.lib.Evaluator_line import *
from util.LAP.lib.utils import *
import os

# def getBoundingBoxes(gt_path, pred_path):
#     """Read txt files containing bounding boxes (ground truth and detections)."""
#     allBoundingBoxes = BoundingBoxes()
#     import glob
#     import os
#     # Read ground truths
#     currentPath = os.path.dirname(os.path.abspath(__file__))
#     folderGT = os.path.join(currentPath, gt_path)#'groundtruths_york/new')
#     os.chdir(folderGT)
#     files = glob.glob("*.txt")
#     files.sort()
#     # Class representing bounding boxes (ground truths and detections)
#     allBoundingBoxes = BoundingBoxes()
#     for f in files:
#         nameOfImage = f.replace(".txt", "")
#         fh1 = open(f, "r")
#         for line in fh1:
#             line = line.replace("\n", "")
#             if line.replace(' ', '') == '':
#                 continue
#             splitLine = line.split(" ")
#             idClass = splitLine[0]  # class
#             x = float(splitLine[1])  # confidence
#             y = float(splitLine[2])
#             w = float(splitLine[3])
#             h = float(splitLine[4])
#             bb = BoundingBox(
#                 nameOfImage,
#                 idClass,
#                 x,
#                 y,
#                 w,
#                 h,
#                 CoordinatesType.Absolute, (128, 128),
#                 BBType.GroundTruth,
#                 format='GT')
#             allBoundingBoxes.addBoundingBox(bb)
#         fh1.close()
#     # Read detections
#     folderDet = os.path.join(currentPath, pred_path)#'TP_F_M_0.5/new')
#     print(folderDet)
#     os.chdir(folderDet)
#     files = glob.glob("*.txt")
#     files.sort()
#
#     for f in files:
#         # nameOfImage = f.replace("_det.txt","")
#         nameOfImage = f.replace(".txt", "")
#         # Read detections from txt file
#         fh1 = open(f, "r")
#         for line in fh1:
#             line = line.replace("\n", "")
#             if line.replace(' ', '') == '':
#                 continue
#             splitLine = line.split(" ")
#             idClass = splitLine[0]  # class
#             confidence = float(splitLine[1])  # confidence
#             x = float(splitLine[2])
#             y = float(splitLine[3])
#             w = float(splitLine[4])
#             h = float(splitLine[5])
#             bb = BoundingBox(
#                 nameOfImage,
#                 idClass,
#                 x,
#                 y,
#                 w,
#                 h,
#                 CoordinatesType.Absolute, (128, 128),
#                 BBType.Detected,
#                 confidence,
#                 format='GT')
#             allBoundingBoxes.addBoundingBox(bb)
#         fh1.close()
#     return allBoundingBoxes
#
# wire_gt = 'groundtruths_wire_line'
# york_gt = 'groundtruths_york_line'
# save_all = 'result/'
# os.makedirs(save_all, exist_ok=True)
#
# def cal_metric(gt_path, pred_path, save_file, thres):
#     boundingboxes = getBoundingBoxes(gt_path, pred_path)
#     save_file = save_all + save_file
#     save_path = save_all + save_file.split('.')[0] + '.png'
#     evaluator = Evaluator(save_file)
#     evaluator.PlotPrecisionRecallCurve(
#         boundingboxes,  # Object containing all bounding boxes (ground truths and detections)
#         Threshold=thres,  # LMS threshold
#         method=MethodAveragePrecision.EveryPointInterpolation,
#         showAP=True,  # Show Average Precision in the title of the plot
#         showInterpolatedPrecision=True,
#         savePath = save_path)  # Plot the interpolated precision curve
#     metricsPerClass = evaluator.results
#     print("Average precision values per class:\n")
#     for mc in metricsPerClass:
#         # Get metric values per each class
#         c = mc['class']
#         average_precision = mc['AP']
#         # Print AP per class
#         print('%s: %f' % (c, average_precision))
#
#
# Thress = [0.5]
# for thres in Thress:
#     wire_path = 'TP-LSD/wire'  # path for result
#     save_file = 'TP-LSD-wire-' + str(thres) + '.npz'
#     cal_metric(wire_gt, wire_path, save_file, thres)


### ----------------------------------------------------------------------
def getBoundingBoxes(target_list, output_list):  # x1, y1, x2, y2, no w, h
    """Read txt files containing bounding boxes (ground truth and detections)."""
    allBoundingBoxes = BoundingBoxes()

    # Read ground truths
    for i, ((pred_score, pred_line), gt_line) in enumerate(zip(output_list, target_list)):
        for gline in gt_line:
            idClass = 'line'  # class
            x = float(gline[0])  # confidence
            y = float(gline[1])
            w = float(gline[2])
            h = float(gline[3])
            bb = BoundingBox(
                str(i),
                idClass,
                x,
                y,
                w,
                h,
                CoordinatesType.Absolute, (128, 128),
                BBType.GroundTruth,
                format='GT')
            allBoundingBoxes.addBoundingBox(bb)

        for s, pline in zip(pred_score, pred_line):
            confidence = s  # confidence
            x = float(pline[0])
            y = float(pline[1])
            w = float(pline[2])
            h = float(pline[3])
            bb = BoundingBox(
                str(i),
                idClass,
                x,
                y,
                w,
                h,
                CoordinatesType.Absolute, (128, 128),
                BBType.Detected,
                confidence,
                format='GT')
            allBoundingBoxes.addBoundingBox(bb)
    return allBoundingBoxes


def cal_metric(output_list, target_list, save_file=None, thres=0.5):
    boundingboxes = getBoundingBoxes(target_list, output_list)
    evaluator = Evaluator(save_file)
    evaluator.PlotPrecisionRecallCurve(
        boundingboxes,  # Object containing all bounding boxes (ground truths and detections)
        Threshold=thres,  # LMS threshold
        method=MethodAveragePrecision.EveryPointInterpolation,
        showAP=False,  # Show Average Precision in the title of the plot
        showInterpolatedPrecision=False
    )  # Plot the interpolated precision curve
    metricsPerClass = evaluator.results
    for mc in metricsPerClass:
        # Get metric values per each class
        c = mc['class']
        average_precision = mc['AP']
        # Print AP per class
        # print('%s: %f' % (c, average_precision))
    return average_precision

if __name__ == "__main__":
    import glob
    # wire_gt = 'groundtruths_wire_line'
    # wire_path = 'TP-LSD/wire'
    #
    # currentPath = os.path.dirname(os.path.abspath(__file__))
    # folderGT = os.path.join(currentPath, wire_gt)
    # os.chdir(folderGT)
    # files = glob.glob("*.txt")
    # files.sort()
    # target_list = [np.loadtxt(file, delimiter=' ', usecols=(1, 2, 3, 4)) for file in files]
    #
    # folderDet = os.path.join(currentPath, wire_path)#'TP_F_M_0.5/new')
    # os.chdir(folderDet)
    # files = glob.glob("*.txt")
    # files.sort()
    # output_list = [(np.loadtxt(file, delimiter=' ', usecols=1), np.loadtxt(file, delimiter=' ', usecols=(2, 3, 4, 5)))
    #                for file in files]

    rootdir = "/disk2/dataset/wireframe"
    filelist = glob.glob(f"{rootdir}/valid/*_0_label.npz")
    output_list = []
    target_list = []
    for name in filelist:
        with np.load(name) as npz:
            lines = np.random.permutation(npz["lpos"])[:, :, [1, 0]].reshape(-1, 4) * 4  # from [yxyx] to [xyxy]
        # from left to right, from top to bottom
        index = lines[:, 1] > lines[:, 3]
        lines[index] = lines[index][:, [2, 3, 0, 1]]
        index = lines[:, 0] > lines[:, 2]
        lines[index] = lines[index][:, [2, 3, 0, 1]]
        target_list.append(lines / 4)
        output_list.append((np.ones_like(lines[:, 0]), lines / 4))
    print(cal_metric(output_list, target_list))