from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import random
from PIL import Image
from einops import rearrange, reduce
import scipy.ndimage


kitti_test = ["test"]

kitti_val = [
    "2011_09_26_drive_0017_sync_02",
    "2011_09_26_drive_0017_sync_03",
    "2011_09_28_drive_0043_sync_03",
    "2011_09_28_drive_0043_sync_02",
    "2011_09_26_drive_0032_sync_02",
    "2011_09_26_drive_0032_sync_03",
    "2011_09_26_drive_0061_sync_02",
    "2011_09_26_drive_0061_sync_03",
    "2011_09_28_drive_0047_sync_03",
    "2011_09_28_drive_0047_sync_02"
]

kitti_train = [
    "2011_09_30_drive_0028_sync_02",
    "2011_09_30_drive_0028_sync_03",
    "2011_09_26_drive_0091_sync_03",
    "2011_09_26_drive_0091_sync_02",
    "2011_09_30_drive_0034_sync_03",
    "2011_09_30_drive_0034_sync_02",
    "2011_09_26_drive_0018_sync_02",
    "2011_09_26_drive_0018_sync_03",
    "2011_09_26_drive_0095_sync_03",
    "2011_09_26_drive_0095_sync_02",
    "2011_10_03_drive_0034_sync_02",
    "2011_10_03_drive_0034_sync_03",
    "2011_09_26_drive_0014_sync_03",
    "2011_09_26_drive_0014_sync_02",
    "2011_09_28_drive_0034_sync_03",
    "2011_09_28_drive_0034_sync_02",
    "2011_09_26_drive_0057_sync_03",
    "2011_09_26_drive_0057_sync_02",
    "2011_09_28_drive_0039_sync_03",
    "2011_09_28_drive_0039_sync_02",
    "2011_09_30_drive_0033_sync_03",
    "2011_09_30_drive_0033_sync_02",
    "2011_09_26_drive_0087_sync_03",
    "2011_09_26_drive_0087_sync_02",
    "2011_09_26_drive_0028_sync_02",
    "2011_09_26_drive_0028_sync_03",
    "2011_09_26_drive_0015_sync_03",
    "2011_09_26_drive_0015_sync_02",
    "2011_09_26_drive_0113_sync_02",
    "2011_09_26_drive_0113_sync_03",
    "2011_09_26_drive_0001_sync_02",
    "2011_09_26_drive_0001_sync_03",
    "2011_09_28_drive_0038_sync_03",
    "2011_09_28_drive_0038_sync_02",
    "2011_09_29_drive_0026_sync_02",
    "2011_09_29_drive_0026_sync_03",
    "2011_09_26_drive_0079_sync_02",
    "2011_09_26_drive_0079_sync_03",
    "2011_10_03_drive_0042_sync_02",
    "2011_10_03_drive_0042_sync_03",
    "2011_09_28_drive_0001_sync_02",
    "2011_09_28_drive_0001_sync_03",
    "2011_09_26_drive_0019_sync_02",
    "2011_09_26_drive_0019_sync_03",
    "2011_09_26_drive_0051_sync_03",
    "2011_09_26_drive_0051_sync_02",
    "2011_09_26_drive_0011_sync_02",
    "2011_09_26_drive_0011_sync_03",
    "2011_09_28_drive_0045_sync_02",
    "2011_09_28_drive_0045_sync_03",
    "2011_09_26_drive_0104_sync_02",
    "2011_09_26_drive_0104_sync_03",
    "2011_09_28_drive_0037_sync_02",
    "2011_09_28_drive_0037_sync_03",
    "2011_09_26_drive_0005_sync_03",
    "2011_09_26_drive_0005_sync_02",
    "2011_09_26_drive_0070_sync_02",
    "2011_09_26_drive_0070_sync_03",
    "2011_09_26_drive_0039_sync_03",
    "2011_09_26_drive_0039_sync_02",
    "2011_09_26_drive_0022_sync_03",
    "2011_09_26_drive_0022_sync_02",
    "2011_09_26_drive_0035_sync_02",
    "2011_09_26_drive_0035_sync_03",
    "2011_09_29_drive_0004_sync_02",
    "2011_09_29_drive_0004_sync_03",
    "2011_09_30_drive_0020_sync_03",
    "2011_09_30_drive_0020_sync_02"
]

