import os
import os.path as osp
from os import PathLike
import pickle
import torch
from torch.nn import functional as F
import numpy as np
import torch.distributed as dist
from torchvision.datasets import ImageNet as ImageNet_pytorch
from PIL import Image
import scipy

from torchvision.transforms import Compose
from .base_dataset import BaseDataset
from utils.parallel import get_dist_info
from .utils import download_and_extract_archive, check_integrity

def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path

def depth_preprocess(depth_np):
    min_depth_val, max_depth_val = 0.5, 10
    depth_np = np.clip(depth_np, min_depth_val, max_depth_val)
    depth_np = (depth_np - min_depth_val) / (max_depth_val - min_depth_val)
    depth_np = np.tile(depth_np[:, :, np.newaxis], (1, 1, 3))
    
    return Image.fromarray((depth_np * 255).astype(np.uint8))


def load_mat():
    path = "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
    # data = scipy.io.loadmat(path)
    import h5py
    data = h5py.File(path)
    for key in data.keys():
        print(key, data[key])
    images = np.array(data['depths'])
    print(images.shape)
    print(images.mean(), images.max(), images.min())
    print((images > 5).sum())
    print(data.keys(), data.values())

class NYUDepthV2Pair(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    def __init__(self,
                 root,
                 photo_transforms,
                 depth_transforms,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if photo_transforms is None:
            photo_transforms = self.DEFAULT_TRNASFORMS
        self.photo_transforms = photo_transforms
        if depth_transforms is None:
            depth_transforms = self.DEFAULT_TRNASFORMS
        self.depth_transforms = depth_transforms
        
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
    
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
        data_path = os.path.join(self.root, 'nyu_depth_v2_labeled.mat')
        import h5py
        data = h5py.File(data_path)
        images = np.array(data['images'])
        depths = np.array(data['depths'])
        image_depth_pairs = []
        for idx in range(len(images)):
            image_depth_pairs.append({'image':images[idx], 'depth':depths[idx]})
        
        data_infos = image_depth_pairs
        return data_infos

    def __getitem__(self, idx):
        image_np, depth_np = self.data_infos[idx]['image'], self.data_infos[idx]['depth']
        if self.photo_transforms is not None:
            image = Image.fromarray(image_np.transpose((2, 1, 0)).astype(np.uint8))
            image = self.photo_transforms(image)
        if self.depth_transforms is not None:
            # sketch_img = Image.open(sketch_path)
            depth = depth_preprocess(depth_np.transpose((1,0)))
            depth = self.depth_transforms(depth)
        return image, depth


class NYUDepthV2Photo(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    def __init__(self,
                 root,
                 photo_transforms,
                 resized_size,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if photo_transforms is None:
            photo_transforms = self.DEFAULT_TRNASFORMS
        self.photo_transforms = photo_transforms
        self.resized_size = resized_size
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
    
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
        data_path = os.path.join(self.root, 'nyu_depth_v2_labeled.mat')
        import h5py
        data = h5py.File(data_path)
        images = np.array(data['images'])
        labels = np.array(data['labels'])
        image_label_pairs = []
        for idx in range(len(images)):
            image_label_pairs.append({'image':images[idx], 'label':labels[idx]})
        
        data_infos = image_label_pairs
        return data_infos

    def __getitem__(self, idx):
        image_np, label_np = self.data_infos[idx]['image'],  self.data_infos[idx]['label']
        if self.photo_transforms is not None:
            # sketch_img = Image.open(sketch_path)
            image = Image.fromarray(image_np.transpose((2, 1, 0)).astype(np.uint8))
            image = self.photo_transforms(image)
        
        label = torch.from_numpy(label_np.transpose((1,0)).astype(np.int64))
        label = F.interpolate(label.unsqueeze(0).unsqueeze(0), self.resized_size)
        label = label.to(torch.long)[0,0]
        if self.target_transforms is not None:
            label = self.target_transforms(label)
        return image, label


class NYUDepthV2Depth(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    def __init__(self,
                 root,
                 depth_transforms,
                 resized_size,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if depth_transforms is None:
            depth_transforms = self.DEFAULT_TRNASFORMS
        self.depth_transforms = depth_transforms
        self.resized_size = resized_size
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
    
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
        data_path = os.path.join(self.root, 'nyu_depth_v2_labeled.mat')
        import h5py
        data = h5py.File(data_path)
        depths = np.array(data['depths'])
        labels = np.array(data['labels'])
        
        depth_label_pairs = []
        for idx in range(len(depths)):
            depth_label_pairs.append({'depth':depths[idx], 'label':labels[idx]})
        
        data_infos = depth_label_pairs
        return data_infos

    def __getitem__(self, idx):
        depth_np, label_np = self.data_infos[idx]['depth'],  self.data_infos[idx]['label']
        if self.depth_transforms is not None:
            # sketch_img = Image.open(sketch_path)
            depth = depth_preprocess(depth_np.transpose((1,0)))
            depth = self.depth_transforms(depth)
        
        label = torch.from_numpy(label_np.transpose((1,0)).astype(np.float32))
        label = F.interpolate(label.unsqueeze(0).unsqueeze(0), self.resized_size, mode='nearest')
        label = label.to(torch.long)[0,0]
        if self.target_transforms is not None:
            label = self.target_transforms(label)
        return depth, label

# if __name__ == '__main__':
#     load_mat()