from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
from einops import rearrange, reduce
import numpy as np
import lz4.frame
import io
import cv2
import pickle
from tqdm import tqdm
import torch as th
import torch.nn.functional as F


class MaskBBoxNumpy:
    def __init__(self, size):
        height, width = size
        x_range = np.linspace(0, width, width)
        y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        mask = (mask > 0.75).astype(np.float32)

        x_masked = self.x_coords * mask
        y_masked = self.y_coords * mask

        x_min = np.min(np.where(x_masked > 0, x_masked, np.inf), axis=(2, 3))
        y_min = np.min(np.where(y_masked > 0, y_masked, np.inf), axis=(2, 3))
        x_max = np.max(np.where(x_masked > 0, x_masked, -np.inf), axis=(2, 3))
        y_max = np.max(np.where(y_masked > 0, y_masked, -np.inf), axis=(2, 3))

        bbox = np.stack([x_min, y_min, x_max, y_max], axis=1).squeeze(2)

        return bbox

def compress_image(image, format='.jpg'):

    # Encode image to the specified format using OpenCV
    is_success, buffer = cv2.imencode(format, image.transpose(1, 2, 0) * 255.0)
    if is_success:
        return np.array(buffer) 
    else:
        raise Exception("Failed to compress image")

def decompress_image(buffer):
    # Decode buffer to image using OpenCV
    image = cv2.imdecode(buffer, cv2.IMREAD_UNCHANGED)
    if len(image.shape) == 2:
        image = image[:,:,None]
    return image.transpose(2, 0, 1) / 255.0

