import os
import pickle
import numpy as np
import argparse


parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data', help='data directory')
parser.add_argument('--num_joints', type=int, default=24, help='number of joints')
parser.add_argument('--norm_orientation', action='store_true', help='whether to norm orientation')

args = parser.parse_args()

def main(args):
    with open(os.path.join(args.data_dir, 'concept_anchor_fit_param.pkl'), 'rb') as f:
        concept_anchor_fit_param = pickle.load(f)
    
    for concept in concept_anchor_fit_param.keys():
        input_keypoints = concept_anchor_fit_param[concept]["input_keypoints"]
        anchor_gather = []
        start_anchor_gather = []
        end_anchor_gather = []
        start_anchor_idx_record = []
        end_anchor_idx_record = []
        anchor_keypointseq_dict = {}
        anchor_idx = 0
        for i in range(len(input_keypoints)):
            for j in range(len(input_keypoints[i][0])):
                tmp_keypoints = []
                for k in range(args.num_joints):
                    tmp_keypoints.append(input_keypoints[i][k][j])
                tmp_keypoints = np.array(tmp_keypoints) # (24, T, 3)
                tmp_keypoints = np.transpose(tmp_keypoints, (1, 0, 2)) # (T, 24, 3)
                anchor_gather.append(tmp_keypoints[0])
                if j == 0:
                    start_anchor_gather.append(tmp_keypoints[0])
                    start_anchor_idx_record.append(anchor_idx)
                anchor_keypointseq_dict[anchor_idx] = {}
                anchor_keypointseq_dict[anchor_idx]["anchor"] = tmp_keypoints[0]
                anchor_keypointseq_dict[anchor_idx]["keypointseq_idx"] = [i, j]
                anchor_keypointseq_dict[anchor_idx]["keypointseq"] = input_keypoints[i]
                anchor_keypointseq_dict[anchor_idx]["keypointseq_part"] = tmp_keypoints
                anchor_idx += 1
                if j == len(input_keypoints[i][0])-1:
                    end_anchor_gather.append(tmp_keypoints[-1])
                    end_anchor_idx_record.append(anchor_idx)
                    anchor_gather.append(tmp_keypoints[-1])
                    anchor_keypointseq_dict[anchor_idx] = {}
                    anchor_keypointseq_dict[anchor_idx]["anchor"] = tmp_keypoints[-1]
                    anchor_keypointseq_dict[anchor_idx]["keypointseq_idx"] = [i, j]
                    anchor_keypointseq_dict[anchor_idx]["keypointseq"] = input_keypoints[i]
                    anchor_keypointseq_dict[anchor_idx]["keypointseq_part"] = tmp_keypoints
                    anchor_idx += 1
        anchor_gather = np.array(anchor_gather)
        start_anchor_gather = np.array(start_anchor_gather)
        end_anchor_gather = np.array(end_anchor_gather)

        start_anchor_idx_record = np.array(start_anchor_idx_record)
        end_anchor_idx_record = np.array(end_anchor_idx_record)
        if not os.path.exists(os.path.join(args.data_dir, "anchor_data")):
            os.makedirs(os.path.join(args.data_dir, "anchor_data"))
        if not os.path.exists(os.path.join(args.data_dir, "start_anchor_data")):
            os.makedirs(os.path.join(args.data_dir, "start_anchor_data"))
        if not os.path.exists(os.path.join(args.data_dir, "end_anchor_data")):
            os.makedirs(os.path.join(args.data_dir, "end_anchor_data"))
        np.save(os.path.join(args.data_dir, "anchor_data", "anchor_gather_"+str(concept)+".npy"), anchor_gather)
        np.save(os.path.join(args.data_dir, "start_anchor_data", "start_anchor_gather_"+str(concept)+".npy"), start_anchor_gather)
        np.save(os.path.join(args.data_dir, "end_anchor_data", "end_anchor_gather_"+str(concept)+".npy"), end_anchor_gather)
        np.save(os.path.join(args.data_dir, "start_anchor_data", "start_anchor_idx_record_"+str(concept)+".npy"), start_anchor_idx_record)
        np.save(os.path.join(args.data_dir, "end_anchor_data", "end_anchor_idx_record_"+str(concept)+".npy"), end_anchor_idx_record)
        with open(os.path.join(args.data_dir, "anchor_data", "anchor_keypointseq_dict_"+str(concept)+".pkl"), 'wb') as f:
            pickle.dump(anchor_keypointseq_dict, f)

        if args.norm_orientation:
            anchor_gather = anchor_gather[..., [0, 2, 1]]
            vector_1 = anchor_gather[:, 1, :] - anchor_gather[:, 3, :]
            vector_2 = anchor_gather[:, 2, :] - anchor_gather[:, 3, :]
            rotated_gather = []
            for j in range(vector_1.shape[0]):
                cross_product = np.cross(vector_1[j], vector_2[j])
                norm = np.linalg.norm(cross_product, axis=0)
                cos_theta = cross_product[1] / norm
                sin_theta = cross_product[0] / norm
                if cross_product[1] < 0:
                    theta = -np.arcsin(sin_theta)
                else:
                    theta = np.arcsin(sin_theta) + np.pi

                rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                            [np.sin(theta), np.cos(theta), 0],
                                            [0, 0, 1]])
                
                tmp_rotated_gather = np.dot(rotation_matrix, anchor_gather[j].T).T
                tmp_vector_1 = tmp_rotated_gather[1] - tmp_rotated_gather[3]
                tmp_vector_2 = tmp_rotated_gather[2] - tmp_rotated_gather[3]
                tmp_cross_product = np.cross(tmp_vector_1, tmp_vector_2)
                tmp_norm = np.linalg.norm(tmp_cross_product, axis=0)
                tmp_cross_product /= tmp_norm
                rotated_gather.append(tmp_rotated_gather)
            rotated_gather = np.array(rotated_gather)
            rotated_gather = rotated_gather[..., [0, 2, 1]]
            np.save(os.path.join(args.data_dir, "anchor_data", "rotated_anchor_gather_"+str(concept)+".npy"), rotated_gather)

            start_anchor_gather = start_anchor_gather[..., [0, 2, 1]]
            vector_1 = start_anchor_gather[:, 1, :] - start_anchor_gather[:, 3, :]
            vector_2 = start_anchor_gather[:, 2, :] - start_anchor_gather[:, 3, :]
            rotated_start_anchor_gather = []
            for j in range(vector_1.shape[0]):
                cross_product = np.cross(vector_1[j], vector_2[j])
                norm = np.linalg.norm(cross_product, axis=0)
                cos_theta = cross_product[1] / norm
                sin_theta = cross_product[0] / norm
                theta = -np.arcsin(sin_theta)

                rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                            [np.sin(theta), np.cos(theta), 0],
                                            [0, 0, 1]])
                
                tmp_rotated_start_anchor_gather = np.dot(rotation_matrix, start_anchor_gather[j].T).T
                rotated_start_anchor_gather.append(tmp_rotated_start_anchor_gather)
            rotated_start_anchor_gather = np.array(rotated_start_anchor_gather)
            rotated_start_anchor_gather = rotated_start_anchor_gather[..., [0, 2, 1]]
            np.save(os.path.join(args.data_dir, "start_anchor_data", "rotated_start_anchor_gather_"+str(concept)+".npy"), rotated_start_anchor_gather)

            end_anchor_gather = end_anchor_gather[..., [0, 2, 1]]
            vector_1 = end_anchor_gather[:, 1, :] - end_anchor_gather[:, 3, :]
            vector_2 = end_anchor_gather[:, 2, :] - end_anchor_gather[:, 3, :]
            rotated_end_anchor_gather = []
            for j in range(vector_1.shape[0]):
                cross_product = np.cross(vector_1[j], vector_2[j])
                norm = np.linalg.norm(cross_product, axis=0)
                cos_theta = cross_product[1] / norm
                sin_theta = cross_product[0] / norm
                theta = -np.arcsin(sin_theta)

                rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                            [np.sin(theta), np.cos(theta), 0],
                                            [0, 0, 1]])
                
                tmp_rotated_end_anchor_gather = np.dot(rotation_matrix, end_anchor_gather[j].T).T
                rotated_end_anchor_gather.append(tmp_rotated_end_anchor_gather)
            rotated_end_anchor_gather = np.array(rotated_end_anchor_gather)
            rotated_end_anchor_gather = rotated_end_anchor_gather[..., [0, 2, 1]]
            np.save(os.path.join(args.data_dir, "end_anchor_data", "rotated_end_anchor_gather_"+str(concept)+".npy"), rotated_start_anchor_gather)


        print("==> concept:", concept, "anchor_gather:", anchor_gather.shape, "start_anchor_gather:", start_anchor_gather.shape, "end_anchor_gather:", end_anchor_gather.shape, "start_anchor_idx_record:", start_anchor_idx_record.shape, "end_anchor_idx_record:", end_anchor_idx_record.shape)


if __name__ == "__main__":
    main(args)
