import numpy as np
import open3d as o3d
from matplotlib import pyplot as plt
from matplotlib import image
import cv2
from skimage.io import imshow

# with open('label_cen/0.txt') as f:
#     labels = f.readline().split(' ')[1:19]

img = cv2.resize(plt.imread("/home/mahmoud/PycharmProjects/data/GUIMOD_low/RGB/4578.png"), (640, 480))


# fig, ax = plt.subplots()
# ax.imshow(img)
#
# for n in range(9):
#     x = float(labels[n*2])*640.0
#     y = float(labels[(n*2)+1])*480.0
#     ax.plot(x, y, marker='.', color="red")
#
# plt.imshow(img)
# plt.show()


def labelDrawPoints(drawList):  # (b, f = back, front), (l, r = left, right), (u, d = up , down)
    drawDict = {}
    drawDict['bld'] = ((int(drawList[0][0])), int(drawList[0][1]))
    drawDict['blu'] = ((int(drawList[1][0])), int(drawList[1][1]))
    drawDict['fld'] = ((int(drawList[2][0])), int(drawList[2][1]))
    drawDict['flu'] = ((int(drawList[3][0])), int(drawList[3][1]))
    drawDict['brd'] = ((int(drawList[4][0])), int(drawList[4][1]))
    drawDict['bru'] = ((int(drawList[5][0])), int(drawList[5][1]))
    drawDict['frd'] = ((int(drawList[6][0])), int(drawList[6][1]))
    drawDict['fru'] = ((int(drawList[7][0])), int(drawList[7][1]))
    return drawDict


def drawPose(img, drawPoints, colour=(255, 0, 0)):  # draw bounding box

    cv2.line(img, drawPoints['bld'], drawPoints['blu'], colour, 2)
    cv2.line(img, drawPoints['bld'], drawPoints['fld'], colour, 2)
    cv2.line(img, drawPoints['bld'], drawPoints['brd'], colour, 2)
    cv2.line(img, drawPoints['blu'], drawPoints['flu'], colour, 2)
    cv2.line(img, drawPoints['blu'], drawPoints['bru'], colour, 2)
    cv2.line(img, drawPoints['fld'], drawPoints['flu'], colour, 2)
    cv2.line(img, drawPoints['fld'], drawPoints['frd'], colour, 2)
    cv2.line(img, drawPoints['flu'], drawPoints['fru'], colour, 2)
    cv2.line(img, drawPoints['fru'], drawPoints['bru'], colour, 2)
    cv2.line(img, drawPoints['fru'], drawPoints['frd'], colour, 2)
    cv2.line(img, drawPoints['frd'], drawPoints['brd'], colour, 2)
    cv2.line(img, drawPoints['brd'], drawPoints['bru'], colour, 2)


def showImage(img):  # displays image using plt
    imshow(img)
    plt.show()


keypoint_2d = [[194.79502406, 300.25210315],
               [309.04170512, 413.46240758],
               [224.71282568, 405.5704591],
               [264.29383834, 437.88191945],
               [221.23326482, 345.66924298],
               [251.84921935, 391.24423743],
               [199.0563269, 341.23706149],
               [211.03029425, 374.63247381]]

gt = [[306.51963122, 414.64622534],
      [224.6858743, 404.66037888],
      [262.87388012, 437.42722657],
      [222.16654847, 346.77501763],
      [251.48638828, 392.17221527],
      [200.20518731, 340.78415411],
      [211.58938474, 373.8630176],
      [234.1159544, 375.90692153]]

check = [[4.097115187633312838e-01*480, 4.704518039073152247e-01*640],
         [6.385825650449137303e-01*480, 6.478847270905010447e-01*640] ,
         [4.680955714493511555e-01*480, 6.322818420027060959e-01*640],
[5.476539169239853511e-01*480, 6.834800415222384018e-01*640],
[4.628469759865801447e-01*480, 5.418359650516325621e-01*640],
[5.239299755790750579e-01*480, 6.127690863584746017e-01*640],
[4.170941402320869629e-01*480, 5.324752408015428484e-01*640],
[4.408112182124990786e-01*480, 5.841609649977563823e-01*640]]


# out = np.zeros((img.shape[0], img.shape[1], 16))
# fig, ax = plt.subplots()
# ax.imshow(img)
# for n in range(8):
#     point = keypoint_2d[n]
#     ax.plot(point[0], point[1], marker='.', color="red")
#     point_gt = gt[n]
#     ax.plot(point_gt[0], point_gt[1], marker='.', color="blue")
#
# plt.imshow(img)
# plt.show()

copy_img = img.copy()
drawPose(copy_img, labelDrawPoints(check), (0, 1, 0))
showImage(copy_img)