from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
import lz4.frame
import io
import pickle
from multiprocessing import shared_memory
from data.datasets.utils import RamImage, CompressedArray


class MOViEObjectSample(data.Dataset):
    def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]):

        data_path = os.path.join(root_path, data_path, f'{size[0]}x{size[1]}')

        rgb = []
        depth = []
        instances = []
        self.size = size
        self.num_instances = 0

        for file in os.listdir(data_path):
            if file.startswith("frame") and file.endswith(".jpg"):
                rgb.append(os.path.join(data_path, file))
            if file.startswith("depth") and file.endswith(".jpg"):
                depth.append(os.path.join(data_path, file))
            if file.startswith("segmentations") and file.endswith(".png"):
                instances.append(os.path.join(data_path, file))

        rgb.sort()
        depth.sort()
        instances.sort()
        self.rgb = []
        self.depth = []
        self.instances = []
        self.num_instances = []
        for i, path in enumerate(rgb):
            self.rgb.append(RamImage(path, 'color'))
        for i, path in enumerate(depth):
            self.depth.append(RamImage(path, 'greyscale'))
        for i, path in enumerate(instances):
            instance_mask  = np.expand_dims(cv2.imread(path, cv2.IMREAD_GRAYSCALE).astype(np.float32), axis=0) / 255.0
            unique_indices = np.unique(instance_mask)
            unique_indices = unique_indices[unique_indices != 0]
            valid_indices  = []

            for index in unique_indices:
                instance_size = np.sum(instance_mask == index)
                if instance_size >= 25:
                    valid_indices.append(index)

            filtered_masks = [(instance_mask == index).astype(np.float32) for index in valid_indices]
            self.instances.append(CompressedArray(np.stack(filtered_masks)))
            self.num_instances.append(len(filtered_masks))

        self.num_instances_total = sum(self.num_instances)

    def __len__(self):
        return self.num_instances_total

    def __getitem__(self, index: int):

        frame_index = 0
        while index >= self.num_instances[frame_index]:
            index -= self.num_instances[frame_index]
            frame_index += 1

            assert frame_index < len(self.num_instances)

        rgb           = self.rgb[frame_index].to_numpy()
        depth         = self.depth[frame_index].to_numpy()
        instance_mask = self.instances[frame_index].to_numpy()[index]

        return rgb, depth, instance_mask

class MOViEBackgroundSample(data.Dataset):
    def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]):

        self.load_instances = load_instances
        self.load_bg = load_bg
        data_path = os.path.join(root_path, data_path, f'{size[0]}x{size[1]}')

        rgb = []
        depth = []
        bg_depth = []
        bg_rgb = []
        uncertainty = []
        self.size = size

        for file in os.listdir(data_path):
            if file.startswith("frame") and file.endswith(".jpg"):
                rgb.append(os.path.join(data_path, file))
            if file.startswith("depth") and file.endswith(".jpg"):
                depth.append(os.path.join(data_path, file))
            if file.startswith("bg-depth") and file.endswith(".jpg"):
                bg_depth.append(os.path.join(data_path, file))
            if file.startswith("bg-rgb") and file.endswith(".jpg"):
                bg_rgb.append(os.path.join(data_path, file))
            if file.startswith("uncertainty") and file.endswith(".jpg"):
                uncertainty.append(os.path.join(data_path, file))

        rgb.sort()
        depth.sort()
        bg_depth.sort()
        bg_rgb.sort()
        uncertainty.sort()
        self.rgb = []
        self.depth = []
        self.bg_depth = []
        self.bg_rgb = []
        self.uncertainty = []
        for path in rgb:
            self.rgb.append(RamImage(path))
        for path in depth:
            self.depth.append(RamImage(path, 'greyscale'))
        for path in bg_depth:
            self.bg_depth.append(RamImage(path, 'greyscale'))
        for path in bg_rgb:
            self.bg_rgb.append(RamImage(path))
        for path in uncertainty:
            self.uncertainty.append(RamImage(path, 'greyscale'))

    def get_data(self):

        rgb = np.zeros((24, 3, self.size[1], self.size[0]), dtype=np.float32)
        depth = np.zeros((24, 1, self.size[1], self.size[0]), dtype=np.float32)
        bg_depth = np.zeros((24, 1, self.size[1], self.size[0]), dtype=np.float32) if self.load_bg else None
        bg_rgb = np.zeros((24, 3, self.size[1], self.size[0]), dtype=np.float32) if self.load_bg else None
        uncertainty = np.zeros((24, 1, self.size[1], self.size[0]), dtype=np.float32) if self.load_bg else None
        for i in range(len(self.rgb)):
            rgb[i] = self.rgb[i].to_numpy()
            depth[i] = self.depth[i].to_numpy()
            bg_depth[i] = self.bg_depth[i].to_numpy()
            bg_rgb[i] = self.bg_rgb[i].to_numpy()
            uncertainty[i] = self.uncertainty[i].to_numpy()

        return rgb, depth, bg_rgb, bg_depth, uncertainty

class MOViEDataset(data.Dataset):

    def save(self):
        with open(self.file, "wb") as outfile:
    	    pickle.dump(self.samples, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            self.samples = pickle.load(infile)

    def __getstate__(self):
        print("Pickling the object...")
        return self.__dict__

    def __setstate__(self, state):
        print("Unpickling the object...")
        self.__dict__.update(state)

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], objects = False, background = False):

        data_path  = f'data/data/video/{dataset_name}'
        data_path  = os.path.join(root_path, data_path)
        self.file  = os.path.join(data_path, f'dataset-{type}-{size[0]}x{size[1]}.pickle')
        if objects:
            self.file = os.path.join(data_path, f'dataset-objects-lightning-v2-{type}-{size[0]}x{size[1]}.pickle')
        if background:
            self.file = os.path.join(data_path, f'dataset-background-{type}-{size[0]}x{size[1]}.pickle')

        print(f"Loading MOVi-E {type} from {data_path}", flush=True)
        data_path = os.path.join(root_path, data_path, type)

        self.samples = []

        if os.path.exists(self.file):
            self.load()
        else:
            samples     = list(filter(lambda x: x.startswith("0"), next(os.walk(data_path))[1]))
            num_samples = len(samples)

            for i, dir in enumerate(samples):
                if objects:
                    self.samples.append(MOViEObjectSample(data_path, dir, size))
                elif background:
                    self.samples.append(MOViEBackgroundSample(data_path, dir, size))
                else:
                    #self.samples.append(MOViESample(data_path, dir, size))
                    assert False, "Not implemented"

                print(f"Loading MOVi-E {type} [{i * 100 / num_samples:.2f}]: {self.samples[-1].num_instances}", flush=True)

            self.save()
        
        self.length = sum([len(sample) for sample in self.samples])
        print(f"MOViEDataset: {self.length}")

        if len(self) == 0:
            raise FileNotFoundError(f'Found no dataset at {data_path}')

    def __len__(self):
        return self.length

    def __getitem__(self, index: int):

        # find the right sample
        sample_index = 0
        while index >= len(self.samples[sample_index]):
            index -= len(self.samples[sample_index])
            sample_index += 1

        return self.samples[sample_index][index]

