from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
from einops import rearrange, reduce
import numpy as np
import lz4.frame
import io
import cv2
import pickle
from tqdm import tqdm
import torch as th
import torch.nn as nn
import torch.nn.functional as F


class MaskBBoxNumpy:
    def __init__(self, size):
        height, width = size
        x_range = np.linspace(0, width, width)
        y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        mask = (mask > 0.75).astype(np.float32)

        x_masked = self.x_coords * mask
        y_masked = self.y_coords * mask

        x_min = np.min(np.where(x_masked > 0, x_masked, np.inf), axis=(2, 3))
        y_min = np.min(np.where(y_masked > 0, y_masked, np.inf), axis=(2, 3))
        x_max = np.max(np.where(x_masked > 0, x_masked, -np.inf), axis=(2, 3))
        y_max = np.max(np.where(y_masked > 0, y_masked, -np.inf), axis=(2, 3))

        bbox = np.stack([x_min, y_min, x_max, y_max], axis=1).squeeze(2)

        return bbox

def save_to_hdf5(hdf5_file_path, data_path, type: str = 'FULL', size: Tuple[int, int] = (384, 384)):
    mask_bboxes = MaskBBoxNumpy(size)


    files = []
    for root, _, filenames in os.walk(data_path):
        for filename in filenames:
            if filename.endswith('.png'):
                files.append(os.path.join(root, filename))


    with h5py.File(hdf5_file_path, "w") as hdf5_file:

        # Create datasets for rgb_images, depth_images, and instance_masks
        hdf5_file.create_dataset(
            "rgb_images",   
            (len(files), 3, size[0], size[1]), 
            maxshape=(len(files), 3, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 3, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "depth_images", 
            (len(files), 1, size[0], size[1]), 
            maxshape=(len(files), 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "foreground_mask",
            (len(files), 1, size[0], size[1]),
            maxshape=(len(files), 1, size[0], size[1]),
            dtype=np.float32,
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "image_instance_indices",
            (len(files), 2), # start index, number of instances
            maxshape=(len(files), 2),
            dtype=np.long,
            compression='gzip',
            compression_opts=5,
        )
        hdf5_file.create_dataset(
            "instance_masks", 
            (len(files), 1, size[0], size[1]), 
            maxshape=(len(files), 1, size[0], size[1]), 
            dtype=np.float32, 
            compression='gzip',
            compression_opts=5,
            chunks=(1, 1, size[0], size[1])
        )
        hdf5_file.create_dataset(
            "instance_masks_images", 
            (len(files), 1), 
            maxshape=(len(files), 1), 
            compression='gzip',
            compression_opts=5,
            dtype=np.long,
        )
        hdf5_file.create_dataset(
            "instance_mask_bboxes", 
            (len(files), 4), 
            maxshape=(len(files), 4), 
            compression='gzip',
            compression_opts=5,
            dtype=np.float32, 
        )
        hdf5_file.create_dataset(
            "sequence_indices",
            (len(files), 2), # start index, number of images
            maxshape=(len(files), 2),
            dtype=np.long,
            compression='gzip',
            compression_opts=5,
        )

        # Create a metadata group and set the attributes
        metadata_grp = hdf5_file.create_group("metadata")
        metadata_grp.attrs["dataset_name"] = 'ShapeNetRendering'
        metadata_grp.attrs["type"] = type

        hdf5_dataset = hdf5_file

        # visualize using tqdm
        with tqdm(total=len(files)) as pbar:

            for index, file in enumerate(files):
                pbar.update(1)
                
                 # read image with alpha channel
                raw_img  = cv2.imread(file, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
                raw_mask = raw_img[:,:,3:]
                raw_img  = raw_img[:,:,0:3]

                img = raw_img.transpose(2, 0, 1)
                mask = raw_mask.transpose(2, 0, 1)
                bboxes = np.array([[0, 0, size[0], size[1]]])

                # if raw image is bigger than size, resize
                if raw_img.shape[0] > size[0] or raw_img.shape[1] > size[1]:
                    img = F.interpolate(th.from_numpy(img).unsqueeze(0), size=size, mode='bicubic', align_corners=True).squeeze(0).numpy()
                    mask = F.interpolate(th.from_numpy(mask).unsqueeze(0), size=size, mode='bicubic', align_corners=True).squeeze(0).numpy()
                else:

                    # copy to center
                    raw_size = raw_img.shape
                    img = np.zeros((3, size[0], size[1]), dtype=np.float32)
                    mask = np.zeros((1, size[0], size[1]), dtype=np.float32)
                    img[:, (size[0] - raw_size[0]) // 2:(size[0] - raw_size[0]) // 2 + raw_size[0], (size[1] - raw_size[1]) // 2:(size[1] - raw_size[1]) // 2 + raw_size[1]] = raw_img.transpose(2, 0, 1)
                    mask[:, (size[0] - raw_size[0]) // 2:(size[0] - raw_size[0]) // 2 + raw_size[0], (size[1] - raw_size[1]) // 2:(size[1] - raw_size[1]) // 2 + raw_size[1]] = raw_mask.transpose(2, 0, 1)
                    bboxes = mask_bboxes.compute(np.expand_dims(mask, axis=0))

                #make img_background black
                img = img * mask


                # Resize the datasets to accommodate the new data
                #hdf5_file['rgb_images'].resize(hdf5_file['rgb_images'].shape[0] + 1, axis=0)
                #hdf5_file['foreground_mask'].resize(hdf5_file['foreground_mask'].shape[0] + 1, axis=0)
                #hdf5_file['instance_masks'].resize(hdf5_file['instance_masks'].shape[0] + 1, axis=0)
                #hdf5_file['instance_mask_bboxes'].resize(hdf5_file['instance_mask_bboxes'].shape[0] + 1, axis=0)

                # Add the new data to the datasets
                hdf5_file['rgb_images'][index] = img
                hdf5_file['foreground_mask'][index] = mask
                hdf5_file['instance_masks'][index] = mask
                hdf5_file['instance_mask_bboxes'][index] = bboxes[0]

                #hdf5_dataset["image_instance_indices"].resize((hdf5_dataset["image_instance_indices"].shape[0] + 1, 2))
                hdf5_dataset["image_instance_indices"][index] = [hdf5_dataset["instance_masks"].shape[0] - 1, 1]

                #hdf5_dataset["instance_masks_images"].resize((hdf5_dataset["instance_masks_images"].shape[0] + 1, 1))
                hdf5_dataset["instance_masks_images"][index] = hdf5_dataset["rgb_images"].shape[0] - 1

                # save images from hdf5 dataset for debugging
                """
                cv2.imwrite(f"rgb_image{index}.png", hdf5_dataset["rgb_images"][index].transpose(1, 2, 0) * 255)
                cv2.imwrite(f"foreground_mask{index}.png", hdf5_dataset["foreground_mask"][index].transpose(1, 2, 0) * 255)

                mask = (hdf5_dataset["instance_masks"][index].transpose(1, 2, 0) > 0.75).astype(np.float32)
                bbox = hdf5_dataset["instance_mask_bboxes"][index]

                mask_center_image = np.zeros((size[0], size[1], 3), dtype=np.float32)
                mask_center_image[:, :, 0] = mask[:, :, 0]
                mask_center_image[:, :, 1] = mask[:, :, 0]
                mask_center_image[:, :, 2] = mask[:, :, 0]

                # draw box on mask in red
                mask_center_image = cv2.rectangle(mask_center_image, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 1), 1)

                # now save the image
                cv2.imwrite(f"mask_bbox_{index}_{int(np.sum((mask > 0.75).astype(np.float32)))}.jpg", mask_center_image * 255)
                """

if __name__ == "__main__":
    #save_to_hdf5("/media/chief/HDD8TB/ShapeNetRendering/shapenet-renderings-FULL-384x384.hdf5", "/media/chief/HDD8TB/ShapeNetRendering/")
    #save_to_hdf5("/media/chief/data/ShapeNetRendering/shapenet-renderings-FULL-128x128.hdf5", "/media/chief/data/ShapeNetRendering/", "FULL", [128, 128])
    save_to_hdf5("/media/chief/data/ShapeNetRendering/shapenet-renderings-FULL-64x64.hdf5", "/media/chief/data/ShapeNetRendering/", "FULL", [64, 64])
