from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
import lz4.frame
import io
import pickle
from multiprocessing import shared_memory
from data.datasets.utils import RamImage, CompressedArray


kitti_train = [
    "2013_05_28_drive_0000_sync",
    "2013_05_28_drive_0002_sync",
    "2013_05_28_drive_0003_sync",
    "2013_05_28_drive_0004_sync",
    "2013_05_28_drive_0005_sync",
    "2013_05_28_drive_0006_sync",
    "2013_05_28_drive_0007_sync",
    "2013_05_28_drive_0009_sync",
]
kitti_val = [
    "2013_05_28_drive_0010_sync",
]

class MaskCenterNumpy:
    def __init__(self, size, normalize=True):
        height, width = size
        if normalize:
            x_range = np.linspace(-1, 1, width)
            y_range = np.linspace(-1, 1, height)
        else:
            x_range = np.linspace(0, width, width)
            y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        center_x = np.sum(self.x_coords * mask, axis=(2, 3)) / np.sum(mask, axis=(2, 3))
        center_y = np.sum(self.y_coords * mask, axis=(2, 3)) / np.sum(mask, axis=(2, 3))

        return np.concatenate((center_x, center_y), axis=-1)

class MaskBBoxNumpy:
    def __init__(self, size):
        height, width = size
        x_range = np.linspace(0, width, width)
        y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        mask = (mask > 0.75).astype(np.float32)

        x_masked = self.x_coords * mask
        y_masked = self.y_coords * mask

        x_min = np.min(np.where(x_masked > 0, x_masked, np.inf), axis=(2, 3))
        y_min = np.min(np.where(y_masked > 0, y_masked, np.inf), axis=(2, 3))
        x_max = np.max(np.where(x_masked > 0, x_masked, -np.inf), axis=(2, 3))
        y_max = np.max(np.where(y_masked > 0, y_masked, -np.inf), axis=(2, 3))

        bbox = np.stack([x_min, y_min, x_max, y_max], axis=1).squeeze(2)

        return bbox

def load_compressed_array_from_path(array_path):
    with lz4.frame.open(array_path, mode='rb') as f:
        decompressed_array = np.load(f)
    return decompressed_array

