from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
from PIL import Image
from einops import rearrange, reduce
import scipy.ndimage


kitti_test = ["test"]

kitti_val = [
    "2011_09_26_drive_0017_sync_02",
    "2011_09_26_drive_0017_sync_03",
    "2011_09_28_drive_0043_sync_03",
    "2011_09_28_drive_0043_sync_02",
    "2011_09_26_drive_0032_sync_02",
    "2011_09_26_drive_0032_sync_03",
    "2011_09_26_drive_0061_sync_02",
    "2011_09_26_drive_0061_sync_03",
    "2011_09_28_drive_0047_sync_03",
    "2011_09_28_drive_0047_sync_02"
]

kitti_train = [
    "2011_09_30_drive_0028_sync_02",
    "2011_09_30_drive_0028_sync_03",
    "2011_09_26_drive_0091_sync_03",
    "2011_09_26_drive_0091_sync_02",
    "2011_09_30_drive_0034_sync_03",
    "2011_09_30_drive_0034_sync_02",
    "2011_09_26_drive_0018_sync_02",
    "2011_09_26_drive_0018_sync_03",
    "2011_09_26_drive_0095_sync_03",
    "2011_09_26_drive_0095_sync_02",
    "2011_10_03_drive_0034_sync_02",
    "2011_10_03_drive_0034_sync_03",
    "2011_09_26_drive_0014_sync_03",
    "2011_09_26_drive_0014_sync_02",
    "2011_09_28_drive_0034_sync_03",
    "2011_09_28_drive_0034_sync_02",
    "2011_09_26_drive_0057_sync_03",
    "2011_09_26_drive_0057_sync_02",
    "2011_09_28_drive_0039_sync_03",
    "2011_09_28_drive_0039_sync_02",
    "2011_09_30_drive_0033_sync_03",
    "2011_09_30_drive_0033_sync_02",
    "2011_09_26_drive_0087_sync_03",
    "2011_09_26_drive_0087_sync_02",
    "2011_09_26_drive_0028_sync_02",
    "2011_09_26_drive_0028_sync_03",
    "2011_09_26_drive_0015_sync_03",
    "2011_09_26_drive_0015_sync_02",
    "2011_09_26_drive_0113_sync_02",
    "2011_09_26_drive_0113_sync_03",
    "2011_09_26_drive_0001_sync_02",
    "2011_09_26_drive_0001_sync_03",
    "2011_09_28_drive_0038_sync_03",
    "2011_09_28_drive_0038_sync_02",
    "2011_09_29_drive_0026_sync_02",
    "2011_09_29_drive_0026_sync_03",
    "2011_09_26_drive_0079_sync_02",
    "2011_09_26_drive_0079_sync_03",
    "2011_10_03_drive_0042_sync_02",
    "2011_10_03_drive_0042_sync_03",
    "2011_09_28_drive_0001_sync_02",
    "2011_09_28_drive_0001_sync_03",
    "2011_09_26_drive_0019_sync_02",
    "2011_09_26_drive_0019_sync_03",
    "2011_09_26_drive_0051_sync_03",
    "2011_09_26_drive_0051_sync_02",
    "2011_09_26_drive_0011_sync_02",
    "2011_09_26_drive_0011_sync_03",
    "2011_09_28_drive_0045_sync_02",
    "2011_09_28_drive_0045_sync_03",
    "2011_09_26_drive_0104_sync_02",
    "2011_09_26_drive_0104_sync_03",
    "2011_09_28_drive_0037_sync_02",
    "2011_09_28_drive_0037_sync_03",
    "2011_09_26_drive_0005_sync_03",
    "2011_09_26_drive_0005_sync_02",
    "2011_09_26_drive_0070_sync_02",
    "2011_09_26_drive_0070_sync_03",
    "2011_09_26_drive_0039_sync_03",
    "2011_09_26_drive_0039_sync_02",
    "2011_09_26_drive_0022_sync_03",
    "2011_09_26_drive_0022_sync_02",
    "2011_09_26_drive_0035_sync_02",
    "2011_09_26_drive_0035_sync_03",
    "2011_09_29_drive_0004_sync_02",
    "2011_09_29_drive_0004_sync_03",
    "2011_09_30_drive_0020_sync_03",
    "2011_09_30_drive_0020_sync_02"
]

