"""
Run this after generating segment wise pickle files for handpose frames using 1_generate_pose_data.py
This script will generate data (X) and labels (Y) for each split (train/val/test)
"""
import argparse
import os
import pickle

import numpy as np
from tqdm import tqdm


def gendata(data_path, out_path, _max_frame, part='train'):
    _max_body_true = 2 # 2 hands
    _num_joint = 21 # 21 landmarks per hand

    if not os.path.exists(out_path):
        os.mkdir(out_path)

    sample_name = [] # Stores segment identifier, which is basically the filenames of the pickles 
                     # each pickle has (3, T, 21, 2) sized numpy array for certain segment
                     # Note: T is not constant at this point

    no_frames = [] # Number of handpose frames per segment

    sample_file_list = os.listdir(data_path)

    def takeID(elem):
        return int(elem[2:].split('_')[0])
    
    sample_file_list.sort(key=takeID)

    for filename in sample_file_list:
        sample_name.append(filename) 
        no_frames.append(int(filename.split('.')[0].split('_')[1].split('len')[1]))

    # Stats on number of frames for the segments
    list_frames = np.array(no_frames)
    print('max=', np.max(list_frames), ', min=', np.min(list_frames),
          ', mean=', int(np.mean(list_frames)), ', med=', np.median(list_frames))

    """
    Create data for the split (X part)
    Data shape to produce: (Total number of segments, 3, max frame, 21, 2)
    Basically joining all the segment pickles into one
    To ensure uniform size (max frame) for each segment:
        - Truncate frames if exceeds max
        - Repeat frames if shorter than max
    """
    fp = np.zeros((len(no_frames), 3, _max_frame, _num_joint, _max_body_true), dtype=np.float32)
    for i, s in enumerate(tqdm(sample_name)):
        filename = data_path + '/' + s
        with open(filename, 'rb') as f:
            data = pickle.load(f)
        sel_val = min(data.shape[1], _max_frame) # Truncate or repeat using sel_val

        if data.shape[1] < _max_frame:
            # Here, sel_val = _max_frame
            out, _ = divmod(_max_frame, data.shape[1]) # How many times to repeat the frames
            out += 1 # Take 1 extra to cover for remaining frames
            data = np.tile(data, (1, out, 1, 1)) # tile up (repeat) in the time axis (axis=1)

        fp[i, :, 0:sel_val, :, :] = data[:, 0:sel_val, :, :] # Fill up allocated region for the segment

    # Save data
    np.save('{}/{}_data_joint_{}.npy'.format(out_path, part, str(_max_frame)), fp)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Data Converter. Generate Assembly101 Test Data Part 2.')
    parser.add_argument('--data_path', type=str, 
        help = 'Take the directory for segment wise pickle files for handpose frames generated by 1_generate_pose_data.py', 
        default='./RAW_contex25_thresh0')
    parser.add_argument('--type', type=str, 
        help = 'Choose the label type: action/verb/noun. This only indicate the output dir. Test data is the same for all.', 
        default='action')
    arg = parser.parse_args()
    
    # Output
    if arg.type.lower() == 'action':
        out_folder = os.path.join(os.path.dirname(arg.data_path), 'share_contex25_thresh0')
    elif arg.type.lower() == 'verb':
        out_folder = os.path.join(os.path.dirname(arg.data_path), 'share_contex25_thresh0_verb')
    elif arg.type.lower() == 'noun':
        out_folder = os.path.join(os.path.dirname(arg.data_path), 'share_contex25_thresh0_noun')
    
    list_max_frames = [200] # maximum number of frames to take per segment
    # list_max_frames = [100, 200, 300]

    for _max_frame in list_max_frames:

        part = ['test']

        for p in part:
            data_path = os.path.join(arg.data_path, p)
            # Generate data for the split
            gendata(data_path, out_folder, _max_frame, part=p)
