from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
from tqdm import tqdm


class HDF5Dataset:

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]):

        hdf5_file_path = os.path.join(root_path, f'{dataset_name}-{type}-{size[0]}x{size[1]}.hdf5')
        data_path      = os.path.join(root_path, dataset_name, type)
        print(f"Loading {dataset_name} {type} from {data_path}", flush=True)

        # setup the hdf5 file
        hdf5_file = h5py.File(hdf5_file_path, "w")

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (0, 3, size[0], size[1]), 
            maxshape=(None, 3, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 3, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "raw_depth", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "forward_flow",
            (0, 2, size[0], size[1]), 
            maxshape=(None, 2, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 2, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "backward_flow",
            (0, 2, size[0], size[1]), 
            maxshape=(None, 2, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 2, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (0, 1, size[0], size[1]),
            maxshape=(None, 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (0, 2), # start index, number of instances
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (0, 1), 
            maxshape=(None, 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.compat.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (0, 4), 
            maxshape=(None, 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (0, 2), # start index, number of images
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_field_of_view", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_focal_length", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_position",
            (0, 3),
            maxshape=(None, 3),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_rotation_quaternion",
            (0, 4),
            maxshape=(None, 4),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_sensor_width", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = dataset_name
        metadata_grp.attrs["type"] = type

        self.hdf5_file = hdf5_file

    def close(self):
        self.hdf5_file.flush()
        self.hdf5_file.close()

    def __getitem__(self, index):
        return self.hdf5_file[index]


def video_to_numpy(video_path):
    # Initialize a VideoCapture object to read video data into a numpy array
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError("Video file cannot be opened")

    # Initialize a list to hold all frames in the video
    frames = []

    # Iterate through each frame in the video
    while cap.isOpened():
        # Read the next frame
        ret, frame = cap.read()
        if ret is True:
            # Add the frame to our frames list
            frames.append(frame)
        else:
            # If no more frames are available, break the loop
            break

    # Clean up the cap object
    cap.release()

    # Convert the frames list into an array
    frames_array = np.array(frames)

    # Reorder the dimensions to [time, channels, height, width]
    frames_array = np.transpose(frames_array, (0, 3, 1, 2))

    return (frames_array / 255.0).astype(np.float32)

if __name__ == "__main__":

    hdf5_file_path = "/media/chief/data/CLEVRER-HR/"
    size = (480, 320)


    hdf5_dataset = HDF5Dataset(hdf5_file_path, "CLEVRER", "test", (size[1], size[0]))

    for i in tqdm(range(15000, 20000)):
        subdir = f"video_{(i//1000)*1000:05d}-{(i//1000 + 1)*1000:05d}"
        video  = f"video_{i:05d}.mp4"
        rgb = video_to_numpy(os.path.join(hdf5_file_path, subdir, video))

        offset = hdf5_dataset["rgb_images"].shape[0]
        hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + len(rgb), 3, size[1], size[0]))
        hdf5_dataset["rgb_images"][offset:] = rgb

        hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
        hdf5_dataset["sequence_indices"][-1] = [offset, len(rgb)]

    hdf5_dataset.close()



    hdf5_dataset = HDF5Dataset(hdf5_file_path, "CLEVRER", "validation", (size[1], size[0]))

    for i in tqdm(range(10000, 15000)):
        subdir = f"video_{(i//1000)*1000:05d}-{(i//1000 + 1)*1000:05d}"
        video  = f"video_{i:05d}.mp4"
        rgb = video_to_numpy(os.path.join(hdf5_file_path, subdir, video))

        offset = hdf5_dataset["rgb_images"].shape[0]
        hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + len(rgb), 3, size[1], size[0]))
        hdf5_dataset["rgb_images"][offset:] = rgb

        hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
        hdf5_dataset["sequence_indices"][-1] = [offset, len(rgb)]

    hdf5_dataset.close()



    hdf5_dataset = HDF5Dataset(hdf5_file_path, "CLEVRER", "train", (size[1], size[0]))

    for i in tqdm(range(0, 10000)):
        subdir = f"video_{(i//1000)*1000:05d}-{(i//1000 + 1)*1000:05d}"
        video  = f"video_{i:05d}.mp4"
        rgb = video_to_numpy(os.path.join(hdf5_file_path, subdir, video))

        offset = hdf5_dataset["rgb_images"].shape[0]
        hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + len(rgb), 3, size[1], size[0]))
        hdf5_dataset["rgb_images"][offset:] = rgb

        hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
        hdf5_dataset["sequence_indices"][-1] = [offset, len(rgb)]

    hdf5_dataset.close()