class RamImage():
    def __init__(self, path, type = 'color'):
        
        self.preprocessed = False
        self.type         = type

        fd = open(path, 'rb')
        img_str = fd.read()
        fd.close()

        self.img_raw = np.frombuffer(img_str, np.uint8)

    def preprocess(self):
        if not self.preprocessed:
            self.img          = self.to_numpy()
            self.img_raw      = None
            self.preprocessed = True

    def to_numpy(self):
        if self.preprocessed:
            return self.img
        
        if self.type == 'greyscale':
            return np.expand_dims(cv2.imdecode(self.img_raw, cv2.IMREAD_GRAYSCALE).astype(np.float32), axis=0) / 255.0

        return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR).astype(np.float32).transpose(2, 0, 1) / 255.0

class KittiSample(data.Dataset):
    def __init__(self, root_path: str, data_path: str, size: Tuple[int, int], length: int):

        data_path = os.path.join(root_path, data_path)

        frames = []
        depths = []
        pred_depths     = []
        pred_pose       = []
        self.size       = size
        self.length     = length

        for file in os.listdir(data_path):
            if (file.endswith(".jpg") or file.endswith(".png")) and not file.endswith(".depth.jpg") and not file.endswith(".mask.jpg") and not file.endswith(".centers.jpg"):
                frames.append(os.path.join(data_path, file))
            if file.endswith(".depth.jpg"):
                depths.append(os.path.join(data_path, file))

        frames.sort()
        depths.sort()

        if len(frames) != len(depths):
            print(f"Error: {len(frames)} != {len(depths)}, {data_path}")
            assert False

        self.imgs = []
        for path in frames:
            self.imgs.append(RamImage(path))

        self.depths = []
        for path in depths:
            self.depths.append(RamImage(path, 'greyscale'))

    def __len__(self):
        return max(0, len(self.imgs) - self.length + 1)

    def get_data(self, index):
        
        frames = np.zeros((self.length,3,self.size[1], self.size[0]),dtype=np.float32)
        depths = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32)

        for i in range(index, index+self.length):
            frames[i-index] = self.imgs[i].to_numpy()
            depths[i-index] = self.depths[i].to_numpy()

        return frames, depths

class KittiDataset(data.Dataset):

    def save(self):
        state = { 'samples': self.samples, 'length': self.length }
        with open(self.file, "wb") as outfile:
    	    pickle.dump(state, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            state = pickle.load(infile)
            self.samples = state['samples']
            self.length  = state['length']

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], length: int):

        data_path   = f'data/data/video/{dataset_name}'
        data_path   = os.path.join(root_path, data_path)
        self.file   = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle')
        self.type   = type
        self.length = 0

        self.samples = []

        if os.path.exists(self.file):
            self.load()
            print(f"Loaded KITTI {type} {size[0]}x{size[1]} {self.length}", flush=True)
        else:

            if self.type == "train":
                for i, dir in enumerate(kitti_train):
                    sample = KittiSample(data_path, dir, size, length)
                    if len(sample) > 0:
                        self.samples.append(sample)
                        self.length += len(sample)

                    print(f"Loading KITTI {type} {size[0]}x{size[1]} [{i * 100 / len(kitti_train):.2f}]", flush=True)
            elif self.type == "val":
                for i, dir in enumerate(kitti_val):
                    sample = KittiSample(data_path, dir, size, length)
                    if len(sample) > 0:
                        self.samples.append(sample)
                        self.length += len(sample)

                    print(f"Loading KITTI {type} {size[0]}x{size[1]} [{i * 100 / len(kitti_val):.2f}]", flush=True)
            else:
                for i, dir in enumerate(kitti_test):
                    sample = KittiSample(data_path, dir, [1242, 375], 1)
                    if len(sample) > 0:
                        self.samples.append(sample)
                        self.length += len(sample)

                    print(f"Loading KITTI {type} {size[0]}x{size[1]} [{i * 100 / len(kitti_test):.2f}]", flush=True)

            self.save()

    def __len__(self):
        return self.length

    def __getitem__(self, index: int):
        
        for i, s in enumerate(self.samples):
            if len(s) > index:
                return s.get_data(index)

            index -= len(s)

        assert False
        return None
