import torch as th
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DistributedSampler
import h5py
import numpy as np
from data.datasets.utils import RandomHorizontalFlip, ScaleCrop
import cv2
from turbojpeg import decompress as jpeg_decompress

class HDF5_Dataset(Dataset):
    def __init__(self, hdf5_file_path, crop_size, split=None):
        
        if not isinstance(crop_size, tuple) and not isinstance(crop_size, list):
            crop_size = (crop_size, crop_size)

        self.filename = hdf5_file_path
        self.random_crop = ScaleCrop(crop_size)
        self.random_horizontal_flip = RandomHorizontalFlip(flip_dim=3)

        self.hdf5_file_path = hdf5_file_path
        self.hdf5_file = h5py.File(hdf5_file_path, "r")

        self.sequence_indices = self.hdf5_file["sequence_indices"][:]
        if split is not None:
            if split == "train":
                self.sequence_indices = self.sequence_indices[:int(len(self.sequence_indices) * 0.8)]
            else:
                self.sequence_indices = self.sequence_indices[int(len(self.sequence_indices) * 0.8):]

        self.dataset_length = sum([seq[1] for seq in self.sequence_indices])

        self.use_depth    = "depth_images" in self.hdf5_file and self.hdf5_file["depth_images"].shape[0] > 1
        self.use_fg_masks = "foreground_mask" in self.hdf5_file and self.hdf5_file["foreground_mask"].shape[0] > 1
        self.use_fg_masks = self.use_fg_masks and np.sum(self.hdf5_file["foreground_mask"][0:100]) > 0

        self.hdf5_file.close()
        self.hdf5_file = None

        print(f"Loaded HDF5 dataset {hdf5_file_path} with size {self.dataset_length}")

    def __len__(self):
        return self.dataset_length

    def __getitem__(self, sample):

        index     = sample['sequence_index']
        seed      = sample['seed']
        time_step = sample['time_step']

        # Open the HDF5 file if it is not already open
        if self.hdf5_file is None:
            self.hdf5_file = h5py.File(self.hdf5_file_path, "r")

        rgb_image       = self.hdf5_file["rgb_images"][index]
        depth_image     = self.hdf5_file["depth_images"][index] if self.use_depth else np.zeros((1, *rgb_image.shape[1:]))
        foreground_mask = self.hdf5_file["foreground_mask"][index] if self.use_fg_masks else np.zeros((1, *rgb_image.shape[1:]))

        # handle compressed datasets
        if rgb_image.dtype == np.uint8:
            rgb_image = np.array(jpeg_decompress(rgb_image)).transpose(2, 0, 1).astype(np.float32) / 255.0

            if self.use_depth:
                depth_image = np.expand_dims(cv2.imdecode(depth_image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0, axis=0)

            if self.use_fg_masks:
                foreground_mask = foreground_mask.astype(np.float32) / 255.0

        rgb_image       = th.from_numpy(rgb_image)
        depth_image     = th.from_numpy(depth_image) if self.use_depth else None
        foreground_mask = th.from_numpy(foreground_mask) if self.use_fg_masks else None

        tensor = [rgb_image]
        if self.use_depth:
            tensor.append(depth_image)
        if self.use_fg_masks:
            tensor.append(foreground_mask)

        tensor = th.cat(tensor, dim=0).unsqueeze(0)

        # applay data augmentation
        tensor = self.random_crop(tensor, seed=seed)
        tensor = self.random_horizontal_flip(tensor, seed=seed)[0]

        if self.use_depth and self.use_fg_masks:
            rgb_image, depth_image, foreground_mask = tensor[0].split([3, 1, 1], dim=0)
        elif self.use_depth:
            rgb_image, depth_image = tensor[0].split([3, 1], dim=0)
            foreground_mask = th.zeros_like(depth_image)
        elif self.use_fg_masks:
            rgb_image, foreground_mask = tensor[0].split([3, 1], dim=0)
            depth_image = th.zeros_like(instance_mask) - 1
        else:
            rgb_image = tensor[0]
            depth_image = th.zeros_like(rgb_image[:1]) - 1
            foreground_mask = th.zeros_like(rgb_image[:1])

        # save images for debugging
        #dataset_index = sample['dataset_index']
        #rank = sample['rank']
        #cv2.imwrite(f"rgb_image_{rank}_{dataset_index:03d}_{seed:03d}_{time_step:03d}.png", rgb_image.numpy().transpose(1, 2, 0) * 255)
        #if self.use_depth:
        #    cv2.imwrite(f"depth_image_{rank}_{dataset_index:03d}_{seed:03d}_{time_step:03d}.png", depth_image.numpy().transpose(1, 2, 0) * 255)
        #if self.use_fg_masks:
        #    cv2.imwrite(f"foreground_mask_{rank}_{dataset_index:03d}_{seed:03d}_{time_step:03d}.png", foreground_mask.numpy().transpose(1, 2, 0) * 255)

        return rgb_image, depth_image, foreground_mask, time_step, self.use_depth, self.use_fg_masks

class ChainedHDF5_Dataset(Dataset):
    def __init__(self, hdf5_datasets, weights):
        self.datasets = hdf5_datasets
        self.weights  = weights / np.sum(weights)

        total_length = sum([len(d) for d in self.datasets])
        self.lenght  = sum([int(total_length * w) for w in self.weights])
        print(f"dataset size: {self.lenght}, total length: {total_length}")

        print(f"dataset size: {self.lenght}")
        for d, w in zip(self.datasets, self.weights):
            print(f"resampling dataset {len(d):10d}|{100*len(d)/total_length:.1f}% -> {int(total_length * w):10d}|{100*w:.1f}% ({d.filename})")

        self.cumulative_lengths = np.cumsum([int(total_length * w) for w in self.weights])
        assert self.lenght == self.cumulative_lengths[-1]

    def __len__(self):
        return self.cumulative_lengths[-1]

    def __getitem__(self, sample):
        return self.datasets[sample['dataset_index']][sample]

# run test if the file itself is called
if __name__ == "__main__":

    import torch
    from torch.utils.data import Dataset
    from torch.utils.data.distributed import DistributedSampler
    import random
    import os
    import torch.distributed as dist

    os.environ['RANK'] = "0"
    os.environ['WORLD_SIZE'] = str(1)
    os.environ['MASTER_ADDR'] = 'localhost' 
    os.environ['MASTER_PORT'] = '29500' 
    dist.init_process_group(backend='nccl', init_method='env://')

    from data.sampler.background_sampler import DistributedBackgroundBatchSampler

    dataset1 = HDF5_Dataset("/media/chief/data/KITTI-360/dataset-objects-lightning-validation-1312x352.hdf5", 256)
    dataset2 = HDF5_Dataset("/media/chief/data/movi-e/dataset-objects-lightning-v2-validation-256x256.hdf5", 256)
    dataset = ChainedHDF5_Dataset([dataset1, dataset2], [0.5, 0.5])
    sampler = DistributedBackgroundBatchSampler(dataset, sequence_length=24, batch_size=3)
    dataloader = DataLoader(dataset, batch_sampler=sampler, num_workers=0, pin_memory=True)

    # set random seed
    th.manual_seed(1234)

    for i, batch in enumerate(dataloader):
        print(f"Testing... {i}/{len(dataloader)}, {i/len(dataloader)*100:.2f}%")

    print("done")