class RamImage():
    def __init__(self, img_path, depth_path = None, rgb = True):
        self.rgb = rgb
        
        fd = open(img_path, 'rb')
        img_str = fd.read()
        fd.close()

        self.img_raw = np.frombuffer(img_str, np.uint8)

        if depth_path is not None:
            self.depth = np.load(depth_path).astype(np.float32)
            if len(self.depth.shape) == 2:
                self.depth = rearrange(self.depth, 'h w -> 1 h w')
        else:
            self.depth = None


    def to_numpy(self):
        img = cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR if self.rgb else cv2.IMREAD_GRAYSCALE)
        #if self.rgb:
        #    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        return (img / 255.0).astype(np.float32)

class RandomHorizontalFlip(object):
    """Randomly horizontally flips the given numpy array with a probability of 0.5"""

    def __call__(self, images, intrinsics):
        assert intrinsics is not None
        if random.random() < 0.5:
            output_intrinsics = np.copy(intrinsics)
            output_images = [np.copy(np.fliplr(im)) for im in images]
            w = output_images[0].shape[1]
            output_intrinsics[0, 2] = w - output_intrinsics[0, 2]
        else:
            output_images = images
            output_intrinsics = intrinsics
        return output_images, output_intrinsics


class RandomScaleCrop(object):
    """Randomly zooms images up to 15% and crop them to keep same size as before."""

    def __call__(self, images, intrinsics):
        assert intrinsics is not None
        output_intrinsics = np.copy(intrinsics)

        in_h, in_w, _ = images[0].shape
        x_scaling, y_scaling = np.random.uniform(1, 1.15, 2)
        scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling)

        output_intrinsics[0] *= x_scaling
        output_intrinsics[1] *= y_scaling
        scaled_images = [scipy.ndimage.zoom(im, (y_scaling, x_scaling, 1), order=1) for im in images]

        offset_y = np.random.randint(scaled_h - in_h + 1)
        offset_x = np.random.randint(scaled_w - in_w + 1)
        cropped_images = [im[offset_y:offset_y + in_h, offset_x:offset_x + in_w] for im in scaled_images]

        output_intrinsics[0, 2] -= offset_x
        output_intrinsics[1, 2] -= offset_y

        return cropped_images, output_intrinsics

