import numpy as np
import open3d as o3d
from matplotlib import pyplot as plt
from matplotlib import image
import cv2
from skimage.io import imshow
from pathlib import Path
from utils import compute_categories_id, compute_id_good_occ


def fps(points, centroid, n_samples):
    """
    points: [N, 3] array containing the whole point cloud
    n_samples: samples you want in the sampled point cloud typically << N
    """
    points = np.array(points)

    # Represent the points by their indices in points
    points_left = np.arange(len(points))  # [P]

    # Initialise an array for the sampled indices
    sample_inds = np.zeros(n_samples, dtype='int')  # [S]

    # Initialise distances to inf
    dists = np.ones_like(points_left) * float('inf')  # [P]

    # Select a point from points by its index, save it
    selected = 0
    sample_inds[0] = points_left[selected]

    # Delete selected
    points_left = np.delete(points_left, selected)  # [P - 1]

    # Iteratively select points for a maximum of n_samples
    for i in range(1, n_samples):
        # Find the distance to the last added point in selected
        # and all the others
        last_added = sample_inds[i - 1]

        dist_to_last_added_point = ((points[last_added] - points[points_left]) ** 2).sum(-1)  # [P - i]

        # If closer, updated distances
        dists[points_left] = np.minimum(dist_to_last_added_point, dists[points_left])  # [P - i]

        # We want to pick the one that has the largest nearest neighbour
        # distance to the sampled points
        selected = np.argmax(dists[points_left])
        sample_inds[i] = points_left[selected]

        # Update points_left
        points_left = np.delete(points_left, selected)

    return points[sample_inds]


def labelDrawPoints(drawList):  # (b, f = back, front), (l, r = left, right), (u, d = up , down)
    drawDict = {}
    drawDict['bld'] = ((int(drawList[0][0])), int(drawList[0][1]))
    drawDict['blu'] = ((int(drawList[1][0])), int(drawList[1][1]))
    drawDict['fld'] = ((int(drawList[2][0])), int(drawList[2][1]))
    drawDict['flu'] = ((int(drawList[3][0])), int(drawList[3][1]))
    drawDict['brd'] = ((int(drawList[4][0])), int(drawList[4][1]))
    drawDict['bru'] = ((int(drawList[5][0])), int(drawList[5][1]))
    drawDict['frd'] = ((int(drawList[6][0])), int(drawList[6][1]))
    drawDict['fru'] = ((int(drawList[7][0])), int(drawList[7][1]))
    return drawDict


def drawPose(img, drawPoints, colour=(255, 0, 0)):  # draw bounding box

    cv2.line(img, drawPoints['bld'], drawPoints['blu'], colour, 2)
    cv2.line(img, drawPoints['bld'], drawPoints['fld'], colour, 2)
    cv2.line(img, drawPoints['bld'], drawPoints['brd'], colour, 2)
    cv2.line(img, drawPoints['blu'], drawPoints['flu'], colour, 2)
    cv2.line(img, drawPoints['blu'], drawPoints['bru'], colour, 2)
    cv2.line(img, drawPoints['fld'], drawPoints['flu'], colour, 2)
    cv2.line(img, drawPoints['fld'], drawPoints['frd'], colour, 2)
    cv2.line(img, drawPoints['flu'], drawPoints['fru'], colour, 2)
    cv2.line(img, drawPoints['fru'], drawPoints['bru'], colour, 2)
    cv2.line(img, drawPoints['fru'], drawPoints['frd'], colour, 2)
    cv2.line(img, drawPoints['frd'], drawPoints['brd'], colour, 2)
    cv2.line(img, drawPoints['brd'], drawPoints['bru'], colour, 2)


def showImage(img):  # displays image using plt
    imshow(img)
    plt.show()


def apply_fps(pcd, fps_num):
    point_cloud_in_numpy = np.asarray(pcd.points)
    center = point_cloud_in_numpy.mean(0)
    fps_points = fps(point_cloud_in_numpy, center, fps_num)

    return fps_points


