from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
import lz4.frame
import io
import pickle
from multiprocessing import shared_memory
from data.datasets.utils import RamImage, CompressedArray
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm


class MaskCenterNumpy:
    def __init__(self, size, normalize=True):
        height, width = size
        if normalize:
            x_range = np.linspace(-1, 1, width)
            y_range = np.linspace(-1, 1, height)
        else:
            x_range = np.linspace(0, width, width)
            y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        center_x = np.sum(self.x_coords * mask, axis=(2, 3)) / np.sum(mask, axis=(2, 3))
        center_y = np.sum(self.y_coords * mask, axis=(2, 3)) / np.sum(mask, axis=(2, 3))

        return np.concatenate((center_x, center_y), axis=-1)

class MaskBBoxNumpy:
    def __init__(self, size):
        height, width = size
        x_range = np.linspace(0, width, width)
        y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        mask = (mask > 0.75).astype(np.float32)

        x_masked = self.x_coords * mask
        y_masked = self.y_coords * mask

        x_min = np.min(np.where(x_masked > 0, x_masked, np.inf), axis=(2, 3))
        y_min = np.min(np.where(y_masked > 0, y_masked, np.inf), axis=(2, 3))
        x_max = np.max(np.where(x_masked > 0, x_masked, -np.inf), axis=(2, 3))
        y_max = np.max(np.where(y_masked > 0, y_masked, -np.inf), axis=(2, 3))

        bbox = np.stack([x_min, y_min, x_max, y_max], axis=1).squeeze(2)

        return bbox

