import os
import json
import argparse
import pickle
import numpy as np
from utils.util import visualize_single

parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="data_generate")
parser.add_argument("-d", "--dataset", type=str, default="uestc")
parser.add_argument("--model_name", type=str, default="actor")

args = parser.parse_args()

humanact12_limbs = [(0, 1), (1, 4), (4, 7), (7, 10), (0, 2), (2, 5), (5, 8), (8, 11), (0, 3), (3, 6), (6, 9), \
                    (9, 12), (12, 15), (9, 13), (13, 16), (16, 18), (18, 20), (20, 22), (9, 14), (14, 17), (17, 19), (19, 21), (21, 23)]
humanact12_match = {
    "0": "warm_up",
    # "1": "walk_nonorm",
    "1": "walk",
    "2": "run",
    "3": "jump",
    "4": "drink",
    "5": "lift_dumbbell",
    "6": "sit",
    "7": "eat",
    "8": "turn_steer_wheel",
    "9": "phone",
    "10": "boxing",
    "11": "throw"
}

def main(args):
    if args.model_name == "actor":
        saved_data = {}
        splits = ["train", "test"]
        if args.dataset == "uestc":
            data_dir = os.path.join(args.data_dir, args.model_name, args.dataset, "uestc_xyz")
        else:
            data_dir = os.path.join(args.data_dir, args.model_name, args.dataset)
        for split in splits:
            for concept in range(40):
                saved_data[concept] = {}
                for i in range(20):
                    with open(os.path.join(data_dir, f"generated_motions_{split}{i}.pkl"), "rb") as f:
                        data = pickle.load(f)
                    labels = data['y']
                    output_xyz = data['output_xyz']
                    gt_xyz = data['gt']
                    keypoint_sequence = []
                    gt_keypoint_sequence = []
                    for j in range(len(labels)):
                        if labels[j] == concept:
                            tmp = np.array(output_xyz[j])
                            tmp = np.transpose(tmp, (2, 0, 1))
                            keypoint_sequence.append(tmp)
                            tmp_gt = np.array(gt_xyz[j])
                            tmp_gt = np.transpose(tmp_gt, (2, 0, 1))
                            gt_keypoint_sequence.append(tmp_gt)
                    keypoint_sequence = np.array(keypoint_sequence)
                    gt_keypoint_sequence = np.array(gt_keypoint_sequence)
                    saved_data[concept][i] ={}
                    saved_data[concept][i]['generated'] = keypoint_sequence
                    saved_data[concept][i]['gt'] = gt_keypoint_sequence
                print("==> Split: ", split, "concept: ", concept, "generated shape: ", len(keypoint_sequence), "gt shape: ", len(gt_keypoint_sequence), "saved!")
            with open(os.path.join(data_dir, split, "saved_processed_data.pkl"), "wb") as f:
                pickle.dump(saved_data, f)
            print("==> Split: ", split, "saved at: ", os.path.join(data_dir, split, "saved_processed_data.pkl"))
        

if __name__ == "__main__":
    main(args)
