import os
import os.path as osp
from os import PathLike
import cv2
import pickle
import torch
from torch.nn import functional as F
import numpy as np
import torch.distributed as dist
from torchvision.datasets import ImageNet as ImageNet_pytorch
from PIL import Image
import scipy

from torchvision.transforms import Compose
from .base_dataset import BaseDataset
from utils.parallel import get_dist_info
from .utils import download_and_extract_archive, check_integrity

def expanduser(path):
    if isinstance(path, (str, PathLike)):
        return osp.expanduser(path)
    else:
        return path

def depth_preprocess(depth_np):
    min_depth_val, max_depth_val = 0.5, 10
    depth_np = np.clip(depth_np, min_depth_val, max_depth_val)
    depth_np = (depth_np - min_depth_val) / (max_depth_val - min_depth_val)
    depth_np = np.tile(depth_np[:, :, np.newaxis], (1, 1, 3))
    
    return Image.fromarray((depth_np * 255).astype(np.uint8))


def load_mat():
    path = "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
    # data = scipy.io.loadmat(path)
    import h5py
    data = h5py.File(path)
    for key in data.keys():
        print(key, data[key])
    images = np.array(data['depths'])
    print(images.shape)
    print(images.mean(), images.max(), images.min())
    print((images > 5).sum())
    print(data.keys(), data.values())

def default_label_process(label_np, resized_size):
    label_np = cv2.resize(label_np, resized_size[::-1], interpolation=cv2.INTER_NEAREST)
    return label_np.astype(np.int32)


class NYUDepthV2Pair(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    def __init__(self,
                 root,
                 photo_transforms,
                 depth_transforms,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if photo_transforms is None:
            photo_transforms = self.DEFAULT_TRNASFORMS
        self.photo_transforms = photo_transforms
        if depth_transforms is None:
            depth_transforms = self.DEFAULT_TRNASFORMS
        self.depth_transforms = depth_transforms
        
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
    
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
        test_tag = 'val' if self.test_mode else 'train'

        image_dir = os.path.join(self.root, f'{test_tag}_images')
        depth_dir = os.path.join(self.root, f'{test_tag}_depth_images')
        image_depth_pairs = []
        for file in os.listdir(image_dir):
            if file.endswith('png'):
                image_path = os.path.join(image_dir, file)
                depth_path = os.path.join(depth_dir, file)
                image_depth_pairs.append({'image':image_path,
                                          'depth':depth_path})

        data_infos = image_depth_pairs
        return data_infos

    def __getitem__(self, idx):
        image_path, depth_path = self.data_infos[idx]['image'], self.data_infos[idx]['depth']
        image = Image.open(image_path)
        if self.photo_transforms is not None:
            image = self.photo_transforms(image)
        depth = Image.open(depth_path)
        if self.depth_transforms is not None:
            depth = self.depth_transforms(depth)
        return image, depth


class NYUDepthV2Photo(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    def __init__(self,
                 root,
                 photo_transforms,
                 resized_size,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if photo_transforms is None:
            photo_transforms = self.DEFAULT_TRNASFORMS
        self.photo_transforms = photo_transforms
        self.resized_size = resized_size
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
        self.contrastive = False
    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
        test_tag = 'val' if self.test_mode else 'train'

        label_dir = os.path.join(self.root, f'{test_tag}_labels')
        image_dir = os.path.join(self.root, f'{test_tag}_images')
        image_label_pairs = []
        for file in os.listdir(image_dir):
            if file.endswith('png'):
                image_path = os.path.join(image_dir, file)
                label_path = os.path.join(label_dir, file)
                image_label_pairs.append({'image':image_path,
                                          'label':label_path})
        data_infos = image_label_pairs
        return data_infos

    def __getitem__(self, idx):
        image_path, label_path = self.data_infos[idx]['image'],  self.data_infos[idx]['label']
        image = Image.open(image_path)
        label = Image.open(label_path)
        if self.photo_transforms is not None:
            # sketch_img = Image.open(sketch_path)
            
            if self.contrastive:
                image = self.photo_transforms(image)
                label = np.array(label)
            else:
                image, label = self.photo_transforms(image, label)
        
        # if self.target_transforms is not None:
        #     label = self.target_transforms(label)
        return image, label


class NYUDepthV2Depth(BaseDataset):
    """Since there is prepared dataset class in pytorch, we just wrap it here.
    """
    def __init__(self,
                 root,
                 depth_transforms,
                 resized_size,
                 target_transforms=None,
                 classes=None,
                 ann_file=None,
                 test_mode=False):
        super(BaseDataset, self).__init__()
        self.root = expanduser(root)
        if depth_transforms is None:
            depth_transforms = self.DEFAULT_TRNASFORMS
        self.depth_transforms = depth_transforms
        self.resized_size = resized_size
        if target_transforms is not None:
            target_transforms = target_transforms
        self.target_transforms = target_transforms
        self.CLASSES = self.get_classes(classes)
        self.ann_file = expanduser(ann_file)
        self.test_mode = test_mode
        self.data_infos = self.load_annotations()
        self.contrastive = False

    def load_annotations(self):
        # rank, world_size = get_dist_info()
        # self.root: "/mnt/beegfs/hlinbh/datasets/NYUDepthV2/nyu_depth_v2_labeled.mat"
        test_tag = 'val' if self.test_mode else 'train'

        label_dir = os.path.join(self.root, f'{test_tag}_labels')
        depth_dir = os.path.join(self.root, f'{test_tag}_depth_images')
        depth_label_pairs = []
        for file in os.listdir(depth_dir):
            if file.endswith('png'):
                depth_path = os.path.join(depth_dir, file)
                label_path = os.path.join(label_dir, file)
                depth_label_pairs.append({'depth':depth_path,
                                          'label':label_path})

        data_infos = depth_label_pairs
        return data_infos

    def __getitem__(self, idx):
        depth_path, label_path = self.data_infos[idx]['depth'],  self.data_infos[idx]['label']
        depth = Image.open(depth_path)
        label = Image.open(label_path)
        if self.depth_transforms is not None:

            if self.contrastive:
                depth = self.depth_transforms(depth)
                label = np.array(label)
            else:
                depth, label = self.depth_transforms(depth, label)
            
        
        # if self.target_transforms is not None:
        #     label = self.target_transforms(label)
        return depth, label

# if __name__ == '__main__':
#     load_mat()