import os
import pickle
import numpy as np
import argparse
import matplotlib.pyplot as plt


parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/HumanAct12_filtered/saved_anchor_data_humanact12_r400', help='data directory')
parser.add_argument('--num_joints', type=int, default=24, help='number of joints')
parser.add_argument('--dataset_name', type=str, default='humanact12', help='dataset name')

args = parser.parse_args()

humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]
uestc_limbs = [(0, 1), (0, 9), (9, 10), (10, 11), (11, 16), (0, 12), (12, 13), (13, 14), (14, 15), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7), (1, 8), (8, 17)]


def visualize_anchor_pos(save_dir, limbs, keypoints, anchor_pos, idx):
    if not os.path.exists(os.path.join(save_dir, idx)):
        os.makedirs(os.path.join(save_dir, idx))
    for i in range(keypoints.shape[0]):
        x_init, y_init, z_init = -keypoints[i, :, 0], keypoints[i, :, 2], -keypoints[i, :, 1]
        
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        if i in anchor_pos:
            ax.scatter(x_init, y_init, z_init, c='r', s=50, label='anchor')
            for edge in limbs:
                ax.plot([x_init[edge[0]], x_init[edge[1]]], [y_init[edge[0]], y_init[edge[1]]], [z_init[edge[0]], z_init[edge[1]]], c='r')
        else:
            ax.scatter(x_init, y_init, z_init, c='b', s=50, label='non-anchor')
            for edge in limbs:
                ax.plot([x_init[edge[0]], x_init[edge[1]]], [y_init[edge[0]], y_init[edge[1]]], [z_init[edge[0]], z_init[edge[1]]], c='b')
        
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_xlim(-100, 100)
        ax.set_ylim(-100, 100)
        ax.set_zlim(-100, 100)
        
        plt.legend()
        plt.savefig(os.path.join(save_dir, idx, '{}.png'.format(i)))
        plt.close()
        plt.clf()


def main(args):
    with open(os.path.join(args.data_dir, 'concept_anchor_fit_param.pkl'), 'rb') as f:
        concept_anchor_fit_param = pickle.load(f)
    
    for concept in concept_anchor_fit_param.keys():
        input_keypoints = concept_anchor_fit_param[concept]["input_keypoints"]
        keypoints_gather = []
        anchor_pose_gather = []
        for i in range(min(len(input_keypoints), 20)):
            full_keypoints = []
            anchor_pos = []
            anchor_offset = 0
            for j in range(len(input_keypoints[i][0])):
                tmp_keypoints = []
                for k in range(len(input_keypoints[i])):
                    tmp_keypoints.append(input_keypoints[i][k][j])
                tmp_keypoints = np.array(tmp_keypoints)
                tmp_keypoints = np.transpose(tmp_keypoints, (1, 0, 2))
                if j < len(input_keypoints[i][0]) - 1:
                    full_keypoints.append(tmp_keypoints[:-1])
                else:
                    full_keypoints.append(tmp_keypoints)
                anchor_offset += tmp_keypoints.shape[0] - 1
                anchor_pos.append(anchor_offset)
            full_keypoints = np.concatenate(full_keypoints, axis=0)
            if args.dataset_name == 'humanact12':
                visualize_anchor_pos(os.path.join(args.data_dir, 'visualize_anchor_pos', concept), humanact12_limbs, full_keypoints, anchor_pos, str(i))
            elif args.dataset_name == 'uestc':
                visualize_anchor_pos(os.path.join(args.data_dir, 'visualize_anchor_pos', concept), uestc_limbs, full_keypoints, anchor_pos, str(i))
            keypoints_gather.append(full_keypoints)
            anchor_pose_gather.append(anchor_pos)


if __name__ == '__main__':
    main(args)
    