class HDF5Dataset:

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]):

        hdf5_file_path = os.path.join(root_path, f'{dataset_name}-{type}-{size[0]}x{size[1]}.hdf5')
        data_path      = os.path.join(root_path, dataset_name, type)
        print(f"Loading {dataset_name} {type} from {data_path}", flush=True)

        # setup the hdf5 file
        hdf5_file = h5py.File(hdf5_file_path, "w")

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (0, 3, size[0], size[1]), 
            maxshape=(None, 3, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 3, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "raw_depth", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "forward_flow",
            (0, 2, size[0], size[1]), 
            maxshape=(None, 2, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 2, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "backward_flow",
            (0, 2, size[0], size[1]), 
            maxshape=(None, 2, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 2, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (0, 1, size[0], size[1]),
            maxshape=(None, 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (0, 2), # start index, number of instances
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (0, 1), 
            maxshape=(None, 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.compat.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (0, 4), 
            maxshape=(None, 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (0, 2), # start index, number of images
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_field_of_view", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_focal_length", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_position",
            (0, 3),
            maxshape=(None, 3),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_rotation_quaternion",
            (0, 4),
            maxshape=(None, 4),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "camera_sensor_width", # for each sequence
            (0, 1),
            maxshape=(None, 1),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = dataset_name
        metadata_grp.attrs["type"] = type

        self.hdf5_file = hdf5_file

    def close(self):
        self.hdf5_file.flush()
        self.hdf5_file.close()

    def __getitem__(self, index):
        return self.hdf5_file[index]

def convert_dataset(hdf5_root_path, root_path: str, dataset_name: str, type: str, size: Tuple[int, int]):

    hdf5_dataset = HDF5Dataset(hdf5_root_path, dataset_name, type, size)

    mask_bboxes = MaskBBoxNumpy(size)

    builder = tfds.builder_from_directory(os.path.join(root_path, dataset_name, f"{size[0]}x{size[1]}/1.0.0"))
    print("\n\nSplits: ", builder.info.splits.keys(), "\n\n")
    ds = builder.as_dataset(split=type)

    # use tqdm to show progress
    total = builder.info.splits[type].num_examples
    with tqdm(total=total) as pbar:
        for sample in ds:
            pbar.update(1)
        
            rgb = sample['video'].numpy().transpose(0, 3, 1, 2) / 255.0
            all_instance_masks = sample['segmentations'].numpy().transpose(0, 3, 1, 2)

            minv, maxv = sample["metadata"]["depth_range"]
            raw_depth = (sample["depth"] / 65535 * (maxv - minv) + minv).numpy().transpose(0, 3, 1, 2)

            minv, maxv = sample["metadata"]["forward_flow_range"]
            forward_flow = (sample["forward_flow"] / 65535 * (maxv - minv) + minv).numpy().transpose(0, 3, 1, 2)

            minv, maxv = sample["metadata"]["backward_flow_range"]
            backward_flow = (sample["backward_flow"] / 65535 * (maxv - minv) + minv).numpy().transpose(0, 3, 1, 2)

            instances = []
            num_instances = []
            foreground_mask = []
            for i in range(all_instance_masks.shape[0]):
                instance_mask  = all_instance_masks[i].astype(np.int32)
                unique_indices = np.unique(instance_mask)
                unique_indices = unique_indices[unique_indices > 0]
                valid_indices  = []

                for index in unique_indices:
                    instance_size = np.sum(instance_mask == index)
                    if instance_size >= (10 * 10):
                        valid_indices.append(index)

                filtered_masks = [(instance_mask == index).astype(np.float32) for index in valid_indices]
                
                if len(filtered_masks) == 0:
                    instances.append(np.zeros((0, 1, size[0], size[1]), dtype=np.float32))
                    num_instances.append(0)
                else:
                    instances.append(np.stack(filtered_masks, axis=0))
                    num_instances.append(len(filtered_masks))

                foreground_mask.append((instance_mask > 0).astype(np.float32))


            foreground_mask = np.stack(foreground_mask, axis=0)

            log_depth = np.log(raw_depth)
            depth_avg = np.sum(log_depth * foreground_mask) / np.sum(foreground_mask)
            depth_std = np.sqrt(np.sum((log_depth - depth_avg)**2 * foreground_mask) / np.sum(foreground_mask))

            depth = 1 / (1 + np.exp((log_depth - depth_avg) / depth_std))


            #increase the size of the hdf5 dataset
            offset = hdf5_dataset["rgb_images"].shape[0]
            hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + len(rgb), 3, size[0], size[1]))
            hdf5_dataset["raw_depth"].resize((hdf5_dataset["raw_depth"].shape[0] + len(raw_depth), 1, size[0], size[1]))
            hdf5_dataset["depth_images"].resize((hdf5_dataset["depth_images"].shape[0] + len(depth), 1, size[0], size[1]))
            hdf5_dataset["foreground_mask"].resize((hdf5_dataset["foreground_mask"].shape[0] + len(foreground_mask), 1, size[0], size[1]))
            hdf5_dataset["forward_flow"].resize((hdf5_dataset["forward_flow"].shape[0] + len(forward_flow), 2, size[0], size[1]))
            hdf5_dataset["backward_flow"].resize((hdf5_dataset["backward_flow"].shape[0] + len(backward_flow), 2, size[0], size[1]))

            hdf5_dataset["rgb_images"][offset:] = rgb
            hdf5_dataset["raw_depth"][offset:] = raw_depth
            hdf5_dataset["depth_images"][offset:] = depth
            hdf5_dataset["foreground_mask"][offset:] = foreground_mask
            hdf5_dataset["forward_flow"][offset:] = forward_flow
            hdf5_dataset["backward_flow"][offset:] = backward_flow

            hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
            hdf5_dataset["sequence_indices"][-1] = [offset, len(rgb)]

            hdf5_dataset["camera_field_of_view"].resize((hdf5_dataset["camera_field_of_view"].shape[0] + 1, 1))
            hdf5_dataset["camera_field_of_view"][-1] = sample["camera"]["field_of_view"].numpy()

            hdf5_dataset["camera_focal_length"].resize((hdf5_dataset["camera_focal_length"].shape[0] + 1, 1))
            hdf5_dataset["camera_focal_length"][-1] = sample["camera"]["focal_length"].numpy()

            hdf5_dataset["camera_sensor_width"].resize((hdf5_dataset["camera_sensor_width"].shape[0] + 1, 1))
            hdf5_dataset["camera_sensor_width"][-1] = sample["camera"]["sensor_width"].numpy()

            hdf5_dataset["camera_position"].resize((hdf5_dataset["camera_position"].shape[0] + len(rgb), 3))
            hdf5_dataset["camera_position"][offset:] = sample["camera"]["positions"].numpy()

            hdf5_dataset["camera_rotation_quaternion"].resize((hdf5_dataset["camera_rotation_quaternion"].shape[0] + len(rgb), 4))
            hdf5_dataset["camera_rotation_quaternion"][offset:] = sample["camera"]["quaternions"].numpy()
            
            #print camera properties for debugging (from hdf5)
            """
            print("camera_field_of_view", hdf5_dataset["camera_field_of_view"][-1])
            print("camera_focal_length", hdf5_dataset["camera_focal_length"][-1])
            print("camera_sensor_width", hdf5_dataset["camera_sensor_width"][-1])
            print("camera_position", hdf5_dataset["camera_position"][offset:])
            print("camera_rotation_quaternion", hdf5_dataset["camera_rotation_quaternion"][offset:])
            """


            for i, masks in enumerate(instances):
                if masks.shape[0] > 0:
                    bboxes = mask_bboxes.compute(masks)
                    mask_offset = hdf5_dataset["instance_masks"].shape[0]
                    hdf5_dataset["instance_masks"].resize((hdf5_dataset["instance_masks"].shape[0] + masks.shape[0], 1, size[0], size[1]))
                    hdf5_dataset["instance_masks_images"].resize((hdf5_dataset["instance_masks_images"].shape[0] + masks.shape[0], 1))
                    hdf5_dataset["instance_mask_bboxes"].resize((hdf5_dataset["instance_mask_bboxes"].shape[0] + masks.shape[0], 4))

                    hdf5_dataset["instance_masks"][mask_offset:] = masks
                    hdf5_dataset["instance_masks_images"][mask_offset:] = np.ones((masks.shape[0], 1)) * (offset + i)
                    hdf5_dataset["instance_mask_bboxes"][mask_offset:] = bboxes

                    # check that all bounding boxes have a valid size ( caclulate area and check if it is > 0)
                    for bbox in bboxes:
                        if bbox[2] - bbox[0] <= 0 or bbox[3] - bbox[1] <= 0:
                            print("invalid bbox", bbox)

                hdf5_dataset["image_instance_indices"].resize((hdf5_dataset["image_instance_indices"].shape[0] + 1, 2))
                hdf5_dataset["image_instance_indices"][offset + i] = [mask_offset, masks.shape[0]]

                    # test if the bboxes are computed correctly by drawing them in red on the mask
                """
                if masks.shape[0] > 0:
                    foreground_mask = hdf5_dataset["foreground_mask"][offset + i, 0]
                    cv2.imwrite(f"foreground-mask-{i:05d}.jpg", foreground_mask * 255)

                    depth_image = hdf5_dataset["depth_images"][offset + i, 0]
                    cv2.imwrite(f"depth-image-{i:05d}.jpg", depth_image * 255)

                    rgb_image = hdf5_dataset["rgb_images"][offset + i].transpose(1, 2, 0)
                    cv2.imwrite(f"rgb-image-{i:05d}.jpg", rgb_image * 255)

                    forward_flow = hdf5_dataset["forward_flow"][offset + i].transpose(1, 2, 0)
                    normalized_forward_flow = (forward_flow / np.linalg.norm(forward_flow, axis=-1, keepdims=True)) * 0.5 + 0.5
                    normalized_forward_flow = np.concatenate((np.zeros((size[0], size[1], 1)), normalized_forward_flow), axis=-1)
                    cv2.imwrite(f"forward-flow-{i:05d}.jpg", normalized_forward_flow * 255)

                    backward_flow = hdf5_dataset["backward_flow"][offset + i].transpose(1, 2, 0)
                    normalized_backward_flow = (backward_flow / np.linalg.norm(backward_flow, axis=-1, keepdims=True)) * 0.5 + 0.5
                    normalized_backward_flow = np.concatenate((np.zeros((size[0], size[1], 1)), normalized_backward_flow), axis=-1)
                    cv2.imwrite(f"backward-flow-{i:05d}.jpg", normalized_backward_flow * 255)

                    for j, mask in enumerate(masks):
                        mask = hdf5_dataset["instance_masks"][mask_offset + j, 0]
                        mask = np.stack((mask, mask, mask), axis=-1)
                        bbox = hdf5_dataset["instance_mask_bboxes"][mask_offset + j]

                        rgb_image = hdf5_dataset["rgb_images"][offset + i].transpose(1, 2, 0)
                        depth_image = hdf5_dataset["depth_images"][offset + i].transpose(1, 2, 0)

                        mask = mask * 0.5 + rgb_image * 0.5 

                        cv2.rectangle(mask, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0.0, 0.0, 1.0), 1)

                        cv2.imwrite(f"rgb-mask-bbox-{i:05d}-{j:05d}.jpg", mask * 255.0)

                        mask = hdf5_dataset["instance_masks"][mask_offset + j, 0]
                        mask = np.stack((mask, mask, mask), axis=-1)
                        mask = mask * 0.5 + depth_image * 0.5 

                        cv2.rectangle(mask, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0.0, 0.0, 1.0), 1)

                        cv2.imwrite(f"depth-mask-bbox-{i:05d}-{j:05d}.jpg", mask * 255.0)
                        print(f"mask-bbox-{i}-{j}.jpg", flush=True)
                """

    hdf5_dataset.close()


if __name__ == "__main__":

    #convert_dataset("/media/chief/data/Kubric-Datasets/", "/media/chief/data/Kubric-Datasets/", "movi-a", "validation", (256, 256))
    #convert_dataset("/media/chief/data/Kubric-Datasets/", "/media/chief/data/Kubric-Datasets/", "movi-a", "train", (256, 256))

    #convert_dataset("/media/chief/data/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-b", "validation", (256, 256))
    #convert_dataset("/media/chief/data/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-b", "train", (256, 256))

    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-c", "validation", (256, 256))
    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-c", "test", (256, 256))
    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-c", "train", (256, 256))

    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-d", "validation", (256, 256))
    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-d", "test", (256, 256))
    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-d", "train", (256, 256))

    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-e", "validation", (256, 256))
    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-e", "test", (256, 256))
    #convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-e", "train", (256, 256))

    convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-f", "validation", (512, 512))
    convert_dataset("/media/chief/HDD8TB/Kubric-Datasets/", "/media/chief/HDD8TB/Kubric-Datasets/", "movi-f", "train", (512, 512))
