import pickle
import h5py
import numpy as np
import os
import sys
import torch


def format_data(vv):
    dir_ = os.path.dirname(os.path.abspath(__file__)) + '/../data/'
    all_files = os.listdir(dir_)
    print("all_files: ", all_files)

    subdirs = []
    for subdir in all_files:
        if vv['data_dir'] in subdir:
            subdirs.append(subdir)
            
    data_file = []
    for subdir in subdirs:
        subdir_path = os.path.dirname(os.path.abspath(__file__)) + '/../data/' + subdir
        all_files = os.listdir(subdir_path)
        for f in all_files:
            if 'softgym_traj' in f:
                data_file.append(os.path.join(subdir_path, f))

    print("data_file: ", data_file, flush=True)
    data = []
    for f in data_file:
        data.append(torch.load(f))

    rgbd = []
    actions = []
    for d in data:
        rgbd += d[0]
        actions += d[1]

    traj_num = len(rgbd)
    traj_len = len(rgbd[0])

    if not os.path.exists(os.path.join(dir_, vv['data_dir'])):
        os.makedirs(os.path.join(dir_, vv['data_dir']), exist_ok=True)
    save_path = os.path.join(dir_, vv['data_dir'] , 'dataset.hdf5')

    # print("traj_num is: ", traj_num)
    # print("traj_len is: ", traj_len)
    # print("actions shape: ", np.asarray(actions[0]).shape)
    action_dim = np.asarray(actions[0]).shape[1]

    with h5py.File(save_path, 'w') as f:
        d_images = f.create_dataset('images', (traj_num, traj_len, 56, 56, 4), dtype=np.float32)
        d_actions = f.create_dataset('actions', (traj_num, traj_len - 1, action_dim), dtype=np.float32)
        for i in range(traj_num):
            d_images[i:i+1] = np.array(rgbd[i]).astype(np.float32)
            d_actions[i:i+1] = np.array(actions[i]).astype(np.float32)
            if i % 100 == 0:
                f.flush()
        f.flush()
        f.close()

def run_task(vv, log_dir, exp_name):
    format_data(vv)


if __name__ == '__main__':
    format_data()
