"""
Run this after generating segment wise pickle files for handpose frames using 1_generate_pose_data.py
This script will generate data (X) and labels (Y) for each split (train/val/test)
"""
import argparse
import os
import pickle

import numpy as np
from tqdm import tqdm


def gendata(data_path, out_path, _max_frame, part='train', label_type='action'):
    _max_body_true = 2 # 2 hands
    _num_joint = 21 # 21 landmarks per hand

    if not os.path.exists(out_path):
        os.mkdir(out_path)

    sample_name = [] # Stores segment identifier, which is basically the filenames of the pickles 
                     # each pickle has (3, T, 21, 2) sized numpy array for certain segment
                     # Note: T is not constant at this point

    sample_label = [] # Action label for segment
    no_frames = [] # Number of handpose frames per segment

    type_dic = {'action': 1, 'verb': 2, 'noun': 3}
    type_dic_marker = {'action': 'a', 'verb': 'v', 'noun': 'n'}

    for filename in os.listdir(data_path):
        # Action class can be used to find verb and noun too
        # action_class = int(filename.split('_')[1].split('a')[1])  # action
        # verb_class = int(filename.split('_')[2].split('v')[1]) # verb
        # noun_class = int(filename.split('_')[3].split('n')[1]) # noun

        label_class = int(filename.split('_')[type_dic[label_type.lower()]].split(type_dic_marker[label_type.lower()])[1])  # action
        sample_name.append(filename) 
        sample_label.append(label_class) 
        no_frames.append(int(filename.split('_')[4].split('len')[1]))

    # Stats on number of frames for the segments
    list_frames = np.array(no_frames)
    print('max=', np.max(list_frames), ', min=', np.min(list_frames),
          ', mean=', int(np.mean(list_frames)), ', med=', np.median(list_frames))

    # Save action labels against file names in a pickle (Y part)
    with open('{}/{}_label.pkl'.format(out_path, part), 'wb') as f:
        pickle.dump((sample_name, list(sample_label)), f)

    """
    Create data for the split (X part)
    Data shape to produce: (Total number of segments, 3, max frame, 21, 2)
    Basically joining all the segment pickles into one
    To ensure uniform size (max frame) for each segment:
        - Truncate frames if exceeds max
        - Repeat frames if shorter than max
    """
    fp = np.zeros((len(sample_label), 3, _max_frame, _num_joint, _max_body_true), dtype=np.float32)
    for i, s in enumerate(tqdm(sample_name)):
        filename = data_path + '/' + s
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        sel_val = min(data.shape[1], _max_frame) # Truncate or repeat using sel_val

        if data.shape[1] < _max_frame:
            # Here, sel_val = _max_frame
            out, _ = divmod(_max_frame, data.shape[1]) # How many times to repeat the frames
            out += 1 # Take 1 extra to cover for remaining frames
            data = np.tile(data, (1, out, 1, 1)) # tile up (repeat) in the time axis (axis=1)

        fp[i, :, 0:sel_val, :, :] = data[:, 0:sel_val, :, :] # Fill up allocated region for the segment

    # Save data
    np.save('{}/{}_data_joint_{}.npy'.format(out_path, part, str(_max_frame)), fp)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Data Converter. Generate Assembly101 Data Part 2.')
    parser.add_argument('--data_path', type=str, 
        help = 'Take the directory for segment wise pickle files for handpose frames generated by 1_generate_pose_data.py', 
        default='./RAW_contex25_thresh0')
    parser.add_argument('--type', type=str, 
        help = 'Choose the label type: action/verb/noun', 
        default='action')
    arg = parser.parse_args()
    
    # Output
    if arg.type.lower() == 'action':
        out_folder = os.path.join(os.path.dirname(arg.data_path), 'share_contex25_thresh0')
    elif arg.type.lower() == 'verb':
        out_folder = os.path.join(os.path.dirname(arg.data_path), 'share_contex25_thresh0_verb')
    elif arg.type.lower() == 'noun':
        out_folder = os.path.join(os.path.dirname(arg.data_path), 'share_contex25_thresh0_noun')

    list_max_frames = [200] # maximum number of frames to take per segment
    # list_max_frames = [100, 200, 300]

    for _max_frame in list_max_frames:

        part = ['train', 'validation']
        
        for p in part:
            data_path = os.path.join(arg.data_path, p)
            # Generate data for the split
            gendata(data_path, out_folder, _max_frame, part=p, label_type=arg.type)