def kitti360_sample_to_hdf5(hdf5_dataset, data_path: str, size: Tuple[int, int]):

    files = [f for f in os.listdir(data_path) if "instance_mask" in f]
    files.sort()

    mask_bboxes = MaskBBoxNumpy(size)

    hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
    hdf5_dataset["sequence_indices"][-1] = np.array([hdf5_dataset["rgb_images"].shape[0], len(files)])

    for i, file in enumerate(files):
        frame_number = int(file.split(".")[0])

        frames_path = os.path.join(data_path, "{:010d}.jpg".format(frame_number))
        depths_path = os.path.join(data_path, "{:010d}.depth.jpg".format(frame_number))
        confidences_path = os.path.join(data_path, "{:010d}.confidence.jpg".format(frame_number))
        foreground_path = os.path.join(data_path, "{:010d}.fg_mask.jpg".format(frame_number))
        instance_masks_path = os.path.join(data_path, file)

        frame  = cv2.imread(frames_path).astype(np.float32).transpose(2, 0, 1) / 255.0
        depths = np.expand_dims(cv2.imread(depths_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0, 0)
        confidence = np.expand_dims(cv2.imread(confidences_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0, 0)
        instance_masks = np.expand_dims(load_compressed_array_from_path(instance_masks_path) * confidence, 1)
        fg_mask = np.expand_dims(cv2.imread(foreground_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255.0, 0) * confidence

        bboxes = mask_bboxes.compute(instance_masks)

        offset = hdf5_dataset["rgb_images"].shape[0]
        hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + 1, 3, size[0], size[1]))
        hdf5_dataset["depth_images"].resize((hdf5_dataset["depth_images"].shape[0] + 1, 1, size[0], size[1]))
        hdf5_dataset["foreground_mask"].resize((hdf5_dataset["foreground_mask"].shape[0] + 1, 1, size[0], size[1]))   

        hdf5_dataset["rgb_images"][offset:] = np.expand_dims(frame, 0)
        hdf5_dataset["depth_images"][offset:] = np.expand_dims(depths, 0)
        hdf5_dataset["foreground_mask"][offset:] = np.expand_dims(fg_mask, 0)

        # debug foreground mask by saving it to disk
        #fg_mask = hdf5_dataset["foreground_mask"][offset:offset+1]
        #cv2.imwrite("fg_mask_{}.jpg".format(i), (fg_mask[0, 0] * 255).astype(np.uint8))

        mask_offset = hdf5_dataset["instance_masks"].shape[0]
        
        num_instances = 0
        for j in range(instance_masks.shape[0]):
            mask = instance_masks[j:j+1]
            bbox = bboxes[j:j+1]

            if np.sum((mask > 0.75).astype(np.float32)) > (10 * 10):
                num_instances += 1

                hdf5_dataset["instance_masks"].resize((hdf5_dataset["instance_masks"].shape[0] + 1, 1, size[0], size[1]))
                hdf5_dataset["instance_masks_images"].resize((hdf5_dataset["instance_masks_images"].shape[0] + 1, 1))
                hdf5_dataset["instance_mask_bboxes"].resize((hdf5_dataset["instance_mask_bboxes"].shape[0] + 1, 4))

                hdf5_dataset["instance_masks"][-1:] = mask
                hdf5_dataset["instance_masks_images"][-1:] = np.ones((1, 1)) * offset 
                hdf5_dataset["instance_mask_bboxes"][-1:] = bbox

                """
                mask = (hdf5_dataset["instance_masks"][mask_offset + num_instances-1].transpose(1, 2, 0) > 0.75).astype(np.float32)
                bbox = hdf5_dataset["instance_mask_bboxes"][mask_offset + num_instances-1]

                mask_center_image = np.zeros((size[0], size[1], 3), dtype=np.float32)
                mask_center_image[:, :, 0] = mask[:, :, 0]
                mask_center_image[:, :, 1] = mask[:, :, 0]
                mask_center_image[:, :, 2] = mask[:, :, 0]

                # draw box on mask in red
                mask_center_image = cv2.rectangle(mask_center_image, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 1), 1)

                # now save the image
                cv2.imwrite(f"mask_bbox_{i}_{j}_{int(np.sum((mask > 0.75).astype(np.float32)))}.jpg", mask_center_image * 255)
                """
        
        hdf5_dataset["image_instance_indices"].resize((hdf5_dataset["image_instance_indices"].shape[0] + 1, 2))
        hdf5_dataset["image_instance_indices"][offset:] = [mask_offset, num_instances]
        
        print(f"Loaded {i+1}/{len(files)}", end="\r", flush=True)

def kitti360_to_hdf5(root_path: str, dataset_name: str, type: str, size: Tuple[int, int], objects = False, background = False):

    data_path      = f'data/data/video/{dataset_name}'
    hdf5_file_path = os.path.join(root_path, data_path, f'dataset-objects-lightning-{type}-{size[1]}x{size[0]}.hdf5')
    print(f"Loading KITTI360 {type} from {data_path}", flush=True)

    # setup the hdf5 file
    with h5py.File(hdf5_file_path, "w") as hdf5_file:

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (0, 3, size[0], size[1]), 
            maxshape=(None, 3, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 3, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (0, 1, size[0], size[1]),
            maxshape=(None, 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (0, 2), # start index, number of instances
            maxshape=(None, 2),
            dtype=np.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (0, 1), 
            maxshape=(None, 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (0, 4), 
            maxshape=(None, 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (0, 2), # start index, number of images
            maxshape=(None, 2),
            dtype=np.long,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = dataset_name
        metadata_grp.attrs["type"] = type

        for i, dir in enumerate(kitti_train if type == "train" else kitti_val):
            if objects:
                sample_data_path = os.path.join(data_path, dir, 'cam0')
                kitti360_sample_to_hdf5(hdf5_file, sample_data_path, size)
                print(f"Loading KITTI {type} [{(i + 0.5) * 100 / len(kitti_train if type == 'train' else kitti_val):.2f}]", flush=True)

                sample_data_path = os.path.join(data_path, dir, 'cam1')
                kitti360_sample_to_hdf5(hdf5_file, sample_data_path, size)
                print(f"Loading KITTI {type} [{(i + 1) * 100 / len(kitti_train if type == 'train' else kitti_val):.2f}]", flush=True)
