import os
import h5py
import argparse
import numpy as np
from tqdm import tqdm

from keyframe_selection import dp_keyframe_selection, greedy_keyframe_selection


def main(args):

    num_keyframes = []
    num_frames = []

    # load data
    for i in tqdm(range(args.start_idx, args.end_idx+1)):
        dataset_path = os.path.join(args.dataset, f'episode_{i}.hdf5')
        with h5py.File(dataset_path, 'r+') as root:
            qpos = root['/observations/qpos'][()]
            # action = root['/action'][()]

            # select keyframes
            keyframes = greedy_keyframe_selection(
                env=None,
                actions=qpos,
                gt_states=qpos,
                err_threshold=args.err_threshold,
                pos_only=True,
            )
            print(f"Episode {i}: {len(qpos)} frames -> {len(keyframes)} keyframes (ratio: {len(qpos)/len(keyframes):.2f})")
            num_keyframes.append(len(keyframes))
            num_frames.append(len(qpos))

            # save keyframes
            name = f"/keyframes"
            try:
                root[name] = keyframes
            except:
                # if the keyframes dataset already exists, ask the user if they want to overwrite
                print("Keyframes dataset already exists. Overwrite? (y/n)")
                ans = input()
                if ans == "y":
                    del root[name]
                    root[name] = keyframes

            root.close()

    print(f"Average number of keyframes: {np.mean(num_keyframes)} \tAverage number of frames: {np.mean(num_frames)} \tratio: {np.mean(num_frames)/np.mean(num_keyframes)}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="data/act/sim_transfer_cube_scripted_copy",
        # default="data/act/sim_insertion_scripted_copy",
        # default="data/act/sim_transfer_cube_human_copy",
        # default="data/act/sim_insertion_human_copy",
        # default="data/act/aloha_screw_driver",
        # default="data/act/aloha_tape",
        # default="data/act/aloha_coffee",
        # default="data/act/aloha_towel",
        help="path to hdf5 dataset",
    )

    # index of the trajectory to playback. If omitted, playback trajectory 0.
    parser.add_argument(
        "--start_idx",
        type=int,
        default=0,
        help="(optional) start index of the trajectory to playback",
    )

    parser.add_argument(
        "--end_idx",
        type=int,
        default=49,
        help="(optional) end index of the trajectory to playback",
    )

    # error threshold for reconstructing the trajectory
    parser.add_argument(
        "--err_threshold",
        type=float,
        default=0.01,
        help="(optional) error threshold for reconstructing the trajectory",
    )

    args = parser.parse_args()
    main(args)
