import numpy as np
import open3d as o3d
from matplotlib import pyplot as plt
from matplotlib import image
import cv2
from skimage.io import imshow


def fps(points, centroid, n_samples):
    """
    points: [N, 3] array containing the whole point cloud
    n_samples: samples you want in the sampled point cloud typically << N
    """
    points = np.array(points)

    # Represent the points by their indices in points
    points_left = np.arange(len(points))  # [P]

    # Initialise an array for the sampled indices
    sample_inds = np.zeros(n_samples, dtype='int')  # [S]

    # Initialise distances to inf
    dists = np.ones_like(points_left) * float('inf')  # [P]

    # Select a point from points by its index, save it
    selected = 0
    sample_inds[0] = points_left[selected]

    # Delete selected
    points_left = np.delete(points_left, selected)  # [P - 1]

    # Iteratively select points for a maximum of n_samples
    for i in range(1, n_samples):
        # Find the distance to the last added point in selected
        # and all the others
        last_added = sample_inds[i - 1]

        dist_to_last_added_point = ((points[last_added] - points[points_left]) ** 2).sum(-1)  # [P - i]

        # If closer, updated distances
        dists[points_left] = np.minimum(dist_to_last_added_point, dists[points_left])  # [P - i]

        # We want to pick the one that has the largest nearest neighbour
        # distance to the sampled points
        selected = np.argmax(dists[points_left])
        sample_inds[i] = points_left[selected]

        # Update points_left
        points_left = np.delete(points_left, selected)

    return points[sample_inds]


def labelDrawPoints(drawList):  # (b, f = back, front), (l, r = left, right), (u, d = up , down)
    drawDict = {}
    drawDict['bld'] = ((int(drawList[0][0])), int(drawList[0][1]))
    drawDict['blu'] = ((int(drawList[1][0])), int(drawList[1][1]))
    drawDict['fld'] = ((int(drawList[2][0])), int(drawList[2][1]))
    drawDict['flu'] = ((int(drawList[3][0])), int(drawList[3][1]))
    drawDict['brd'] = ((int(drawList[4][0])), int(drawList[4][1]))
    drawDict['bru'] = ((int(drawList[5][0])), int(drawList[5][1]))
    drawDict['frd'] = ((int(drawList[6][0])), int(drawList[6][1]))
    drawDict['fru'] = ((int(drawList[7][0])), int(drawList[7][1]))
    return drawDict


def drawPose(img, drawPoints, colour=(255, 0, 0)):  # draw bounding box

    cv2.line(img, drawPoints['bld'], drawPoints['blu'], colour, 2)
    cv2.line(img, drawPoints['bld'], drawPoints['fld'], colour, 2)
    cv2.line(img, drawPoints['bld'], drawPoints['brd'], colour, 2)
    cv2.line(img, drawPoints['blu'], drawPoints['flu'], colour, 2)
    cv2.line(img, drawPoints['blu'], drawPoints['bru'], colour, 2)
    cv2.line(img, drawPoints['fld'], drawPoints['flu'], colour, 2)
    cv2.line(img, drawPoints['fld'], drawPoints['frd'], colour, 2)
    cv2.line(img, drawPoints['flu'], drawPoints['fru'], colour, 2)
    cv2.line(img, drawPoints['fru'], drawPoints['bru'], colour, 2)
    cv2.line(img, drawPoints['fru'], drawPoints['frd'], colour, 2)
    cv2.line(img, drawPoints['frd'], drawPoints['brd'], colour, 2)
    cv2.line(img, drawPoints['brd'], drawPoints['bru'], colour, 2)


def showImage(img):  # displays image using plt
    imshow(img)
    plt.show()


def apply_fps(pcd, fps_num):
    point_cloud_in_numpy = np.asarray(pcd.points)
    center = point_cloud_in_numpy.mean(0)
    fps_points = fps(point_cloud_in_numpy, center, fps_num)

    return fps_points


def process2(pcd, R_exp, tVec, camera, img, vis= True):

    # point_cloud_in_numpy = np.asarray(pcd.points)
    # center = point_cloud_in_numpy.mean(0)
    #
    # new_point = fps(point_cloud_in_numpy, center, fps_num)
    # print(new_point)

    # pcd_fps = o3d.geometry.PointCloud()
    # pcd_fps.points = o3d.utility.Vector3dVector(pcd)

    camera = np.array(camera)
    R_exp = np.array(R_exp, dtype="float64")
    tVec = np.array(tVec, dtype="float64")

    pcd_fps_numpy = np.asarray(pcd)
    keypoint_2d = cv2.projectPoints(pcd_fps_numpy, R_exp, tVec, camera, np.zeros(shape=[8, 1], dtype='float64'))

    # for n in range(len(pcd_fps_numpy)):
    #     print(pcd_fps_numpy[n], '==>', keypoint_2d[0][n])

    if vis:
        out = np.zeros((img.shape[0], img.shape[1], 16))
        fig, ax = plt.subplots()
        ax.imshow(img)
        for n in range(len(pcd_fps_numpy)):
            point = keypoint_2d[0][n]
            ax.plot(point[0][0], point[0][1], marker='.', color="red")

        plt.imshow(img)
        plt.show()
    return keypoint_2d[0]


