from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle


class HDF5Dataset:

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]):

        hdf5_file_path = os.path.join(root_path, f'{dataset_name}-{type}-{size[0]}x{size[1]}.hdf5')
        data_path      = os.path.join(root_path, dataset_name, type)
        print(f"Loading {dataset_name} {type} from {data_path}", flush=True)

        # setup the hdf5 file
        hdf5_file = h5py.File(hdf5_file_path, "w")

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (0, 3, size[0], size[1]), 
            maxshape=(None, 3, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 3, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "raw_depth", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "forward_flow",
            (0, 2, size[0], size[1]), 
            maxshape=(None, 2, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 2, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "backward_flow",
            (0, 2, size[0], size[1]), 
            maxshape=(None, 2, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 2, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (0, 1, size[0], size[1]),
            maxshape=(None, 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (0, 2), # start index, number of instances
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (0, 1), 
            maxshape=(None, 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.compat.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (0, 4), 
            maxshape=(None, 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (0, 2), # start index, number of images
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_field_of_view", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_focal_length", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_position",
            (0, 3),
            maxshape=(None, 3),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_rotation_quaternion",
            (0, 4),
            maxshape=(None, 4),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_sensor_width", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = dataset_name
        metadata_grp.attrs["type"] = type

        self.hdf5_file = hdf5_file

    def close(self):
        self.hdf5_file.flush()
        self.hdf5_file.close()

    def __getitem__(self, index):
        return self.hdf5_file[index]


class RamImage():
    def __init__(self, path):
        
        fd = open(path, 'rb')
        img_str = fd.read()
        fd.close()

        self.img_raw = np.frombuffer(img_str, np.uint8)

    def to_numpy(self):
        return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) 

class ClevrerSample(data.Dataset):
    def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]):

        data_path = os.path.join(root_path, data_path, "train", f'{size[0]}x{size[1]}')

        frames = []
        self.size = size

        for file in os.listdir(data_path):
            if file.startswith("frame") and (file.endswith(".jpg") or file.endswith(".png")):
                frames.append(os.path.join(data_path, file))

        frames.sort()
        self.imgs = []
        for path in frames:
            self.imgs.append(RamImage(path))

    def get_data(self):

        frames = np.zeros((128,3,self.size[1], self.size[0]),dtype=np.float32)
        for i in range(len(self.imgs)):
            img = self.imgs[i].to_numpy()
            frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0

        return frames


class ClevrerDataset(data.Dataset):

    def save(self):
        with open(self.file, "wb") as outfile:
    	    pickle.dump(self.samples, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            self.samples = pickle.load(infile)

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]):

        data_path  = f'data/data/video/{dataset_name}'
        data_path  = os.path.join(root_path, data_path)
        self.file  = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}.pickle')
        self.size  = size
        self.type  = type

        self.samples = []

        if os.path.exists(self.file):
            self.load()
        else:

            samples     = list(filter(lambda x: x.startswith("0"), next(os.walk(data_path))[1]))
            num_samples = len(samples)

            for i, dir in enumerate(samples):
                self.samples.append(ClevrerSample(data_path, dir, size))

                print(f"Loading CLEVRER [{i * 100 / num_samples:.2f}]", flush=True)

            self.save()
        
        self.length     = len(self.samples)
        self.background = None
        if "background.jpg" in os.listdir(data_path):
            self.background = cv2.imread(os.path.join(data_path, "background.jpg"))
            self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC)
            self.background = self.background.transpose(2, 0, 1).astype(np.float32) / 255.0
            self.background = self.background.reshape(1, self.background.shape[0], self.background.shape[1], self.background.shape[2])

        print(f"ClevrerDataset: {self.length}")

        if len(self) == 0:
            raise FileNotFoundError(f'Found no dataset at {self.data_path}')

    def __len__(self):
        if self.type == "train":
            return 10000

        return 5000

    def __getitem__(self, index: int):

        if self.type == "val":
            index += 10000
        if self.type == "test":
            index += 15000
        
        return (
            self.samples[index].get_data(),
            self.background,
        )

    def save_to_hdf5(self, hdf5_file_path):
        size = self.size
        hdf5_dataset = HDF5Dataset(hdf5_file_path, "CLEVRER", self.type, (size[1], size[0]))

        for index in tqdm(range(len(self))):
            rgb = self[index][0]

            offset = hdf5_dataset["rgb_images"].shape[0]
            hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + len(rgb), 3, size[1], size[0]))
            hdf5_dataset["rgb_images"][offset:] = rgb

            hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
            hdf5_dataset["sequence_indices"][-1] = [offset, len(rgb)]

        hdf5_dataset.close()

if __name__ == "__main__":
    valset = ClevrerDataset("./", "CLEVRER", "val", (320, 240))
    valset.save_to_hdf5("/media/chief/data/CLEVRER/")

    testset = ClevrerDataset("./", "CLEVRER", "test", (320, 240))
    testset.save_to_hdf5("/media/chief/data/CLEVRER/")

    trainset = ClevrerDataset("./", "CLEVRER", "train", (320, 240))
    trainset.save_to_hdf5("/media/chief/data/CLEVRER/")
