import h5py
import os
import pickle
from PIL import Image
import io
import argparse
import tqdm
import numpy as np
import glob


def turns_and_steps_to_positions(turns, steps, is_tf=False):
    if is_tf:
        get_shape = lambda x: tuple(x.shape.as_list())
        module = tf
    else:
        get_shape = lambda x: tuple(x.shape)
        module = np

    assert get_shape(turns) == get_shape(steps)

    if len(turns.shape) == 1:
        is_batch = False
        turns = turns[module.newaxis]
        steps = steps[module.newaxis]
    elif len(turns.shape) == 2:
        is_batch = True
    else:
        raise ValueError
    
    
    batch_size, horizon = get_shape(turns)
    angles = [module.zeros(batch_size)]
    positions = [module.zeros((batch_size, 2))]
    for turn, step in zip(module.split(turns, horizon, axis=1), module.split(steps, horizon, axis=1)):
        turn = turn[:, 0]
        step = step[:, 0]

        angle = angles[-1] + turn
        position = positions[-1] + step[:, module.newaxis] * \
                   module.stack([module.cos(angle), module.sin(angle)], axis=-1)

        angles.append(angle)
        positions.append(position)
    positions = module.stack(positions, axis=1)

    if not is_batch:
        positions = positions[0]

    return positions, angles


def main(args: argparse.Namespace):
    recon_dir = args.input_dir
    output_dir = args.output_dir

    # create output dir if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # get all the folders in the recon dataset
    filenames = os.listdir(recon_dir)
    # filenames = glob.glob(recon_dir + '/**/*.hdf5', recursive=True)
    
    if args.num_trajs >= 0:
        filenames = filenames[: args.num_trajs]

    # processing loop
    pos_data = {}
    for filename in tqdm.tqdm(filenames):
        # extract the name without the extension
        traj_name = "_".join(filename.split(".")[0].split('/')[8:])
        # load the hdf5 file
        try:
            f = h5py.File(os.path.join(recon_dir, filename), "r")
        except OSError:
            print(f"Error loading {filename}. Skipping...")
            continue

        if "jackal" not in f.keys() or "position" not in f["jackal"].keys():
            print(f"Error loading positions for {filename}. Converting turns...")
            position_data, yaw_data = turns_and_steps_to_positions(
                np.array(f["commands"]["turn"]), 
                np.array(f["commands"]["dt"])
            )
        else:
            position_data = f["jackal"]["position"][:, :2]
            yaw_data = None
        
        if "jackal" not in f.keys() or 'yaw' not in f["jackal"].keys():
            print(f"Error loading yaws for {filename}. Using interpolation")
            # Generate a yaw from the position data for all points 
            if not yaw_data:
                yaw_data = np.arctan2(position_data[1:, 1] - position_data[:-1, 1], position_data[1:, 0] - position_data[:-1, 0])
                yaw_data = np.insert(yaw_data, 0, 0)
        else:
            yaw_data = f["jackal"]["yaw"][()]
        
        # save the data to a dictionary
        pos_data[traj_name] = {"position": position_data, "yaw": yaw_data}
        traj_folder = os.path.join(output_dir, traj_name)
        # make a folder for the file
        if not os.path.exists(traj_folder):
            os.makedirs(traj_folder)
        # save the image data to disk
        for i in range(f["images"]["front"].shape[0]):
            img = Image.open(io.BytesIO(f["images"]["front"][i]))
            img.save(os.path.join(traj_folder, f"{i}.jpg"))

    # save the odom data to disk
    with open(os.path.join(output_dir, "pos_data.pkl"), "wb") as f:
        pickle.dump(pos_data, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # get arguments for the recon input dir and the output dir
    parser.add_argument(
        "--input-dir",
        "-i",
        type=str,
        help="path of the dataset",
        required=True,
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        default="../data/datasets/recon/",
        type=str,
        help="path for processed recon dataset (default: ../data/datasets/recon/)",
    )
    # number of trajs to process
    parser.add_argument(
        "--num-trajs",
        "-n",
        default=-1,
        type=int,
        help="number of trajectories to process (default: -1, all)",
    )

    args = parser.parse_args()
    print("STARTING PROCESSING DATASET")
    main(args)
    print("FINISHED PROCESSING DATASET")