def save_to_hdf5(hdf5_file_path, data_path, type: str = 'FULL', size: Tuple[int, int] = (256, 256)):
    mask_bboxes = MaskBBoxNumpy(size)


    files = []
    for root, _, filenames in os.walk(os.path.join(data_path, 'image')):
        for filename in filenames:
            if filename.endswith('.png'):
                files.append(os.path.join(root, filename))


    files = files[:1000]

    with h5py.File(hdf5_file_path, "w") as hdf5_file:

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (len(files),),
            dtype=h5py.vlen_dtype(np.dtype('uint8'))
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (len(files),),
            dtype=h5py.vlen_dtype(np.dtype('uint8'))
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (0, 1, size[0], size[1]),
            maxshape=(None, 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (0, 2), # start index, number of instances
            maxshape=(None, 2),
            dtype=np.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (0, 1, size[0], size[1]), 
            maxshape=(None, 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (0, 1), 
            maxshape=(None, 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (0, 4), 
            maxshape=(None, 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (0, 2), # start index, number of images
            maxshape=(None, 2),
            dtype=np.long,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = 'ShapeNetRendering'
        metadata_grp.attrs["type"] = type

        hdf5_dataset = hdf5_file

        # visualize using tqdm
        with tqdm(total=len(files)) as pbar:

            for index, file in enumerate(files):
                pbar.update(1)

                rgb_path = file
                depth_path = file.replace('image', 'depth')
                
                 # read image with alpha channel
                raw_img  = cv2.imread(file, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
                raw_mask = raw_img[:,:,3:].transpose(2, 0, 1)
                raw_img  = raw_img[:,:,0:3].transpose(2, 0, 1)
                raw_size = raw_img.shape[1:]

                raw_depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
                raw_depth = 1 - raw_depth[:,:,:1].transpose(2, 0, 1)
                #raw_depth = raw_depth[:,:,3:].transpose(2, 0, 1)
                #raw_depth = np.log(raw_depth + 1e-6)

                # normalize depth
                #mean = np.sum(raw_depth * raw_mask) / np.sum(raw_mask)
                #std  = np.sqrt(np.sum((raw_depth - mean)**2 * raw_mask) / np.sum(raw_mask))
                #raw_depth = (1 / (1 + np.exp((raw_depth - mean) / (std + 1e-6)))) * raw_mask

                img   = None
                mask  = None
                depth = None

                # if raw image is bigger than size, resize
                if raw_img.shape[0] > size[0] or raw_img.shape[1] > size[1]:
                    img   = F.interpolate(th.from_numpy(raw_img).unsqueeze(0), size=size, mode='bicubic', align_corners=True).squeeze(0).numpy()
                    mask  = F.interpolate(th.from_numpy(raw_mask).unsqueeze(0), size=size, mode='bicubic', align_corners=True).squeeze(0).numpy()
                    depth = F.interpolate(th.from_numpy(raw_depth).unsqueeze(0), size=size, mode='bicubic', align_corners=True).squeeze(0).numpy()
                else:
                    # copy to center
                    img   = np.zeros((3, size[0], size[1]), dtype=np.float32)
                    mask  = np.zeros((1, size[0], size[1]), dtype=np.float32)
                    depth = np.zeros((1, size[0], size[1]), dtype=np.float32)
                    img[:, (size[0] - raw_size[0]) // 2:(size[0] - raw_size[0]) // 2 + raw_size[0], (size[1] - raw_size[1]) // 2:(size[1] - raw_size[1]) // 2 + raw_size[1]] = raw_img
                    mask[:, (size[0] - raw_size[0]) // 2:(size[0] - raw_size[0]) // 2 + raw_size[0], (size[1] - raw_size[1]) // 2:(size[1] - raw_size[1]) // 2 + raw_size[1]] = raw_mask
                    depth[:, (size[0] - raw_size[0]) // 2:(size[0] - raw_size[0]) // 2 + raw_size[0], (size[1] - raw_size[1]) // 2:(size[1] - raw_size[1]) // 2 + raw_size[1]] = raw_depth

                #make img_background black
                img   = img * mask
                depth = depth * mask

                bboxes = mask_bboxes.compute(np.expand_dims(mask, axis=0))

                # Resize the datasets to accommodate the new data
                hdf5_file['foreground_mask'].resize(hdf5_file['foreground_mask'].shape[0] + 1, axis=0)
                hdf5_file['instance_masks'].resize(hdf5_file['instance_masks'].shape[0] + 1, axis=0)
                hdf5_file['instance_mask_bboxes'].resize(hdf5_file['instance_mask_bboxes'].shape[0] + 1, axis=0)

                # Add the new data to the datasets
                hdf5_file['rgb_images'][index] = compress_image(img, '.jpg')
                hdf5_file['depth_images'][index] = compress_image(depth, '.png')
                hdf5_file['foreground_mask'][-1] = mask
                hdf5_file['instance_masks'][-1] = mask
                hdf5_file['instance_mask_bboxes'][-1] = bboxes[0]

                hdf5_dataset["image_instance_indices"].resize((hdf5_dataset["image_instance_indices"].shape[0] + 1, 2))
                hdf5_dataset["image_instance_indices"][-1] = [hdf5_dataset["instance_masks"].shape[0] - 1, 1]

                hdf5_dataset["instance_masks_images"].resize((hdf5_dataset["instance_masks_images"].shape[0] + 1, 1))
                hdf5_dataset["instance_masks_images"][-1] = hdf5_dataset["rgb_images"].shape[0] - 1

                # save images from hdf5 dataset for debugging
                """
                cv2.imwrite(f"rgb_image{index}.png", decompress_image(hdf5_dataset["rgb_images"][index]).transpose(1, 2, 0) * 255)
                cv2.imwrite(f"depth_image{index}.png", decompress_image(hdf5_dataset["depth_images"][index]).transpose(1, 2, 0) * 255)
                cv2.imwrite(f"foreground_mask{index}.png", hdf5_dataset["foreground_mask"][index].transpose(1, 2, 0) * 255)

                mask = (hdf5_dataset["instance_masks"][-1].transpose(1, 2, 0) > 0.75).astype(np.float32)
                bbox = hdf5_dataset["instance_mask_bboxes"][-1]

                mask_center_image = np.zeros((size[0], size[1], 3), dtype=np.float32)
                mask_center_image[:, :, 0] = mask[:, :, 0]
                mask_center_image[:, :, 1] = mask[:, :, 0]
                mask_center_image[:, :, 2] = mask[:, :, 0]

                # draw box on mask in red
                mask_center_image = cv2.rectangle(mask_center_image, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 1), 1)

                # now save the image
                cv2.imwrite(f"mask_bbox_{index}_{int(np.sum((mask > 0.75).astype(np.float32)))}.jpg", mask_center_image * 255)
                """

if __name__ == "__main__":
    save_to_hdf5("/media/chief/HDD8TB/ShapeNetRendering-Depth/shapenet-renderings-FULL-128x128.hdf5", "/media/chief/HDD8TB/ShapeNetRendering-Depth/", "FULL", (896, 896))
    #save_to_hdf5("/media/chief/data/ShapeNetRendering/shapenet-renderings-FULL-64x64.hdf5", "/media/chief/data/ShapeNetRendering/", "FULL", (64, 64))
