'''
    helper functions to read the TSM_feature lmdb
    run this with a command line argument describing the path to the lmdb
    e.g. python read_lmdb.py TSM_features/C10095_rgb
'''
import os
import sys
import lmdb
import numpy as np
import pickle as pkl

# path to the lmdb file you want to read as a command line argument
# lmdb_path = sys.argv[1]
lmdb_path = "/data/peiyao/all_data/Assembly101/TSM_features/"

# iterate over the entire lmdb and output all files
def extract_all_features(env, view):
    '''
        input:
            env: lmdb environment loaded (see main function)
        output: a dictionary with key as the path_to_frame and value as the TSM feature (2048-D np-array)
                the lmdb key format is '{sequence_name}/{view_name}/{view_name}_{frame_no:010d}.jpg'
                e.g. nusar-2021_action_both_9011-a01_9011_user_id_2021-02-01_153724/C10095_rgb/C10095_rgb_0000000001.jpg
    '''
    # ALL THE FRAME NUMBERS ARE AT 30FPS !!!

    results = {}  # set()
    print('Iterating over the entire lmdb. This may take some time...')
    i= 0
    a = 0
    with env.begin() as e: #env: lmdb_path + view
        cursor = e.cursor()
        for file, data in cursor:
            # if i>100000: break
            # i+=1
            vid = file.decode("utf-8")  # vid: 'nusar-2021_action_both_9011-a01_9011_user_id_2021-02-01_153724/C10095_rgb/C10095_rgb_0000000001.jpg'
            data = np.frombuffer(data, dtype=np.float32) ## each frame
            if data.shape[0] == 2048:
                vid_name, view_name, img_name = vid.split('/')[0], vid.split('/')[1], vid.split('/')[2]
                assert view_name == view
                if vid_name not in results:
                    results[vid_name] = {}
                if view_name not in results[vid_name]:
                    results[vid_name][view_name] = []
                frame_id = img_name.replace(view_name, '').replace('_', '').split('.')[0]
                if int(frame_id) ==1:
                    a += 1
                results[vid_name][view_name].append(int(frame_id))
            else:
                print(vid, data.shape)

    final_results = {}
    minimum_f, maximum_f = [], []
    for vid_name in results.keys(): ## each video
        if vid_name not in final_results:
            final_results[vid_name] = {}
        for view_name in results[vid_name]: ##each view
            if view_name in final_results[vid_name]:
                print('files in views should not be handled yet')
                exit(1)
            max_id, min_id = max(results[vid_name][view_name]), min(results[vid_name][view_name])  ## frame id in each video in a view
            maximum_f.append(max_id)
            minimum_f.append(min_id)
            assert (max_id - min_id + 1) == len(results[vid_name][view_name])
            final_results[vid_name][view_name] = [min_id, max_id + 1] #frame id from image to vide

    print(f'Features [view {view} ]: {len(final_results)} tasks loaded.')
    print('videos start at:', np.unique(np.array(minimum_f)))
    print('max length video with frames end at', np.max(np.array(maximum_f)))

    return final_results

def merge_dict(A, B):
    # B=> A
    if len(A) == 0:
        return B
    if len(B) == 0:
        return A
    for vid_name in B.keys():
        assert len(B[vid_name]) == 1
        view = list(B[vid_name].keys())[0]
        if vid_name in A:
            assert view not in A[vid_name]
            A[vid_name][view] = B[vid_name][view]
        else:
            print(vid_name, "not in the current dict")
            A[vid_name] = B[vid_name]
    return A



def save_statistic_input(views):
    total = {}
    for view in views:
        env = lmdb.open(lmdb_path + view, readonly=True, lock=False)
        final_results = extract_all_features(env, view)
        total = merge_dict(total, final_results)

    with open('statistic_input.pkl', 'wb') as f:
        pkl.dump(total, f)


# main function
if __name__ == '__main__':
    views = ['C10095_rgb', 'C10115_rgb', 'C10118_rgb', 'C10119_rgb', 'C10379_rgb', 'C10390_rgb', 'C10395_rgb', 'C10404_rgb',
             'HMC_21176875_mono10bit', 'HMC_84346135_mono10bit', 'HMC_21176623_mono10bit', 'HMC_84347414_mono10bit',
             'HMC_21110305_mono10bit', 'HMC_84355350_mono10bit','HMC_21179183_mono10bit', 'HMC_84358933_mono10bit']

    save_statistic_input(views)
