from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
from PIL import Image
from einops import rearrange, reduce
import scipy.ndimage
from data.datasets.KITTI.labels import labels
import lz4.frame


kitti_train = [
    "2013_05_28_drive_0000_sync",
    "2013_05_28_drive_0002_sync",
    "2013_05_28_drive_0003_sync",
    "2013_05_28_drive_0004_sync",
    "2013_05_28_drive_0005_sync",
    "2013_05_28_drive_0006_sync",
    "2013_05_28_drive_0007_sync",
    "2013_05_28_drive_0009_sync",
]
kitti_val = [
    "2013_05_28_drive_0010_sync",
]

class CompressedArray():
    def __init__(self, array=None, array_path=None, load_to_ram=False):
        self.compressed_array = None
        self.array_path = array_path
        self.load_to_ram = load_to_ram

        if array is not None and load_to_ram:
            self.compress_array(array)
        elif array_path is not None and load_to_ram:
            self.load_compressed_array_from_path(array_path)

    def compress_array(self, array):
        self.compressed_array = io.BytesIO()

        with lz4.frame.open(self.compressed_array, mode='wb', compression_level=3) as f:
            np.save(f, array)

        # Seek to the start of the compressed buffer
        self.compressed_array.seek(0)

    def load_compressed_array_from_path(self, array_path):
        with open(array_path, 'rb') as f:
            compressed_data = f.read()

        self.compressed_array = io.BytesIO(compressed_data)
        self.compressed_array.seek(0)

    def to_numpy(self):
        if self.load_to_ram:
            with lz4.frame.open(self.compressed_array, mode='rb') as f:
                decompressed_array = np.load(f)

            # Seek to the start of the compressed buffer for future reuse
            self.compressed_array.seek(0)
        else:
            with lz4.frame.open(self.array_path, mode='rb') as f:
                decompressed_array = np.load(f)

        return decompressed_array

class RamImage():
    def __init__(self, img_path, rgb=True, load_to_ram=False):
        self.img_path = img_path
        self.rgb = rgb
        self.load_to_ram = load_to_ram

        if self.load_to_ram:
            with open(img_path, 'rb') as fd:
                img_str = fd.read()
            self.img_raw = np.frombuffer(img_str, np.uint8)

    def to_numpy(self):
        if self.load_to_ram:
            img = cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR if self.rgb else cv2.IMREAD_GRAYSCALE)
        else:
            img = cv2.imread(self.img_path, cv2.IMREAD_COLOR if self.rgb else cv2.IMREAD_GRAYSCALE)

        if not self.rgb:
            img = np.expand_dims(img, axis=2)
        return (img / 255.0).astype(np.float32).transpose(2, 0, 1)

class KittiSample(data.Dataset):
    def __init__(self, data_path: str, size: Tuple[int, int], length: int, load_to_ram=False):

        frames      = []
        depths      = []
        self.size   = size
        self.length = length

        for file in os.listdir(data_path):
            if file.endswith(".jpg") and file.count(".") == 1:
                frames.append(os.path.join(data_path, file))
            if file.endswith(".depth.jpg"):
                depths.append(os.path.join(data_path, file))

        frames.sort()
        depths.sort()
        self.imgs = []
        self.depths = []
        self.semantics = []
        self.confidences = []
        self.fg_mask = []
        for img_path in frames:
            self.imgs.append(RamImage(img_path, load_to_ram=load_to_ram))

        for img_path in depths:
            self.depths.append(RamImage(img_path, rgb=False, load_to_ram=load_to_ram))

    def __len__(self):
        return max(0, len(self.imgs) - self.length + 1)

    def get_data(self, index):
        
        frames = np.zeros((self.length,3,self.size[1], self.size[0]),dtype=np.float32)
        depths = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32)

        for i in range(index, index+self.length):
            frames[i-index] = self.imgs[i].to_numpy()
            depths[i-index] = self.depths[i].to_numpy()

        return frames, depths

class KittiBackgroundSample(data.Dataset):
    def __init__(self, data_path: str, size: Tuple[int, int], length: int, load_to_ram=False):

        frames      = []
        depths      = []
        semantics   = []
        confidences = []
        fg_mask     = []
        self.size   = size
        self.length = length

        for file in os.listdir(data_path):
            if file.endswith(".jpg") and file.count(".") == 1:
                frames.append(os.path.join(data_path, file))
            if file.endswith(".depth.jpg"):
                depths.append(os.path.join(data_path, file))
            if file.endswith(".confidence.jpg"):
                confidences.append(os.path.join(data_path, file))
            if file.endswith(".fg_mask.jpg"):
                fg_mask.append(os.path.join(data_path, file))


        frames.sort()
        depths.sort()
        confidences.sort()
        fg_mask.sort()
        print("Found {} frames".format(len(frames)))
        print("Found {} depths".format(len(depths)))
        print("Found {} confidences".format(len(confidences)))
        print("Found {} fg_mask".format(len(fg_mask)))
        self.imgs = []
        self.depths = []
        self.semantics = []
        self.confidences = []
        self.fg_mask = []
        for img_path in frames:
            self.imgs.append(RamImage(img_path, load_to_ram=load_to_ram))

        for img_path in depths:
            self.depths.append(RamImage(img_path, rgb=False, load_to_ram=load_to_ram))

        for img_path in fg_mask:
            self.fg_mask.append(RamImage(img_path, rgb=False, load_to_ram=load_to_ram))

        for img_path in confidences:
            self.confidences.append(RamImage(img_path, rgb=False, load_to_ram=load_to_ram))

    def __len__(self):
        return max(0, len(self.imgs) - self.length + 1)

    def get_data(self, index):
        
        frames = np.zeros((self.length,3,self.size[1], self.size[0]),dtype=np.float32)
        depths = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32)
        confidences = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32) 
        fg_mask     = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32) 

        for i in range(index, index+self.length):
            frames[i-index] = self.imgs[i].to_numpy()
            depths[i-index] = self.depths[i].to_numpy()
            fg_mask[i-index] = self.fg_mask[i].to_numpy()
            confidences[i-index] = self.confidences[i].to_numpy()

        return frames, depths, fg_mask, confidences


