import numpy as np
import os
import random
from common.paramUtil import *
dataset_names = ['bfa', 'cmu', 'xia']
config = {
    'motion_length': 160,
    'data_dir': 'processed_'
}
root = "../motion_transfer_data/"
content_num = 20
style_pernum = 1

# for dataset_name in dataset_names:
#     mdataset = np.load(os.path.join(root, config['data_dir']+dataset_name, 'test_data.npy'), allow_pickle=True).item()
#     motions, labels, actions, uids = [], [], [], []
#     motion_dict = dict()
#     mlen = config['motion_length'] if dataset_name != 'xia' else 16
#     for key, value in mdataset.items():
#         if len(value)>=mlen:
#             uids.append(key)
#             motions.append(value)
#             if "cmu" in dataset_name:
#                 label = 0
#             else:
#                 label = eval(key.split("#")[-1])

#             labels.append(label)

#             if "xia" in dataset_name:
#                 actions.append(eval(key.split("#")[-2]))
#             else:
#                 actions.append(0)
            
#             motion_dict[label] = motion_dict.get(label, []) + [key]
#     if dataset_name == 'bfa':
#         style_dict = motion_dict.copy()

#     gen_dataset = []
#     sample_contents_idxes = random.sample(range(len(uids)), content_num)
#     for content_idx in sample_contents_idxes:
#         for k in style_dict.keys():
#             if k == labels[content_idx] and "cmu" not in dataset_name:
#                 continue
#             sample_styls = random.sample(style_dict[k], style_pernum)
#             for sample_style in sample_styls:
#                 gen_dataset.append({'cnt': uids[content_idx], 'sty': sample_style})

#     print(gen_dataset)
#     np.savez("gen_dataset_"+dataset_name, data=gen_dataset)

# gen_data = np.load("gen_dataset_bfa.npz", allow_pickle=True)["data"]
# clean_uids = []
# for gen_pair in gen_data:
#     cnt_uid = "_CNT_" + gen_pair["cnt"]
#     sty_uid = "_STY_" + gen_pair["sty"]
#     clean_uids.append(cnt_uid+sty_uid)

# random_selected_uids = random.sample(clean_uids, 64)
# np.savez("random_selected_uids_bfa", data=random_selected_uids)

print(np.load("random_selected_uids_bfa.npz", allow_pickle=True)["data"])
                