def process2(pcd, R_exp, tVec, camera, img, vis= True):

    camera = np.array(camera)
    R_exp = np.array(R_exp, dtype="float64")
    tVec = np.array(tVec, dtype="float64")

    pcd_fps_numpy = np.asarray(pcd)
    keypoint_2d = cv2.projectPoints(pcd_fps_numpy, R_exp, tVec, camera, np.zeros(shape=[5, 1], dtype='float64'))

    # for n in range(len(pcd_fps_numpy)):
    #     print(pcd_fps_numpy[n], '==>', keypoint_2d[0][n])

    if vis:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        out = np.zeros((img.shape[0], img.shape[1], 16))
        fig, ax = plt.subplots()
        ax.imshow(img)
        for n in range(len(pcd_fps_numpy)):
            point = keypoint_2d[0][n]
            ax.plot(point[0][0], point[0][1], marker='.', color="red")
        plt.imshow(img)
        plt.show()
    return keypoint_2d[0]


def process(pcd_box, pcd2, R_exp, tVec, camera, img):

    camera = np.array(camera)
    R_exp = np.array(R_exp, dtype="float64")
    tVec = np.array(tVec, dtype="float64")

    pcd2_in_numpy2 = np.asarray(pcd2.points)
    pcd2_in_numpy = pcd_box

    keypoint_2d = cv2.projectPoints(pcd2_in_numpy, R_exp, tVec, camera, np.zeros(shape=[5, 1], dtype='float64'))
    keypoint_2d2 = cv2.projectPoints(pcd2_in_numpy2, R_exp, tVec, camera, np.zeros(shape=[5, 1], dtype='float64'))

    # for n in range(len(pcd2_in_numpy)):
    #     print(pcd2_in_numpy[n], '==>', keypoint_2d[0][n])

    points = []
    for n in range(len(pcd2_in_numpy)):
        point = keypoint_2d[0][n]
        points.append(point[0])

    copy_img = img.copy()
    fig, ax = plt.subplots()
    ax.imshow(copy_img)

    for n in range(len(pcd2_in_numpy2)):
        point = keypoint_2d2[0][n]
        plt.plot(int(point[0][0]), int(point[0][1]), marker='.', color="red")

    copy_img = img.copy()
    drawPose(copy_img, labelDrawPoints(points), (0, 1, 0))
    showImage(copy_img)


# ==============================================================================

def generate_fps(data_name, camera, Nb_camera, Nb_world, list_categories, occ_target, vis=False):
    # Read the point cloud
    # for categories in list_categories:
    #     obj_id = 1
    #     point_cloud = f'{data_name}/Generated/Models/{categories}/{categories.lower()}.ply'
    #     pcd = o3d.io.read_point_cloud(point_cloud)

    #     print("pcd", pcd)


    #     fps_points = apply_fps(pcd, 200)
    #     #print(fps_points)
    #     np.savetxt(f'{data_name}/Generated/FPS/{categories}_fps_3d.txt', fps_points)
    
    for i in range(1, Nb_world + 1): # worlds
        catergories_instance_array_id_to_cat, catergories_instance_array_cat_to_id = compute_categories_id(data_name, i)
        for j in range(1, Nb_camera+1): # cameras
            p = ((i-1)*Nb_camera) + j
            catergories_occ_array = compute_id_good_occ(data_name, p, catergories_instance_array_id_to_cat, catergories_instance_array_cat_to_id, occ_target)
            for categories in list_categories:
                if len(catergories_occ_array[categories]) == 1 :

                    img = image.imread(f"{data_name}/RGB/{p}.png")

                    np.set_printoptions(precision=15)
                    pose = np.load(f'{data_name}/Generated/Pose_transformed/{categories}/{p}.npy')
                    #print(pose)
                    R_exp = pose[0:3, 0:3]
                    tVec = pose[0:3, 3]

                    #print(tVec)
                    # camera = np.matrix([[1386.4138492513919, 0.0, 960.5],
                    #                     [0.0, 1386.4138492513919, 540.5],
                    #                     [0.0, 0.0, 1.0]])

                    
                    fps_points = np.loadtxt(f'{data_name}/Generated/FPS/{categories}_fps_3d.txt')
                    # process(pcd_bbox, pcd, R_exp, tVec, camera, img)
                    points = process2(fps_points, R_exp, tVec, camera, img, vis)
                    out = np.zeros((1, 401))

                    out[0] = catergories_occ_array[categories] #obj_id #len have to be 1 !!
                    ind = 1
                    for point in points:
                        out[0][ind] = point[0][0] / img.shape[1]
                        out[0][ind + 1] = point[0][1] / img.shape[0]
                        ind += 2
                    np.savetxt(f'{data_name}/Generated/FPS/{categories}/{p}.txt', out)


