import argparse
import pickle
import os
import numpy as np
from tqdm import tqdm
from pathlib import Path
import random
import json
from scipy.spatial.transform import Rotation as R

SPLIT = [0.7,0.15,0.15]
COCO_KP = [
            "nose", 
            "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder","right_shoulder", 
            "left_elbow", "right_elbow", "left_wrist", "right_wrist", "left_hip", "right_hip", 
            "left_knee", "right_knee", "left_ankle", "right_ankle"
            ]
SKELETON = [
            [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13],[6,7],
            [6,8],[7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]
        ]
GENRE_CODE = {
    'BR':'break',
    'HO':'house',
    'JB':'ballet_jazz',
    'JS':'street_jazz',
    'KR':'krump',
    'LH':'la_hip_hop',
    'LO':'lock',
    'MH':'middle_hip_hop',
    'PO':'pop',
    'WA':'wack'
}

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--raw_pth', type=str, default='raw_data', help='Path to raw data.')
    parser.add_argument('--save_dir', type=str, default='datasets/aistpp_seqp', help='Directory to save processed JSON data.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--n_history', type=int, default=10, help='Number of frames required for prediction.')
    parser.add_argument('--n_horizon', type=int, default=5, help='Number of frames to predict.')
    args = parser.parse_args()

    np.random.seed(args.seed)
    random.seed(args.seed)

    sequences = np.array(os.listdir(os.path.join(args.raw_pth, 'keypoints3d'))) # all recorded sequences
    genre_labels = np.array([name.split('_')[0].split('g')[1] for name in sequences]) # genre of each sequence

    for genre in GENRE_CODE.keys():
        print(f'\nProcessing {GENRE_CODE[genre]} sequences...')
        # create data directory for this genre
        genre_dir = os.path.join(args.save_dir, GENRE_CODE[genre], f'h{args.n_history}_H{args.n_horizon}')
        os.makedirs(genre_dir, exist_ok=True)

        # filter out only the sequences for this genre
        idx_genre_sequences = np.argwhere(genre_labels == genre).squeeze()
        genre_sequences = sequences[idx_genre_sequences]

        # prepare data dictionary
        all_data = {'data':{}}
        all_data['names'] = []
        all_data['samples'] = 0
        
        # process each sequence into frames in sliding window fashion
        data_min, data_max = np.Inf, -np.Inf
        for s in tqdm(genre_sequences):
            name = Path(s).stem

            # grab 3d keypoints
            keypoints_pth = os.path.join(args.raw_pth, 'keypoints3d', s)
            with open(keypoints_pth, 'rb') as f:
                keypoints = pickle.load(f)
            keypoint_sequence = keypoints['keypoints3d'] # N x 17 x 3

            # check if there is any strange or outlier data
            if np.sum(np.isnan(keypoint_sequence)) > 0:
                continue
            if np.max(np.abs(keypoint_sequence)) > 400:
                continue

            # grab rotations and convert to global view
            motions_pth = os.path.join(args.raw_pth, 'motions', s)
            with open(motions_pth, 'rb') as f:
                motions = pickle.load(f)

            R_full = R.from_rotvec(motions['smpl_poses'][:,:3]).as_matrix() # N x 3 x 3
            fwd = R_full @ np.array([0,0,1]) # N x 3
            fwd[1] = 0
            norm = np.linalg.norm(fwd)
            if norm > 1e-6:
                fwd /= norm
            else:
                fwd = np.array([0, 0, 1])  # default forward

            rotation_sequence = np.arctan2(fwd[:, 0], fwd[:, 2])  # yaw = atan2(x, z)
            rotation_sequence = np.unwrap(rotation_sequence)

            sin_rotation_sequence = np.sin(rotation_sequence)
            cos_rotation_sequence = np.cos(rotation_sequence)

            # center and remove rotations from keypoints
            cos_yaws = np.cos(-rotation_sequence)
            sin_yaws = np.sin(-rotation_sequence)
            zeros, ones = np.zeros(cos_yaws.shape), np.ones(cos_yaws.shape)
            R_inv_yaw = np.array([
                [cos_yaws,      zeros,      sin_yaws],
                [zeros,         ones,       zeros],
                [-sin_yaws,     zeros,      cos_yaws]
            ])
            root = (keypoint_sequence[:,11,:] + keypoint_sequence[:,12,:])/2
            nm_keypoints_sequence = (keypoint_sequence - root[:,np.newaxis,:]) @ R_inv_yaw.T # N x 17 x 3

            # get some stats about data
            min, max = np.min(nm_keypoints_sequence), np.max(nm_keypoints_sequence)
            if min < data_min: data_min = min
            if max > data_max: data_max = max

            # loop through first frames
            len_sequence, _, _ = keypoint_sequence.shape
            for history_start in range(len_sequence - args.n_history - args.n_horizon +1):
                history_end = history_start + args.n_history
                horizon_end = history_end + args.n_horizon

                history_kp = keypoint_sequence[history_start:history_end]
                horizon_kp = keypoint_sequence[history_end:horizon_end]

                history_nm_kp = nm_keypoints_sequence[history_start:history_end]
                horizon_nm_kp = nm_keypoints_sequence[history_end:horizon_end]

                sin_history_rt = sin_rotation_sequence[history_start:history_end]
                sin_horizon_rt = sin_rotation_sequence[history_end:horizon_end]

                cos_history_rt = sin_rotation_sequence[history_start:history_end]
                cos_horizon_rt = cos_rotation_sequence[history_end:horizon_end]

                all_data['data'][all_data['samples']] = {
                    'name': name,
                    'history_kp_raw':history_kp.tolist(),
                    'horizon_kp_raw':horizon_kp.tolist(),
                    'history_kp_norm':history_nm_kp.tolist(), # centered and without y-axis rotation
                    'horizon_kp_norm':horizon_nm_kp.tolist(), # centered and without y-aaxis rotation
                    'history_rt_cos':cos_history_rt.tolist(),
                    'history_rt_sin':sin_history_rt.tolist(),
                    'horizon_rt_cos':cos_horizon_rt.tolist(),
                    'horizon_rt_sin':sin_horizon_rt.tolist(),
                }
                if name not in all_data['names']: all_data['names'].append(name)
                all_data['samples'] += 1
        
        # random split into train/val/test sets
        sample_keys = list(range(all_data['samples']))
        random.shuffle(sample_keys)

        assert sum(SPLIT) == 1.0
        n_tr, n_val = int(all_data['samples']*SPLIT[0]), int(all_data['samples']*SPLIT[1])
        n_te = all_data['samples'] - n_tr - n_val
        tr_keys, val_keys, te_keys = sample_keys[:n_tr], sample_keys[n_tr:n_tr+n_val], sample_keys[n_tr+n_val:]
        print(f"Splitting {str(all_data['samples'])} samples into {str(n_tr)} train, {str(n_val)} validation, and {str(n_te)} test samples.")
        
        train_data, val_data, test_data = {}, {}, {}
        train_data['data'] = {k:all_data['data'][k] for k in tr_keys}
        val_data['data'] = {k:all_data['data'][k] for k in val_keys}
        test_data['data'] = {k:all_data['data'][k] for k in te_keys}
        assert (len(train_data['data'].keys()) == n_tr) and (len(val_data['data'].keys()) == n_val) and (len(test_data['data'].keys()) == n_te)

        # record info for each split
        print(f"Overall: {str(all_data['samples'])} samples")
        for (split_name, split_data, n_split) in [('Train', train_data, n_tr), ('Val', val_data, n_val), ('Test', test_data, n_te)]:
            split_data.update({
                'samples': n_split,
                'names':list(set([split_data['data'][s]['name'] for s in split_data['data'].keys()])),
                'genre_code':genre,
                'genre_name':GENRE_CODE[genre],
                'coco_kp': COCO_KP,
                'skeleton':SKELETON,
                'n_history':args.n_history,
                'n_horizon':args.n_horizon
            })
            print(f"{split_name}: {str(split_data['samples'])} samples")

        print(f'Min: {data_min}, Max: {data_max}')

        # save as 3 json files to an arg-specified data directory
        with open(os.path.join(genre_dir, 'train.json'), 'w') as f:
            json.dump(train_data, f)
        with open(os.path.join(genre_dir, 'val.json'), 'w') as f:
            json.dump(val_data, f)
        with open(os.path.join(genre_dir, 'test.json'), 'w') as f:
            json.dump(test_data, f)
