import dq_sync as dqs
import numpy as np
from scipy.io import loadmat
from scipy.io import savemat
import time
import open3d as o3d
from natsort import natsorted
import os


mat_data = loadmat('icpData\meapyMat_bunny_sparse.mat')
rawdata = mat_data['meapyMat_q_refine']

maxiter = 30



# Apply noise to generate measurement matrix

t = rawdata[:,:,0:3]
r = rawdata[:,:,3:]
n = np.shape(t)[0]

C_partial = dqs.rt2dq(r, t)
C_partial = dqs.dq2dqmat(C_partial)
C = dqs._hermitizeDQmat(C_partial)

noise = dqs.dq2dqmat(dqs.rt2dq(*dqs.randomGaussianGaussian((n, n), sigma_r = dqs.angle2radians(1), sigma_t = 0.001)))
inds = np.diag_indices(n)
noise[inds[0], inds[1], :, :] = np.eye(4)

C = C @ noise
C = dqs._hermitizeDQmat(C)
C_gpm = dqs.mb2bm(C)





# do gpm

shape = (C_gpm.shape[0]//4, 1)
out = dqs.mb2bm(dqs.dq2dqmat(dqs.rt2dq(*dqs.randomUniformGaussian(shape))))

t1 = time.time()

for i in range(maxiter):
    out_check = dqs.dqmat2dq(dqs.normalize(dqs.bm2mb(out)))
    r_est,t_est = dqs.dq2rt(out_check)
    out = C_gpm @ out
    out = dqs._normalizeVector(out)

    rvec,tvec = dqs.dq2rt(out_check)
    ply_folder = os.path.join(os.getcwd(), 'bunny/data')

    print(i)


    ## load groundtruth and calculate best aligner w.r.t gt

    gt_data = loadmat('gtData\gtVec_bunny.mat')
    raw_data = gt_data['gtVec2']

    n = rvec.shape[0]
    r_ori = np.zeros((n,1,4))
    t_ori = np.zeros((n,1,3))

    for j in range(1, n+1):
        r_ori[j-1,0,0] = raw_data[j-1,3]
        r_ori[j-1,0,1:] = raw_data[j-1,4:]

    for j in range(1, n+1):
        t_ori[j-1,0,:] = raw_data[j-1,0:3]

    origin  = dqs.dq2dqmat(dqs.rt2dq(r_ori,t_ori))
    estimate = dqs.dq2dqmat(dqs.rt2dq(rvec,tvec))

    ba_dqmat = dqs.calculateBestAligner(estimate, origin)
    ba_r, ba_t = dqs.dq2rt(dqs.dqmat2dq(ba_dqmat))



    # Get the list of .ply files in the folder and sort them naturally
    ply_files = [f for f in os.listdir(ply_folder) if f.endswith('.ply')]
    ply_files = natsorted(ply_files)

    all_point_clouds = []

    colors = [
        [0, 0, 1],  # Blue
        [1, 1, 0],  # Yellow
        [1, 0.5, 0],  # Orange
        [0.5, 0, 1],  # Purple
        [0, 0.5, 1],  # Light Blue
        [1, 0, 0],  # Red
        [0, 1, 0],  # Green
        [1, 0, 1],  # Magenta
        [0, 1, 1],  # Cyan
        [0.5, 0.5, 0.5],  # Gray
    ]



    for s in range(n):
        file_path = os.path.join(ply_folder, ply_files[s])
        pointset_down = o3d.io.read_point_cloud(file_path)

        # Get the rotation and translation vectors
        r = rvec[s, 0, :]
        t = tvec[s, 0, :]
        ba_r = ba_r.reshape(-1)
        ba_t = ba_t.reshape(-1)

        # Convert rotation vector to rotation matrix
        r_matrix = o3d.geometry.get_rotation_matrix_from_quaternion(r)
        ba_r_matrix = o3d.geometry.get_rotation_matrix_from_quaternion(ba_r)

        # Apply the transformation
        pointset_down.translate(-t)  # Apply translation
        pointset_down.rotate(r_matrix.transpose(), center=(0, 0, 0))  # Rotate the point cloud

        pointset_down.translate(-ba_t)  # Apply translation
        pointset_down.rotate(ba_r_matrix.transpose(), center=(0, 0, 0))

        # Assign a single uniform color to the point cloud, if omit this part of code, the figure will be drawn in gradient color
        color = np.array(colors[s % len(colors)])# Cycle through the color list
        pointset_down.colors = o3d.utility.Vector3dVector(np.tile(color, (len(pointset_down.points), 1)))

        # Add the transformed point cloud to the list
        all_point_clouds.append(pointset_down)


    # Visualize all point clouds in one Open3D window
    o3d.visualization.draw_geometries(all_point_clouds)





