from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
from einops import rearrange, reduce
import numpy as np
import lz4.frame
import io
import cv2
import pickle


class MaskBBoxNumpy:
    def __init__(self, size):
        height, width = size
        x_range = np.linspace(0, width, width)
        y_range = np.linspace(0, height, height)

        x_coords, y_coords = np.meshgrid(x_range, y_range)

        self.x_coords = x_coords[None, None, :, :]
        self.y_coords = y_coords[None, None, :, :]

    def compute(self, mask):
        mask = (mask > 0.75).astype(np.float32)

        x_masked = self.x_coords * mask
        y_masked = self.y_coords * mask

        x_min = np.min(np.where(x_masked > 0, x_masked, np.inf), axis=(2, 3))
        y_min = np.min(np.where(y_masked > 0, y_masked, np.inf), axis=(2, 3))
        x_max = np.max(np.where(x_masked > 0, x_masked, -np.inf), axis=(2, 3))
        y_max = np.max(np.where(y_masked > 0, y_masked, -np.inf), axis=(2, 3))

        bbox = np.stack([x_min, y_min, x_max, y_max], axis=1).squeeze(2)

        return bbox

class RamImage():
    def __init__(self, path):

        fd = open(path, 'rb')
        img_str = fd.read()
        fd.close()

        self.img_raw = np.frombuffer(img_str, np.uint8)

    def to_numpy(self):
        img = cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR).astype(np.float32).transpose(2, 0, 1) / 255.0
        return np.expand_dims(img, axis=0)

class CompressedArray():
    def __init__(self, array):
        self.compressed_array = io.BytesIO()

        with lz4.frame.open(self.compressed_array, mode='wb', compression_level=3) as f:
            np.save(f, array)

        # Seek to the start of the compressed buffer
        self.compressed_array.seek(0)

    def to_numpy(self):
        with lz4.frame.open(self.compressed_array, mode='rb') as f:
            decompressed_array = np.load(f)

        # Seek to the start of the compressed buffer for future reuse
        self.compressed_array.seek(0)

        return decompressed_array