def process(pcd_box, pcd2, R_exp, tVec, camera, img):

    camera = np.array(camera)
    R_exp = np.array(R_exp, dtype="float64")
    tVec = np.array(tVec, dtype="float64")

    pcd2_in_numpy2 = np.asarray(pcd2.points)
    pcd2_in_numpy = pcd_box

    keypoint_2d = cv2.projectPoints(pcd2_in_numpy, R_exp, tVec, camera, np.zeros(shape=[5, 1], dtype='float64'))
    keypoint_2d2 = cv2.projectPoints(pcd2_in_numpy2, R_exp, tVec, camera, np.zeros(shape=[5, 1], dtype='float64'))

    # for n in range(len(pcd2_in_numpy)):
    #     print(pcd2_in_numpy[n], '==>', keypoint_2d[0][n])

    points = []
    for n in range(len(pcd2_in_numpy)):
        point = keypoint_2d[0][n]
        points.append(point[0])

    copy_img = img.copy()
    fig, ax = plt.subplots()
    ax.imshow(copy_img)

    for n in range(len(pcd2_in_numpy2)):
        point = keypoint_2d2[0][n]
        plt.plot(int(point[0][0]), int(point[0][1]), marker='.', color="red")

    copy_img = img.copy()
    drawPose(copy_img, labelDrawPoints(points), (0, 1, 0))
    showImage(copy_img)


# ==============================================================================

def generate_fps(data_name, camera, vis=False, resize=False):
    # Read the point cloud

    for obj in ["Banana"]:
        obj_id = 0
        point_cloud = f'{data_name}/Models/{obj}/{obj.lower()}.ply'
        pcd = o3d.io.read_point_cloud(point_cloud)

        fps_points = apply_fps(pcd, 8)
        # np.savetxt(f'{data_name}/FPS/{obj}_fps_3d.txt', fps_points)

        point_cloud_in_numpy = np.asarray(pcd.points)
        center = fps_points.mean(0)
        fps_points = np.append(fps_points, [center], axis=0)

        for i in range(4995):
            img = image.imread(f"{data_name}/RGB/{i}.png")
            if resize:
                img = cv2.resize(img.copy(), (640, 480))

            pose = np.load(f'{data_name}/Pose_transformed/{obj}/{i}.npy')
            R_exp = pose[0:3, 0:3]
            tVec = pose[0:3, 3]

            # process(pcd_bbox, pcd, R_exp, tVec, camera, img)
            points = process2(fps_points, R_exp, tVec, camera, img, vis)
            # out = np.zeros((1, 17))
            out = [float(obj_id)]
            ind = 1
            for point in points:
                # out[0][ind] = point[0][0] / img.shape[1]
                # out[0][ind + 1] = point[0][1] / img.shape[0]
                x = point[0][0] / img.shape[1]
                y = point[0][1] / img.shape[0]
                out.append(x)
                out.append(y)
                ind += 2
            np.savetxt(f'label_cen/{i}.txt', np.array(out).reshape(1, len(out)))
            print("stop")
    obj_id += 1


if __name__ == '__main__':
    choice = "low"
    data_options = {"high": "ground_truth_rgb",
                    "low": "ground_truth_depth"}

    dataset_type = data_options[choice]

    dataset_name = f"GUIMOD_{choice}"

    if choice == 'high':
        camera = np.matrix([[1386.4138492513919, 0.0, 960.5],
                            [0.0, 1386.4138492513919, 540.5],
                            [0.0, 0.0, 1.0]])
    else:
        camera = np.matrix([[1086.5054444841007, 0.0, 640.5],
                            [0.0, 1086.5054444841007, 360.5],
                            [0.0, 0.0, 1.0]])

    # resize image to 640*480
    trans = np.matrix([[0.5, 0.0, 0.0],
                       [0.0, (2 / 3), 0.0],
                       [0.0, 0.0, 1.0]])

    camera_new = trans @ camera

    generate_fps(dataset_name, camera_new, False, True)