class KittiSample(data.Dataset):
    def __init__(self, root_path: str, data_path: str, size: Tuple[int, int], length: int, type):

        data_path = os.path.join(root_path, data_path)

        frames          = []
        gt_depths       = []
        pred_depths     = []
        pred_pose       = []
        self.size       = size
        self.length     = length
        self.type       = type
        self.flip       = RandomHorizontalFlip()
        self.crop       = RandomScaleCrop()

        for file in os.listdir(data_path):
            if (file.endswith(".jpg") or file.endswith(".png")) and not file.endswith(".depth.jpg") and not file.endswith(".mask.jpg") and not file.endswith(".centers.jpg"):
                frames.append(os.path.join(data_path, file))
            if file.endswith(".npy") and not file.endswith(".depth.npy") and not file.endswith(".pose.npy"):
                gt_depths.append(os.path.join(data_path, file))
            if file.endswith(".depth.npy"):
                pred_depths.append(os.path.join(data_path, file))
            if file.endswith(".pose.npy"):
                pred_pose.append(os.path.join(data_path, file))

        frames.sort()
        gt_depths.sort()
        pred_depths.sort()
        pred_pose.sort()
        self.imgs = []
        for img_path, depth_path in zip(frames, gt_depths):
            self.imgs.append(RamImage(img_path, depth_path if type != 'train' else None))

        self.pred_depths = []
        for path in pred_depths:
            self.pred_depths.append(np.load(path))

        self.pred_pose = []
        for path in pred_pose:
            self.pred_pose.append(np.load(path))

        if len(gt_depths) == 0:
            for img_path in frames:
                self.imgs.append(RamImage(img_path))

        self.intrinsics = np.genfromtxt(os.path.join(data_path, 'cam.txt')).astype(np.float32).reshape((3, 3))

        print(self.size)
        print(self.intrinsics)

        if self.size[0] == 208:
            print("dowscalling about 0.25")
            self.intrinsics[0] *= 0.25
            self.intrinsics[1] *= 0.25
        elif self.size[0] == 96:
            print(f"croping about {(832-384)//2} and dowscalling about 0.25")
            self.intrinsics[0, 2] -= (832-384)//2
            self.intrinsics[0] *= 0.25
            self.intrinsics[1] *= 0.25
        elif self.size[0] == 384:
            print(f"croping about {(832-384)//2}")
            self.intrinsics[0, 2] -= (832-384)//2
        elif self.size[0] < 832:
            assert False, f"unknowen size: {self.size}"
            
        print(self.intrinsics)

    def __len__(self):
        return max(0, len(self.imgs) - self.length + 1)

    def get_data(self, index):
        
        intrinsics  = self.intrinsics
        frames      = np.zeros((self.length,3,self.size[1], self.size[0]),dtype=np.float32)
        gt_depths   = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32)
        pred_depths = np.zeros((self.length,1,self.size[1], self.size[0]),dtype=np.float32)
        pred_pose   = np.zeros((self.length,6),dtype=np.float32)

        if self.type != "train":
            frames[0]    = self.imgs[0].to_numpy().transpose(2, 0, 1)
            gt_depths[0] = self.imgs[0].depth

            return frames, intrinsics, gt_depths

        imgs = []
        for i in range(index, index+self.length):
            imgs.append(self.imgs[i].to_numpy())
            imgs.append(np.expand_dims(self.pred_depths[i][0,0], 2))
            gt_depths[i-index] = self.imgs[i].depth

            if i < len(self.pred_pose):
                pred_pose[i-index] = self.pred_pose[i]

        if self.type == 'train':
            imgs, intrinsics = self.flip(imgs, intrinsics)
            imgs, intrinsics = self.crop(imgs, intrinsics)

        for i, img in enumerate(imgs):
            if i % 2 == 0:
                frames[i//2] = img.transpose(2, 0, 1)
            else:
                pred_depths[i//2] = img.transpose(2, 0, 1)

        #FIXME Noramlize depth with simoid!!!

        #frames = (frames - 0.45) / 0.225
        return frames, intrinsics, gt_depths, pred_depths, pred_pose

class KittiDataset(data.Dataset):

    def save(self):
        state = { 'samples': self.samples, 'length': self.length }
        with open(self.file, "wb") as outfile:
    	    pickle.dump(state, outfile)

    def load(self):
        with open(self.file, "rb") as infile:
            state = pickle.load(infile)
            self.samples = state['samples']
            self.length  = state['length']

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], length: int):

        data_path   = f'data/data/video/{dataset_name}'
        data_path   = os.path.join(root_path, data_path)
        self.file   = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle')
        self.type   = type
        self.length = 0

        with open(os.path.join(data_path, 'mean-std.pickle'), 'rb') as f:
            loaded_dict = pickle.load(f)
            self.depth_mean = loaded_dict['mean']
            self.depth_std  = loaded_dict['std']

        self.samples    = []

        if os.path.exists(self.file):
            self.load()
        else:

            if self.type == "train":
                for i, dir in enumerate(kitti_train):
                    sample = KittiSample(data_path, dir, size, length, type)
                    if len(sample) > 0:
                        self.samples.append(sample)
                        self.length += len(sample)

                    print(f"Loading KITTI {type} [{i * 100 / len(kitti_train):.2f}]", flush=True)
            elif self.type == "val":
                for i, dir in enumerate(kitti_val):
                    sample = KittiSample(data_path, dir, size, 2, type)
                    if len(sample) > 0:
                        self.samples.append(sample)
                        self.length += len(sample)

                    print(f"Loading KITTI {type} [{i * 100 / len(kitti_val):.2f}]", flush=True)
            else:
                for i, dir in enumerate(kitti_test):
                    sample = KittiSample(data_path, dir, [1242, 375], 1, type)
                    if len(sample) > 0:
                        self.samples.append(sample)
                        self.length += len(sample)

                    print(f"Loading KITTI {type} [{i * 100 / len(kitti_test):.2f}]", flush=True)

            self.save()

    def __len__(self):
        return self.length

    def __getitem__(self, index: int):
        
        for i, s in enumerate(self.samples):
            if len(s) > index:
                frames, intrinsics, gt_depths, pred_depths, pred_pose = s.get_data(index)
                pred_depths = 1 / (1 + np.exp((pred_depths - self.depth_mean) / self.depth_std))
                return frames, intrinsics, gt_depths, pred_depths, pred_pose

            index -= len(s)

        assert False
        return None
