from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
import lz4.frame
import io
import pickle
import re
import sys
from PIL import Image
import torch.nn.functional as F
import torch


def video_to_numpy(path, crop_size = None, scale_size = None):
    # Initialize a VideoCapture object to read video data into a numpy array
    video = cv2.VideoCapture(path)
    frames = []

    while True:
        # Read a new frame
        ret, frame = video.read()

        # If we got a frame, append it to the frames list
        if ret:
            # OpenCV uses BGR as its default colour order for images, matplotlib uses RGB. 
            # So when displaying an image loaded with OpenCV you will need to convert it from BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = np.transpose(frame, (2, 0, 1))  # From [H, W, 3] to [3, H, W]
            frames.append(frame)
        else:
            break

    video.release()

    frames = np.stack(frames).astype(np.float32) / 255.0
    if crop_size is not None:
        frames = frames[:, :, (frames.shape[2] - crop_size[0]) // 2:(frames.shape[2] + crop_size[0]) // 2, (frames.shape[3] - crop_size[1]) // 2:(frames.shape[3] + crop_size[1]) // 2]

    if scale_size is not None:
        frames = F.interpolate(torch.from_numpy(frames), size=scale_size, mode="bicubic", align_corners=False).numpy()

    return frames

def process_mask(mask, iterations):
    """
    Apply erosion and dilation to a mask to remove noise.
    
    Parameters:
    mask (np.array): a binary instance mask
    iterations (int): number of times to perform each operation
    
    Returns:
    np.array: the processed mask
    """
    
    # Define the structuring element (a 3x3 square, in this case)
    kernel = np.ones((3,3),np.uint8)

    # Apply erosion
    eroded = cv2.erode(mask, kernel, iterations=iterations)

    # Then apply dilation
    processed_mask = cv2.dilate(eroded, kernel, iterations=iterations)

    return processed_mask

def extract_largest_connected_component(mask):
    # Assuming the mask is of shape (height, width) and has binary values 0 and 1

    # Apply connectedComponents on the mask
    num_labels, labels_im = cv2.connectedComponents(mask.astype(np.uint8))

    # Count the size of each connected component
    component_sizes = [np.sum(labels_im == i) for i in range(num_labels)]

    # Keep only the largest connected component
    largest_component_idx = np.argmax(component_sizes[1:]) + 1  # Exclude background
    largest_component = (labels_im == largest_component_idx)

    return largest_component

def load_numpy(path, crop_size = None, scale_size = None, interpolation = "bicubic"):
    frames = np.load(path)
    if path.endswith(".npz"):
        frames = frames["arr_0"]

    frames = np.expand_dims(frames.transpose(2, 0, 1), axis=1).astype(np.float32)

    if crop_size is not None:
        frames = frames[:, :, (frames.shape[2] - crop_size[0]) // 2:(frames.shape[2] + crop_size[0]) // 2, (frames.shape[3] - crop_size[1]) // 2:(frames.shape[3] + crop_size[1]) // 2]

    if scale_size is not None:
        if interpolation == "nearest":
            frames = F.interpolate(torch.from_numpy(frames), size=scale_size, mode=interpolation).numpy()
        else:
            frames = F.interpolate(torch.from_numpy(frames), size=scale_size, mode=interpolation, align_corners=False).numpy()

    return frames

def load_depth_and_instance_masks(path, crop_size = None, scale_size = None):
    depth = load_numpy(os.path.join(path, "depth_raw.npz"), crop_size, scale_size)
    instance_masks = load_numpy(os.path.join(path, "inst_raw.npz"), crop_size, scale_size, interpolation="nearest")

    fg_masks = (instance_masks > 0).astype(np.float32)

    log_depth = np.log(depth)

    depth_avg = np.sum(log_depth * fg_masks) / np.sum(fg_masks)
    depth_std = np.sqrt(np.sum(fg_masks * (log_depth - depth_avg) ** 2) / np.sum(fg_masks))

    norm_sigmoid_depth = 1 / (1 + np.exp((log_depth - depth_avg) / (2*depth_std)))

    return depth, norm_sigmoid_depth, instance_masks

class MaskBBoxNumpy:
    def __init__(self, size):
        height, width = size
        x_range = np.linspace(0, width, width)
        y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        mask = (mask > 0.75).astype(np.float32)

        x_masked = self.x_coords * mask
        y_masked = self.y_coords * mask

        x_min = np.min(np.where(x_masked > 0, x_masked, np.inf), axis=(2, 3))
        y_min = np.min(np.where(y_masked > 0, y_masked, np.inf), axis=(2, 3))
        x_max = np.max(np.where(x_masked > 0, x_masked, -np.inf), axis=(2, 3))
        y_max = np.max(np.where(y_masked > 0, y_masked, -np.inf), axis=(2, 3))

        bbox = np.stack([x_min, y_min, x_max, y_max], axis=1).squeeze(2)

        return bbox

def object_sample_to_hdf5(hdf5_dataset, data_path: str, crop_size: Tuple[int, int], scale_size: Tuple[int, int]):

    size = scale_size
    mask_bboxes = MaskBBoxNumpy(scale_size)

    rgb = video_to_numpy(os.path.join(data_path, "rgb.avi"), crop_size, scale_size)
    raw_depth, depth, all_instance_masks = load_depth_and_instance_masks(data_path, crop_size, scale_size)

    foreground_mask = (all_instance_masks > 0).astype(np.float32)

    instances = []
    num_instances = []
    for i in range(all_instance_masks.shape[0]):
        instance_mask  = all_instance_masks[i].astype(np.int32)
        unique_indices = np.unique(instance_mask)
        unique_indices = unique_indices[unique_indices > 0]

        filtered_masks = []
        for index in unique_indices:
            mask = (instance_mask == index).astype(np.uint8)
            mask = extract_largest_connected_component(mask[0]).astype(np.float32)
            if np.sum(mask) > 100:
                filtered_masks.append(mask)

        if len(filtered_masks) == 0:
            instances.append(np.zeros((0, 1, size[0], size[1]), dtype=np.float32))
            num_instances.append(0)
        else:
            instances.append(np.expand_dims(np.stack(filtered_masks, axis=0), axis=1))
            num_instances.append(len(filtered_masks))


    #increase the size of the hdf5 dataset
    offset = hdf5_dataset["rgb_images"].shape[0]
    hdf5_dataset["rgb_images"].resize((hdf5_dataset["rgb_images"].shape[0] + len(rgb), 3, size[0], size[1]))
    hdf5_dataset["raw_depth"].resize((hdf5_dataset["raw_depth"].shape[0] + len(raw_depth), 1, size[0], size[1]))
    hdf5_dataset["depth_images"].resize((hdf5_dataset["depth_images"].shape[0] + len(depth), 1, size[0], size[1]))
    hdf5_dataset["foreground_mask"].resize((hdf5_dataset["foreground_mask"].shape[0] + len(foreground_mask), 1, size[0], size[1]))

    hdf5_dataset["rgb_images"][offset:] = rgb
    hdf5_dataset["raw_depth"][offset:] = raw_depth
    hdf5_dataset["depth_images"][offset:] = depth
    hdf5_dataset["foreground_mask"][offset:] = foreground_mask

    hdf5_dataset["sequence_indices"].resize((hdf5_dataset["sequence_indices"].shape[0] + 1, 2))
    hdf5_dataset["sequence_indices"][-1] = [offset, len(rgb)]

    for i, masks in enumerate(instances):
        print(f"filling hdf5: {i * 100 / len(instances):.2f}%\r", end="")
        if masks.shape[0] > 0:
            bboxes = mask_bboxes.compute(masks)
            mask_offset = hdf5_dataset["instance_masks"].shape[0]
            hdf5_dataset["instance_masks"].resize((hdf5_dataset["instance_masks"].shape[0] + masks.shape[0], 1, size[0], size[1]))
            hdf5_dataset["instance_masks_images"].resize((hdf5_dataset["instance_masks_images"].shape[0] + masks.shape[0], 1))
            hdf5_dataset["instance_mask_bboxes"].resize((hdf5_dataset["instance_mask_bboxes"].shape[0] + masks.shape[0], 4))

            hdf5_dataset["instance_masks"][mask_offset:] = masks
            hdf5_dataset["instance_masks_images"][mask_offset:] = np.ones((masks.shape[0], 1)) * (offset + i)
            hdf5_dataset["instance_mask_bboxes"][mask_offset:] = bboxes

        hdf5_dataset["image_instance_indices"].resize((hdf5_dataset["image_instance_indices"].shape[0] + 1, 2))
        hdf5_dataset["image_instance_indices"][offset + i] = [mask_offset, masks.shape[0]]

            # test if the bboxes are computed correctly by drawing them in red on the mask
        """
        if masks.shape[0] > 0:
            foreground_mask = hdf5_dataset["foreground_mask"][offset + i, 0]
            cv2.imwrite(f"foreground-mask-{i:05d}.jpg", foreground_mask * 255)

            for j, mask in enumerate(masks):
                mask = hdf5_dataset["instance_masks"][mask_offset + j, 0]
                mask = np.stack((mask, mask, mask), axis=-1)
                bbox = hdf5_dataset["instance_mask_bboxes"][mask_offset + j]

                rgb_image = hdf5_dataset["rgb_images"][offset + i].transpose(1, 2, 0)
                depth_image = hdf5_dataset["depth_images"][offset + i].transpose(1, 2, 0)

                mask = mask * 0.5 + rgb_image * 0.5 

                cv2.rectangle(mask, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0.0, 0.0, 1.0), 1)

                cv2.imwrite(f"rgb-mask-bbox-{i:05d}-{j:05d}.jpg", mask * 255.0)

                mask = hdf5_dataset["instance_masks"][mask_offset + j, 0]
                mask = np.stack((mask, mask, mask), axis=-1)
                mask = mask * 0.5 + depth_image * 0.5 

                cv2.rectangle(mask, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0.0, 0.0, 1.0), 1)

                cv2.imwrite(f"depth-mask-bbox-{i:05d}-{j:05d}.jpg", mask * 255.0)
                print(f"mask-bbox-{i}-{j}.jpg", flush=True)
        """
    print("Filled hdf5                                      \r", end="")


def avoe_to_hdf5(root_path: str, dataset_name: str, type: str, crop_size: Tuple[int, int], size: Tuple[int, int]):

    data_path      = dataset_name
    hdf5_file_path = os.path.join(root_path, f"{dataset_name}-{type.replace('/','-')}-{size[0]}x{size[1]}.hdf5")
    data_path      = root_path
    print(f"Loading AVoE from {data_path}", flush=True)

    samples = []
    for trial in ['A_support', 'B_occlusion', 'C_container', 'D_collision', 'E_barrier']:
        for dir in os.listdir(os.path.join(data_path, trial, type)):
            samples.append(os.path.join(data_path, trial, type, dir))
            
    num_samples = len(samples)

    # setup the hdf5 file
    print(f"Creating hdf5 file {hdf5_file_path}", flush=True)
    with h5py.File(hdf5_file_path, "w") as hdf5_file:

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (0, 3, size[0], size[1]), 
            maxshape=(None, 3, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 3, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "raw_depth", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (0, 1, size[0], size[1]),
            maxshape=(None, 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (0, 2), # start index, number of instances
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (0, 1), 
            maxshape=(None, 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.compat.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (0, 4), 
            maxshape=(None, 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (0, 2), # start index, number of images
            maxshape=(None, 2),
            dtype=np.compat.long,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = dataset_name
        metadata_grp.attrs["type"] = type

        for i, dir in enumerate(samples):
            object_sample_to_hdf5(hdf5_file, dir, crop_size, size)
            print(f"Loading AVoE [{i * 100 / num_samples:.2f}]", flush=True)

if __name__ == "__main__":

    avoe_to_hdf5('/media/chief/data/AVoE/', 'AVoE', 'test/surprising', (540, 945), (256, 448))
    avoe_to_hdf5('/media/chief/data/AVoE/', 'AVoE', 'test/expected', (540, 945), (256, 448))
    avoe_to_hdf5('/media/chief/data/AVoE/', 'AVoE', 'validation/surprising', (540, 945), (256, 448))
    avoe_to_hdf5('/media/chief/data/AVoE/', 'AVoE', 'validation/expected', (540, 945), (256, 448))
    avoe_to_hdf5('/media/chief/data/AVoE/', 'AVoE', 'train/expected', (540, 945), (256, 448))

    """
    depth, instance_masks = load_depth_and_instance_masks('/media/chief/data/AVoE/E_barrier/train/expected/trial_141/')
    #depth, instance_masks = load_depth_and_instance_masks('/media/chief/data/AVoE/E_barrier/train/expected/trial_151/')
    #depth, instance_masks = load_depth_and_instance_masks('/media/chief/data/AVoE/A_support/train/expected/trial_1/')

    # save the depth image
    for i in range(depth.shape[0]):
        cv2.imwrite(f'depth_{i:02d}.png', depth[i, 0, :, :]*255)


    fg = (instance_masks > 0).astype(np.float32)

    # save the instance masks
    for i in range(fg.shape[0]):
        cv2.imwrite(f'fg_{i:02d}.png', fg[i, 0, :, :]*255)

        unique_indices = np.unique(instance_masks[i, 0, :, :])

        for j, idx in enumerate(unique_indices):
            mask = extract_largest_connected_component((instance_masks[i, 0, :, :] == idx).astype(np.uint8))
            cv2.imwrite(f'mask_{i:02d}_{j:02d}.png', mask*255)
    """
