import h5py
import os
import pickle
from PIL import Image
import io
import argparse
import tqdm
import numpy as np
import glob

def main(args: argparse.Namespace):
    recon_dir = args.input_dir
    output_dir = args.output_dir

    # create output dir if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # get all the folders in the recon dataset
    filenames = os.listdir(recon_dir)
    # filenames = glob.glob(recon_dir + '/**/*.hdf5', recursive=True)
    
    if args.num_trajs >= 0:
        filenames = filenames[: args.num_trajs]

    # processing loop
    pos_data = {}
    for filename in tqdm.tqdm(filenames):
        # extract the name without the extension
        traj_name = filename.split(".")[0]
        # load the hdf5 file
        try:
            f = h5py.File(os.path.join(recon_dir, filename), "r")
        except OSError:
            print(f"Error loading {filename}. Skipping...")
            continue
        
        position_data = np.array(f["position"])
        yaw_data = np.array(f["yaw"])
        
        # save the data to a dictionary
        pos_data[traj_name] = {"position": position_data, "yaw": yaw_data}
        traj_folder = os.path.join(output_dir, traj_name)
        # make a folder for the file
        if not os.path.exists(traj_folder):
            os.makedirs(traj_folder)
        # save the image data to disk
        for i in range(f["images"].shape[0]):
            img = Image.open(io.BytesIO(f["images"][i]))
            img.save(os.path.join(traj_folder, f"{i}.jpg"))

    # save the odom data to disk
    with open(os.path.join(output_dir, "pos_data.pkl"), "wb") as fi:
        pickle.dump(pos_data, fi)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # get arguments for the recon input dir and the output dir
    parser.add_argument(
        "--input-dir",
        "-i",
        type=str,
        help="path of the dataset",
        required=True,
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        default="../data/datasets/recon/",
        type=str,
        help="path for processed recon dataset (default: ../data/datasets/recon/)",
    )
    # number of trajs to process
    parser.add_argument(
        "--num-trajs",
        "-n",
        default=-1,
        type=int,
        help="number of trajectories to process (default: -1, all)",
    )

    args = parser.parse_args()
    print("STARTING PROCESSING DATASET")
    main(args)
    print("FINISHED PROCESSING DATASET")