import argparse
import pickle
import os
import numpy as np
from tqdm import tqdm
from pathlib import Path
import random
import json
from scipy.spatial.transform import Rotation as R
import archetypes as arch
import warnings

from utils.archetypes import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--dataset', type=str, choices=['break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
    parser.add_argument('--src_dir', type=str, default='datasets/aistpp_seqp', help='Path to data without aa labels.')
    parser.add_argument('--save_dir', type=str, default='datasets/aistpp_seqp_aa', help='Path to data with aa labels.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--n_history', type=int, default=10, help='Number of frames required for prediction.')
    parser.add_argument('--n_horizon', type=int, default=5, help='Number of frames to predict.')
    args = parser.parse_args()

    args.save_dir = os.path.join(args.save_dir, args.dataset, f'h{args.n_history}_H{args.n_horizon}')
    os.makedirs(args.save_dir, exist_ok=True)

    # dataset
    if 'break' in args.dataset:
        data_range = (-217.49, 286.76)
    elif 'street_jazz' in args.dataset:
        data_range = (-102.57, 101.90)
    elif 'ballet_jazz' in args.dataset:
        data_range = (-308.86, 327.77)
    elif 'middle_hip_hop' in args.dataset:
        data_range = (-91.20, 92.65)
    elif 'house' in args.dataset:
        data_range = (-78.08, 81.47)
    elif 'krump' in args.dataset:
        data_range = (-85.32, 110.89)
    elif 'la_hip_hop' in args.dataset:
        data_range = (-77.41, 100.26)
    elif 'lock' in args.dataset:
        data_range = (-109.02, 95.57)
    elif 'pop' in args.dataset:
        data_range = (-132.20, 103.33)
    elif 'wack' in args.dataset:
        data_range = (-283.42, 413.92)
    else:
        raise NotImplementedError

    data_dir = f'{args.src_dir}/{args.dataset}/h{args.n_history}_H{args.n_horizon}'

    with open(os.path.join(data_dir, 'train.json')) as f:
        train_dict = json.loads(f.read())
    with open(os.path.join(data_dir, 'val.json')) as f:
        val_dict = json.loads(f.read())
    with open(os.path.join(data_dir, 'test.json')) as f:
        test_dict = json.loads(f.read())

    print(f'Reading archetype matrix...')
    with open(f"archetypes/{args.dataset}_aa_object.pkl", 'rb') as f:
        aa = pickle.load(f)

    # update and save dicts
    for dict, name in [(train_dict, 'train'), (val_dict, 'val'), (test_dict, 'test')]:
        print(name)
        samples = list(dict['data'].keys())

        for s in tqdm(samples):
            sample = dict['data'][s]
            history_nm_kp = np.array(sample['history_kp_norm']) # h x 17 x 3
            horizon_nm_kp = np.array(sample['horizon_kp_norm']) # H x 17 x 3

            history_normed = normalize_data(history_nm_kp, data_range[0], data_range[1], 0, 1)
            horizon_normed = normalize_data(horizon_nm_kp, data_range[0], data_range[1], 0, 1)

            history_flattened = history_normed.reshape((-1,17*3))
            horizon_flattened = horizon_normed.reshape((-1,17*3))

            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=FutureWarning) # suppress warning about BaseEstimator._validate_data deprecation
                history_alphas = decompose_hvf_data(aa, history_flattened)
                horizon_alphas = decompose_hvf_data(aa, horizon_flattened)

            sample.update({
                'history_aa':history_alphas.tolist(),
                'horizon_aa':horizon_alphas.tolist()
            })
        
        # save as new json file to an arg-specified data directory
        with open(os.path.join(args.save_dir, f'{name}.json'), 'w') as f:
            json.dump(dict, f)