from icpData import globalTest as glt
import numpy as np
import os
from scipy.io import savemat
from natsort import natsorted
from scipy.spatial.transform import Rotation as Rot


# generate measurement matrix: option choose from 'happy' 'dragon' 'arma'
option = 'happy'

CONFIG = {
    'happy': {
        'folder_path': 'happy/data/happy_stand',
        'angle': 24,
    },
    'dragon': {
        'folder_path': 'dragon_stand',
        'angle': 24,
    },
    'arma': {
        'folder_path': 'Armadillo',
        'angle': 30,
    },
}

if option not in CONFIG:
    raise ValueError(f"Unknown option={option}. Choose from {list(CONFIG.keys())}")

folder_path = CONFIG[option]['folder_path']
angle = CONFIG[option]['angle']




def construct_pyMeasurement_rotinit(sourcePly, targetPly, theta): # input file path
    voxel_size = 0.005  # 0.05 means 5cm for this dataset
    SOURCE_PCD = sourcePly
    TARGET_PCD = targetPly


    source, target, source_down, target_down, source_fpfh, target_fpfh = glt.prepare_dataset(SOURCE_PCD, TARGET_PCD,
                                                                                         voxel_size)

    r = Rot.from_quat([0, np.sin(theta * np.pi / 360), 0, np.cos(theta * np.pi / 360)])
    rr = r.as_matrix().transpose()
    t = np.reshape(np.asarray([0, 0, 0]), (3, 1))
    Rt = np.concatenate((rr, t), axis=1)
    a = np.asarray([0, 0, 0, 1]).reshape((1, 4))
    rotationMatrix = np.concatenate((Rt, a), axis=0)

    # Local refinement
    result_icp = glt.refine_registration(source_down, target_down, source_fpfh, target_fpfh, voxel_size, rotationMatrix)

    A = result_icp.transformation
    fit = result_icp.fitness
    rmse = result_icp.inlier_rmse
    setsize = np.shape(np.asarray(result_icp.correspondence_set))[0]

    # glt.draw_registration_result(source, target, rotationMatrix)
    # glt.draw_registration_result(source, target, A)

    return A, fit, rmse, setsize



def construct_pyMeasurement_noinit(sourcePly, targetPly): # input file path
    voxel_size = 0.005  # 0.05 means 5cm for this dataset
    SOURCE_PCD = sourcePly
    TARGET_PCD = targetPly
    source, target, source_down, target_down, source_fpfh, target_fpfh = glt.prepare_dataset(SOURCE_PCD, TARGET_PCD,
                                                                                         voxel_size)

    # RANSAC
    result_ransac = glt.execute_global_registration(source_down, target_down,
                                                source_fpfh, target_fpfh,
                                                voxel_size)

    # Local refinement
    result_icp = glt.refine_registration(source_down, target_down, source_fpfh, target_fpfh, voxel_size, result_ransac.transformation)
    A = result_icp.transformation
    fit = result_icp.fitness
    rmse = result_icp.inlier_rmse
    setsize = np.shape(np.asarray(result_icp.correspondence_set))[0]

    # glt.draw_registration_result(source, target, A)

    return A, fit, rmse, setsize





ply_files = [f for f in os.listdir(folder_path) if f.endswith('.ply')]
ply_files = natsorted(ply_files)
num_files = len(ply_files)

meapyMat = np.zeros((num_files, num_files, 4,4))
fitpyMat = np.zeros((num_files, num_files))
rmsepyMat = np.zeros((num_files, num_files))
setsizepyMat = np.zeros((num_files, num_files))



for i in range(num_files):


    file_path1 = os.path.join(folder_path, ply_files[i])

    for j in range(i , num_files):

        file_path2 = os.path.join(folder_path, ply_files[j])

        theta = (j-i)*angle

        A, fit, rmse, setsize = construct_pyMeasurement_rotinit(file_path1, file_path2, theta)
        R = A[0:3,0:3].transpose()
        t = (A[0:3,3]).reshape((3,1))
        t = - R @ t
        Rt_temp =  np.concatenate((R,t),axis = 1)
        A_temp = np.concatenate((Rt_temp,[[0,0,0,1]]),axis = 0)

        meapyMat[i,j,:,:] = A_temp
        fitpyMat[i,j] = fit
        rmsepyMat[i,j] = rmse
        setsizepyMat[i,j] = setsize

        print('Finish pairing ({},{}) '.format(i, j))

savemat('meapyMat.mat', {'meapyMat': meapyMat})
savemat('fitpyMat.mat', {'fitpyMat': fitpyMat})
savemat('rmsepyMat.mat', {'rmsepyMat': rmsepyMat})
savemat('setsizepyMat.mat', {'setsizepyMat': setsizepyMat})
