import os
import json
import argparse
import pickle
import numpy as np
from utils.util import visualize_single

parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data_generate")
parser.add_argument("-d", "--dataset", type=str, default="humanact12")
parser.add_argument("--concept_list", type=int, nargs="+")
parser.add_argument("--model_name", type=str, default="actor")
parser.add_argument("--save_dir", type=str, default="processed")

args = parser.parse_args()

humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]
humanact12_match = {
    "0": "warm_up",
    # "1": "walk_nonorm",
    "1": "walk",
    "2": "run",
    "3": "jump",
    "4": "drink",
    "5": "lift_dumbbell",
    "6": "sit",
    "7": "eat",
    "8": "turn_steer_wheel",
    "9": "phone",
    "10": "boxing",
    "11": "throw"
}

def main(args):
    if args.model_name == "actor":
        saved_data = {}
        for concept in range(12):
            data_dir = os.path.join(args.data_dir, args.model_name, args.dataset)
            saved_data[concept] = {}
            for i in range(20):
                with open(os.path.join(data_dir, f"generated_motions_{i}.json"), "rb") as f:
                    data = json.load(f)
                z = data['z']
                labels = data['y']
                lengths = data['lengths']
                output_xyz = data['output_xyz']
                keypoint_sequence = []
                for j in range(len(labels)):
                    if labels[j] == concept:
                        tmp = np.array(output_xyz[j])
                        tmp = np.transpose(tmp, (2, 0, 1))
                        keypoint_sequence.append(tmp)
                keypoint_sequence = np.array(keypoint_sequence)
                saved_data[concept][i] = keypoint_sequence
        with open(os.path.join(data_dir, "saved_processed_data.pkl"), "wb") as f:
            pickle.dump(saved_data, f)
    elif args.model_name == "action2motion":
        saved_data = {}
        # for concept in args.concept_list:
        for concept in range(12):
            data_dir = os.path.join(args.data_dir, args.model_name, args.dataset)
            saved_data[concept] = {}
            for i in range(20):
                with open(os.path.join(data_dir, f"generated_motions_{i}.json"), "rb") as f:
                    data = json.load(f)
                motions = data['motions']
                labels = data['labels']
                keypoint_sequence = []
                for j in range(len(labels)):
                    if labels[j] == concept:
                        tmp = np.array(motions[j])
                        tmp = tmp.reshape(-1, 24, 3)
                        keypoint_sequence.append(tmp)
                keypoint_sequence = np.array(keypoint_sequence)
                saved_data[concept][i] = keypoint_sequence
        with open(os.path.join(data_dir, "saved_processed_data.pkl"), "wb") as f:
            pickle.dump(saved_data, f)
        

if __name__ == "__main__":
    main(args)