class KittiObjectSample(data.Dataset):
    def __init__(self, data_path: str, size: Tuple[int, int], load_to_ram=False):

        self.frames         = []
        self.depths         = []
        self.instance_masks = []
        self.confidences    = []
        self.num_instances  = []
        self.size           = size

        # find all the files that contains the string "instance_mask"
        files = [f for f in os.listdir(data_path) if "instance_mask" in f]
        files.sort()

        for file in files:

            # get the frame number
            frame_number = int(file.split(".")[0])

            # get number of instances in the frame
            self.num_instances.append(int(file.split(".")[2]))

            # load frame, depths, confidences and instance_mask
            self.frames.append(RamImage(os.path.join(data_path, "{:010d}.jpg".format(frame_number)), load_to_ram=load_to_ram))
            self.depths.append(RamImage(os.path.join(data_path, "{:010d}.depth.jpg".format(frame_number)), rgb=False, load_to_ram=load_to_ram))
            self.confidences.append(RamImage(os.path.join(data_path, "{:010d}.confidence.jpg".format(frame_number)), rgb=False, load_to_ram=load_to_ram))
            self.instance_masks.append(CompressedArray(array_path = os.path.join(data_path, file), load_to_ram=load_to_ram))

        self.num_instances_total = sum(self.num_instances)

    def __len__(self):
        return self.num_instances_total

    def get_data(self, index):

        # find the right frame by counting the number of instances
        frame_index = 0
        while index >= self.num_instances[frame_index]:
            index -= self.num_instances[frame_index]
            frame_index += 1

            assert frame_index < len(self.num_instances)

        # load the frame, depth and instance mask
        frames = self.frames[frame_index].to_numpy()
        depths = self.depths[frame_index].to_numpy()
        instance_mask = self.instance_masks[frame_index].to_numpy()[index:index+1]
        confidences   = self.confidences[frame_index].to_numpy()

        return frames, depths, instance_mask, confidences


class KittiDataset(data.Dataset):

    def save(self):
        state = { 'samples': self.samples, 'length': self.length }
        with open(self.file, "wb") as outfile:
    	    pickle.dump(state, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            state = pickle.load(infile)
            self.samples = state['samples']
            self.length  = state['length']

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], length: int, background: bool = False, objects: bool = False, load_to_ram=False):

        print(f'Loading {dataset_name} dataset with config: size={size}, length={length}, background={background}, objects={objects}')

        data_path   = f'data/data/video/{dataset_name}'
        data_path   = os.path.join(root_path, data_path)

        #name file according to background and object settings
        self.file   = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}-{"bg" if background else "nobg"}-{"obj" if objects else "noobj"}.pickle')
        self.type   = type
        self.length = 0

        with open(os.path.join(data_path, 'mean-std.pickle'), 'rb') as f:
            loaded_dict = pickle.load(f)
            self.depth_mean = loaded_dict['mean']
            self.depth_std  = loaded_dict['std']

        self.samples    = []

        if os.path.exists(self.file):
            self.load()
        else:

            if self.type == "train":
                for i, dir in enumerate(kitti_train):
                    if background:
                        self.samples.append(KittiBackgroundSample(os.path.join(data_path, dir, 'cam0'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    elif objects:
                        self.samples.append(KittiObjectSample(os.path.join(data_path, dir, 'cam0'), size, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    else:
                        self.samples.append(KittiSample(os.path.join(data_path, dir, 'cam0'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    print(f"Loading KITTI {type} [{(i + 0.5) * 100 / len(kitti_train):.2f}]", flush=True)
                    if background:
                        self.samples.append(KittiBackgroundSample(os.path.join(data_path, dir, 'cam1'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    elif objects:
                        self.samples.append(KittiObjectSample(os.path.join(data_path, dir, 'cam1'), size, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    else:
                        self.samples.append(KittiSample(os.path.join(data_path, dir, 'cam0'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    print(f"Loading KITTI {type} [{(i + 1) * 100 / len(kitti_train):.2f}]", flush=True)
            else:
                for i, dir in enumerate(kitti_val):
                    if background:
                        self.samples.append(KittiBackgroundSample(os.path.join(data_path, dir, 'cam0'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    elif objects:
                        self.samples.append(KittiObjectSample(os.path.join(data_path, dir, 'cam0'), size, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    else:
                        self.samples.append(KittiSample(os.path.join(data_path, dir, 'cam0'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    print(f"Loading KITTI {type} [{(i + 0.5) * 100 / len(kitti_val):.2f}]", flush=True)
                    if background:
                        self.samples.append(KittiBackgroundSample(os.path.join(data_path, dir, 'cam1'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    elif objects:
                        self.samples.append(KittiObjectSample(os.path.join(data_path, dir, 'cam1'), size, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    else:
                        self.samples.append(KittiSample(os.path.join(data_path, dir, 'cam0'), size, length, load_to_ram=load_to_ram))
                        self.length += len(self.samples[-1])
                    print(f"Loading KITTI {type} [{(i + 1) * 100 / len(kitti_val):.2f}]", flush=True)

            if load_to_ram:
                self.save()

        print(f"loaded KITTI {type} dataset with {self.length} samples")

    def __len__(self):
        return self.length

    def __getitem__(self, index: int):
        
        for i, s in enumerate(self.samples):
            if len(s) > index:
                return s.get_data(index)

            index -= len(s)

        assert False
        return None
