'''
    This is a short demo to see how to load and use the SMAL model.
    Please read the README.txt file for requirements.

'''

import pickle
import os
import trimesh
import numpy as np
import sys
# from sklearn.decomposition import PCA
np.random.seed(0)

def solver_Rt(P, Q):
    ''' solve  min_{R, t} || R @ Q + t - P ||^2
    Args:
        P: [3, N], the target shape to be aligned to
        Q: [3, N], the source shape
    Returns:
        R_opt: [3, 3]
        t_opt: [3, 1]
    '''
    assert(P.shape[0] == Q.shape[0] == 3)
    assert(P.shape[1] == Q.shape[1])
    p_mean = np.mean(P, axis=-1, keepdims=True)  # [3, 1]
    q_mean = np.mean(Q, axis=-1, keepdims=True)  # [3, 1]
    Y = Q - q_mean
    X = P - p_mean
    # S = Y @ X.T  # [3, 3]
    S = Y.dot(X.T)  # [3, 3]
    U, _, Vh = np.linalg.svd(S, full_matrices=True)
    sign = np.linalg.det(S) > 0
    # R_opt = Vh.T @ np.diag([1, 1, 2 * sign - 1]) @ U.T
    R_opt = Vh.T.dot(np.diag([1, 1, 2 * sign - 1])).dot(U.T)
    # t_opt = p_mean - R_opt @ q_mean  # [3, 1]
    t_opt = p_mean - R_opt.dot(q_mean)  # [3, 1]
    # sanity check
    assert(np.linalg.det(R_opt) > 0)
    # diff = R_opt @ Q + t_opt - P
    diff = R_opt.dot(Q) + t_opt - P
    err = np.sum(diff * diff)
    # print(f'err = {err}')
    print(err)
    return R_opt, t_opt, err




if __name__ == '__main__':
    data_dir = '/path/to/the/THUman/dataset/'

    with open('./sample_id_128.txt', 'r') as f:
        fid_list = f.readlines()
    fid_list = [v.strip() for v in fid_list]

    mesh_template = trimesh.load('./smpl_male.obj', process=False)

    opt_ids = np.logical_and(-0.2 < mesh_template.vertices[:, 1], mesh_template.vertices[:, 1] < 0)

    out_dir = './registrations128/thuman/pose/'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    for fid in fid_list:
        mesh_smpl = trimesh.load(f"{data_dir}/{fid}/smpl.obj", process=False)
        mesh_fuse = trimesh.load(f"{data_dir}/{fid}/mesh.obj", process=False)
        R_i, t_i, err_i = solver_Rt(mesh_template.vertices[opt_ids, :].T, mesh_smpl.vertices[opt_ids, :].T)

        smpl_v_i_trans = R_i @ mesh_smpl.vertices.T + t_i
        fuse_v_i_trans = R_i @ mesh_fuse.vertices.T + t_i

        mesh_fuse_new = trimesh.Trimesh(vertices=fuse_v_i_trans.T, faces=mesh_fuse.faces, process=True)
        mesh_smpl_new = trimesh.Trimesh(vertices=smpl_v_i_trans.T, faces=mesh_smpl.faces, process=False)

        fid = '_'.join(fid.split('/'))
        mesh_fuse_new.export(f"{out_dir}/{fid}.obj")

