from argparse import ArgumentParser
from pathlib import Path

import cv2
import numpy as np

if __name__ == "__main__":
    parser = ArgumentParser(description="Interactive visualization of CALVIN dataset")
    parser.add_argument("path", type=str, help="Path to dir containing scene_info.npy")
    parser.add_argument("-d", "--data", nargs="*", default=["rgb_static", "rgb_gripper"], help="Data to visualize")
    args = parser.parse_args()

    ep_start_end_ids = np.load(args.path + "/ep_start_end_ids.npy")
    print(ep_start_end_ids)
    scene_info = np.load(f"{args.path}/scene_info.npy", allow_pickle=True).item()
    for key, value in scene_info.items():
        # indices = list(range(value[0], value[1] + 1))
        print(value[0], value[1] + 1)
    assert 0 == 1
    if not Path(args.path).is_dir():
        print(f"Path {args.path} is either not a directory, or does not exist.")
        exit()

    indices = next(iter(np.load(f"{args.path}/scene_info.npy", allow_pickle=True).item().values()))
    indices = list(range(indices[0], indices[1] + 1))
    # print(len(indices))
    # print(indices[0].shape)
    # print(indices[1].shape)
    annotations = np.load(f"{args.path}/lang_annotations/auto_lang_ann.npy", allow_pickle=True).item()
    annotations = list(zip(annotations["info"]["indx"], annotations["language"]["ann"]))

    idx = 0
    ann_idx = -1

    while True:
        t = np.load(f"{args.path}/episode_{indices[idx]:07d}.npz", allow_pickle=True)
        keys = t.files
        # print(keys)
        # for key in keys:
        #     print(key, t[key].shape)
        # assert 0 == 1
        for d in args.data:
            if d not in t:
                print(f"Data {d} cannot be found in transition")
                continue

            cv2.imshow(d, t[d][:, :, ::-1])

        for n, ((low, high), ann) in enumerate(annotations):
            if indices[idx] >= low and indices[idx] <= high:
                if n != ann_idx:
                    print(f"{ann}")
                    ann_idx = n

        key = cv2.waitKey(0)
        if key == ord("q"):
            break
        elif key == 83:  # Right arrow
            idx = (idx + 1) % len(indices)
        elif key == 81:  # Left arrow
            idx = (len(indices) + idx - 1) % len(indices)
        else:
            print(f'Unrecognized keycode "{key}"')