class ClevrDataset(data.Dataset):

    def save(self):
        with open(self.file, "wb") as outfile:
    	    pickle.dump({ 
                'imgs': self.imgs, 
                'masks': self.masks, 
                'depths': self.depths, 
                'max_index': self.max_index
            }, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            data = pickle.load(infile)
            print(data.keys())
            self.imgs      = data['imgs']
            self.masks     = data['masks']
            self.depths    = data['depths']
            self.max_index = data['max_index']


    def __init__(self, type: str = 'train'):

        data_path = f'/media/chief/data/CLEVR/train/320x240'
        self.path = data_path

        self.file = os.path.join(data_path, f'/media/chief/data/CLEVR/dataset-{type}.pickle')
        self.type = type

        if os.path.exists(self.file):
            self.load()
        else:
        
            self.imgs   = []
            self.masks  = []
            self.depths = []
            self.max_index = []
            for index in range(0 if type == 'train' else 90000, 90000 if type == 'train' else 100000):

                masks = np.zeros((1, 11, 240, 320)).astype(np.float32)
                for i in range(11):
                    m = (cv2.imread(f'{self.path}/mask{index:010d}-{i:02d}.png') / 255.0).astype(np.float32)
                    masks[0,i] = reduce(m, 'h w c -> h w', 'mean').astype(np.float32)

                mask_mean = np.mean(masks[:,1:], axis=(2,3), keepdims=True).astype(np.float32)
                depth = np.sum(masks[:,1:] * mask_mean, axis=1, keepdims=True).astype(np.float32)

                max_index = 9
                while mask_mean[0,max_index,0,0] < 0.000001:
                    max_index -= 1

                self.imgs.append(RamImage(f'{self.path}/image{index:010d}.jpg'))
                self.masks.append(CompressedArray(masks[:,1:max_index+2]))
                self.depths.append(CompressedArray(depth))
                self.max_index.append(max_index+1)
                
                print(f'loading CLEVR {type} {index} [{(index - (0 if type == "train" else 90000)) * 100 / (90000 if type == "train" else 10000):.2f}%]')

            self.save()

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index: int):

        img   = self.imgs[index].to_numpy()
        depth = self.depths[index].to_numpy()

        mask = self.masks[index].to_numpy()

        return img, depth, mask

    def save_to_hdf5(self, hdf5_file_path):
        size = (240, 320)
        mask_bboxes = MaskBBoxNumpy(size)
        with h5py.File(hdf5_file_path, "w") as hdf5_file:

            # Create datasets for rgb_images, depth_images, and instance_masks
            hdf5_file.create_dataset(
                "rgb_images",   
                (0, 3, size[0], size[1]), 
                maxshape=(None, 3, size[0], size[1]), 
                dtype=np.float32, 
                compression='gzip',
                compression_opts=5,
                chunks=(1, 3, size[0], size[1])
            )
            hdf5_file.create_dataset(
                "depth_images", 
                (0, 1, size[0], size[1]), 
                maxshape=(None, 1, size[0], size[1]), 
                dtype=np.float32, 
                compression='gzip',
                compression_opts=5,
                chunks=(1, 1, size[0], size[1])
            )
            hdf5_file.create_dataset(
                "foreground_mask",
                (0, 1, size[0], size[1]),
                maxshape=(None, 1, size[0], size[1]),
                dtype=np.float32,
                compression='gzip',
                compression_opts=5,
                chunks=(1, 1, size[0], size[1])
            )
            hdf5_file.create_dataset(
                "image_instance_indices",
                (0, 2), # start index, number of instances
                maxshape=(None, 2),
                dtype=np.long,
                compression='gzip',
                compression_opts=5,
            )
            hdf5_file.create_dataset(
                "instance_masks", 
                (0, 1, size[0], size[1]), 
                maxshape=(None, 1, size[0], size[1]), 
                dtype=np.float32, 
                compression='gzip',
                compression_opts=5,
                chunks=(1, 1, size[0], size[1])
            )
            hdf5_file.create_dataset(
                "instance_masks_images", 
                (0, 1), 
                maxshape=(None, 1), 
                compression='gzip',
                compression_opts=5,
                dtype=np.long,
            )
            hdf5_file.create_dataset(
                "instance_mask_bboxes", 
                (0, 4), 
                maxshape=(None, 4), 
                compression='gzip',
                compression_opts=5,
                dtype=np.float32, 
            )
            hdf5_file.create_dataset(
                "sequence_indices",
                (0, 2), # start index, number of images
                maxshape=(None, 2),
                dtype=np.long,
                compression='gzip',
                compression_opts=5,
            )

            # Create a metadata group and set the attributes
            metadata_grp = hdf5_file.create_group("metadata")
            metadata_grp.attrs["dataset_name"] = 'CLEVR'
            metadata_grp.attrs["type"] = self.type

            hdf5_dataset = hdf5_file

            for index in range(len(self)):
                img, _, mask = self.__getitem__(index)
                mask = mask[0]
                mask = np.expand_dims(mask, axis=1)

                bboxes = mask_bboxes.compute(mask)

                # Resize the datasets to accommodate the new data
                hdf5_file['rgb_images'].resize(hdf5_file['rgb_images'].shape[0] + 1, axis=0)
                hdf5_file['foreground_mask'].resize(hdf5_file['foreground_mask'].shape[0] + 1, axis=0)
                hdf5_file['instance_masks'].resize(hdf5_file['instance_masks'].shape[0] + mask.shape[0], axis=0)
                hdf5_file['instance_mask_bboxes'].resize(hdf5_file['instance_mask_bboxes'].shape[0] + mask.shape[0], axis=0)

                # Add the new data to the datasets
                hdf5_file['rgb_images'][-1] = img
                hdf5_file['foreground_mask'][-1] = np.sum(mask, axis=0, keepdims=True)
                hdf5_file['instance_masks'][-mask.shape[0]:] = mask
                hdf5_file['instance_mask_bboxes'][-mask.shape[0]:] = bboxes

                hdf5_dataset["image_instance_indices"].resize((hdf5_dataset["image_instance_indices"].shape[0] + 1, 2))
                hdf5_dataset["image_instance_indices"][-1] = [hdf5_dataset["instance_masks"].shape[0] - mask.shape[0], mask.shape[0]]

                hdf5_dataset["instance_masks_images"].resize((hdf5_dataset["instance_masks_images"].shape[0] + mask.shape[0], 1))
                hdf5_dataset["instance_masks_images"][-mask.shape[0]:] = hdf5_dataset["rgb_images"].shape[0] - 1

                # save images from hdf5 dataset for debugging
                """
                cv2.imwrite(f"rgb_image{index}.png", hdf5_dataset["rgb_images"][index].transpose(1, 2, 0) * 255)
                cv2.imwrite(f"foreground_mask{index}.png", hdf5_dataset["foreground_mask"][index].transpose(1, 2, 0) * 255)

                mask_offset = hdf5_file['instance_masks'].shape[0] - mask.shape[0]
                for j in range(mask.shape[0]):
                    mask = hdf5_file['instance_masks'][j:j+1]
                    bbox = bboxes[j:j+1]

                    mask = (hdf5_dataset["instance_masks"][mask_offset + j].transpose(1, 2, 0) > 0.75).astype(np.float32)
                    bbox = hdf5_dataset["instance_mask_bboxes"][mask_offset + j]

                    mask_center_image = np.zeros((size[0], size[1], 3), dtype=np.float32)
                    mask_center_image[:, :, 0] = mask[:, :, 0]
                    mask_center_image[:, :, 1] = mask[:, :, 0]
                    mask_center_image[:, :, 2] = mask[:, :, 0]

                    # draw box on mask in red
                    mask_center_image = cv2.rectangle(mask_center_image, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 1), 1)

                    # now save the image
                    cv2.imwrite(f"mask_bbox_{index}_{j}_{int(np.sum((mask > 0.75).astype(np.float32)))}.jpg", mask_center_image * 255)
                """
                print(f'Saving to hdf5 file [{index * 100 / len(self):.2f}%]')

if __name__ == "__main__":
    dataset = ClevrDataset("test")
    dataset.save_to_hdf5("/media/chief/data/CLEVR/dataset-test-320x240.hdf5")
    dataset = ClevrDataset("train")
    dataset.save_to_hdf5("/media/chief/data/CLEVR/dataset-train-320x240.hdf5")
