import glob
import os

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from src.common import as_intrinsics_matrix
from torch.utils.data import Dataset
import logging

def readEXR_onlydepth(filename):
    """
    Read depth data from EXR image file.

    Args:
        filename (str): File path.

    Returns:
        Y (numpy.array): Depth buffer in float32 format.
    """
    # move the import here since only CoFusion needs these package
    # sometimes installation of openexr is hard, you can run all other datasets
    # even without openexr
    import Imath
    import OpenEXR as exr

    exrfile = exr.InputFile(filename)
    header = exrfile.header()
    dw = header['dataWindow']
    isize = (dw.max.y - dw.min.y + 1, dw.max.x - dw.min.x + 1)

    channelData = dict()

    for c in header['channels']:
        C = exrfile.channel(c, Imath.PixelType(Imath.PixelType.FLOAT))
        C = np.fromstring(C, dtype=np.float32)
        C = np.reshape(C, isize)

        channelData[c] = C

    Y = None if 'Y' not in header['channels'] else channelData['Y']

    return Y


def get_dataset(cfg, args, device='cuda'):
    return dataset_dict[cfg['dataset']](cfg, args, device=device)


class BaseDataset(Dataset):
    def __init__(self, cfg, args, device='cuda'
                 ):
        super(BaseDataset, self).__init__()
        self.name = cfg['dataset']
        self.device = device
        self.png_depth_scale = cfg['cam']['png_depth_scale']

        self.H, self.W, self.fx, self.fy, self.cx, self.cy = cfg['cam']['H'], cfg['cam'][
            'W'], cfg['cam']['fx'], cfg['cam']['fy'], cfg['cam']['cx'], cfg['cam']['cy']

        self.distortion = np.array(
            cfg['cam']['distortion']) if 'distortion' in cfg['cam'] else None
        self.crop_size = cfg['cam']['crop_size'] if 'crop_size' in cfg['cam'] else None

        if args.input_folder is None:
            self.input_folder = cfg['data']['input_folder']
        else:
            self.input_folder = args.input_folder

        self.crop_edge = cfg['cam']['crop_edge']

    def __len__(self):
        return self.n_img

    @staticmethod
    def set_edge_pixels_to_zero(depth_data, crop_edge):
        mask = torch.ones_like(depth_data)
        mask[:crop_edge, :] = 0
        mask[-crop_edge:, :] = 0
        mask[:, :crop_edge] = 0
        mask[:, -crop_edge:] = 0

        depth_data = depth_data * mask
        return depth_data

    def __getitem__(self, index):
        color_path = self.color_paths[index]
        depth_path = self.depth_paths[index]
        color_data = cv2.imread(color_path)
        if '.png' in depth_path:
            depth_data = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        elif '.exr' in depth_path:
            depth_data = readEXR_onlydepth(depth_path)
        if self.distortion is not None:
            K = as_intrinsics_matrix([self.fx, self.fy, self.cx, self.cy])
            # undistortion is only applied on color image, not depth!
            color_data = cv2.undistort(color_data, K, self.distortion)

        color_data = cv2.cvtColor(color_data, cv2.COLOR_BGR2RGB)
        color_data = color_data / 255.
        # print("range of depth: ", depth_data.min(), depth_data.max())
        depth_data = depth_data.astype(np.float32) / self.png_depth_scale
        H, W = depth_data.shape
        color_data = cv2.resize(color_data, (W, H))
        color_data = torch.from_numpy(color_data)
        depth_data = torch.from_numpy(depth_data)
        if self.crop_size is not None:
            # follow the pre-processing step in lietorch, actually is resize
            color_data = color_data.permute(2, 0, 1) # HWC -> CHW
            color_data = F.interpolate(
                color_data[None], self.crop_size, mode='bilinear', align_corners=True)[0]
            depth_data = F.interpolate(
                depth_data[None, None], self.crop_size, mode='nearest')[0, 0]
            color_data = color_data.permute(1, 2, 0).contiguous() # CHW -> HWC

        edge = self.crop_edge
        if edge > 0:
            color_data = color_data[edge:-edge, edge:-edge]
            depth_data = depth_data[edge:-edge, edge:-edge]
        pose = self.poses[index]
        return index, color_data.to(self.device), depth_data.to(self.device), pose.to(self.device)


class Replica(BaseDataset):
    def __init__(self, cfg, args, device='cuda'
                 ):
        super(Replica, self).__init__(cfg, args, device)
        self.color_paths = sorted(
            glob.glob(f'{self.input_folder}/results/frame*.jpg'))
        self.depth_paths = sorted(
            glob.glob(f'{self.input_folder}/results/depth*.png'))
        self.n_img = len(self.color_paths)
        self.load_poses(f'{self.input_folder}/traj.txt')

    def load_poses(self, path):
        self.poses = []
        with open(path, "r") as f:
            lines = f.readlines()
        for i in range(self.n_img):
            line = lines[i]
            c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
            # The codebase assumes that the camera coordinate system is X left to right,
            # Y down to up and Z in the negative viewing direction. Most datasets assume
            # X left to right, Y up to down and Z in the positive viewing direction.
            # Therefore, we need to rotate the camera coordinate system.
            # Multiplication of R_x (rotation aroun X-axis 180 degrees) from the right.
            c2w[:3, 1] *= -1
            c2w[:3, 2] *= -1
            c2w = torch.from_numpy(c2w).float()
            self.poses.append(c2w)


class ScanNet(BaseDataset):
    def __init__(self, cfg, args, device='cuda'
                 ):
        super(ScanNet, self).__init__(cfg, args, device)
        self.input_folder = os.path.join(self.input_folder, 'frames')
        self.color_paths = sorted(glob.glob(os.path.join(
            self.input_folder, 'color', '*.jpg')), key=lambda x: int(os.path.basename(x)[:-4]))
        self.depth_paths = sorted(glob.glob(os.path.join(
            self.input_folder, 'depth', '*.png')), key=lambda x: int(os.path.basename(x)[:-4]))
        self.load_poses(os.path.join(self.input_folder, 'pose'))
        self.n_img = len(self.color_paths)

    def load_poses(self, path):
        self.poses = []
        pose_paths = sorted(glob.glob(os.path.join(path, '*.txt')),
                            key=lambda x: int(os.path.basename(x)[:-4]))
        for pose_path in pose_paths:
            with open(pose_path, "r") as f:
                lines = f.readlines()
            ls = []
            for line in lines:
                l = list(map(float, line.split(' ')))
                ls.append(l)
            c2w = np.array(ls).reshape(4, 4)
            # The codebase assumes that the camera coordinate system is X left to right,
            # Y down to up and Z in the negative viewing direction. Most datasets assume
            # X left to right, Y up to down and Z in the positive viewing direction.
            # Therefore, we need to rotate the camera coordinate system.
            # Multiplication of R_x (rotation aroun X-axis 180 degrees) from the right.
            c2w[:3, 1] *= -1
            c2w[:3, 2] *= -1
            c2w = torch.from_numpy(c2w).float()
            self.poses.append(c2w)


class TUM_RGBD(BaseDataset):
    def __init__(self, cfg, args, device='cuda'
                 ):
        super(TUM_RGBD, self).__init__(cfg, args, device)
        self.color_paths, self.depth_paths, self.poses = self.loadtum(
            self.input_folder, frame_rate=32)
        self.n_img = len(self.color_paths)

    def parse_list(self, filepath, skiprows=0):
        """ read list data """
        data = np.loadtxt(filepath, delimiter=' ',
                          dtype=np.unicode_, skiprows=skiprows)
        return data

    def associate_frames(self, tstamp_image, tstamp_depth, tstamp_pose, max_dt=0.08):
        """ pair images, depths, and poses """
        associations = []
        for i, t in enumerate(tstamp_image):
            if tstamp_pose is None:
                j = np.argmin(np.abs(tstamp_depth - t))
                if (np.abs(tstamp_depth[j] - t) < max_dt):
                    associations.append((i, j))

            else:
                j = np.argmin(np.abs(tstamp_depth - t))
                k = np.argmin(np.abs(tstamp_pose - t))

                if (np.abs(tstamp_depth[j] - t) < max_dt) and \
                        (np.abs(tstamp_pose[k] - t) < max_dt):
                    associations.append((i, j, k))

        return associations

    def loadtum(self, datapath, frame_rate=-1):
        """ read video data in tum-rgbd format """
        if os.path.isfile(os.path.join(datapath, 'groundtruth.txt')):
            pose_list = os.path.join(datapath, 'groundtruth.txt')
        elif os.path.isfile(os.path.join(datapath, 'pose.txt')):
            pose_list = os.path.join(datapath, 'pose.txt')

        image_list = os.path.join(datapath, 'rgb.txt')
        depth_list = os.path.join(datapath, 'depth.txt')

        image_data = self.parse_list(image_list)
        depth_data = self.parse_list(depth_list)
        pose_data = self.parse_list(pose_list, skiprows=1)
        pose_vecs = pose_data[:, 1:].astype(np.float64)

        tstamp_image = image_data[:, 0].astype(np.float64)
        tstamp_depth = depth_data[:, 0].astype(np.float64)
        tstamp_pose = pose_data[:, 0].astype(np.float64)
        associations = self.associate_frames(
            tstamp_image, tstamp_depth, tstamp_pose)

        indicies = [0]
        for i in range(1, len(associations)):
            t0 = tstamp_image[associations[indicies[-1]][0]]
            t1 = tstamp_image[associations[i][0]]
            if t1 - t0 > 1.0 / frame_rate:
                indicies += [i]

        images, poses, depths, intrinsics = [], [], [], []
        inv_pose = None
        for ix in indicies:
            (i, j, k) = associations[ix]
            images += [os.path.join(datapath, image_data[i, 1])]
            depths += [os.path.join(datapath, depth_data[j, 1])]
            c2w = self.pose_matrix_from_quaternion(pose_vecs[k])
            if inv_pose is None:
                inv_pose = np.linalg.inv(c2w)
                c2w = np.eye(4)
            else:
                c2w = inv_pose@c2w

            # The codebase assumes that the camera coordinate system is X left to right,
            # Y down to up and Z in the negative viewing direction. Most datasets assume
            # X left to right, Y up to down and Z in the positive viewing direction.
            # Therefore, we need to rotate the camera coordinate system.
            # Multiplication of R_x (rotation aroun X-axis 180 degrees) from the right.
            c2w[:3, 1] *= -1
            c2w[:3, 2] *= -1
            c2w = torch.from_numpy(c2w).float()
            poses += [c2w]

        return images, depths, poses

    def pose_matrix_from_quaternion(self, pvec):
        """ convert 4x4 pose matrix to (t, q) """
        from scipy.spatial.transform import Rotation

        pose = np.eye(4)
        pose[:3, :3] = Rotation.from_quat(pvec[3:]).as_matrix()
        pose[:3, 3] = pvec[:3]
        return pose


class SemanticReplica(BaseDataset):
    def __init__(self, cfg, args, device='cuda'
                 ):
        super(SemanticReplica, self).__init__(cfg, args, device)
        self.rgb_dir = os.path.join(self.input_folder, "rgb")
        self.depth_dir = os.path.join(self.input_folder, "depth")  # depth is in mm uint
        self.semantic_class_dir = os.path.join(self.input_folder, "semantic_class")
        self.semantic_instance_dir = os.path.join(self.input_folder, "semantic_instance")
        if not os.path.exists(self.semantic_instance_dir):
            self.semantic_instance_dir = None       


        self.rgb_list = sorted(glob.glob(self.rgb_dir + '/rgb*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4]))
        self.depth_list = sorted(glob.glob(self.depth_dir + '/depth*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4]))
        self.semantic_list = sorted(glob.glob(self.semantic_class_dir + '/semantic_class_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4]))
        if self.semantic_instance_dir is not None:
            self.instance_list = sorted(glob.glob(self.semantic_instance_dir + '/semantic_instance_*.png'), key=lambda file_name: int(file_name.split("_")[-1][:-4]))
        self.n_img = len(self.rgb_list)
        
        self.load_poses(f'{self.input_folder}/traj.txt')
        
        #NOTE: Derive the semantic classes from the semantic images
        # self.samples = {'semantic': [], 'instance': []}
        # from tqdm import tqdm
        # for idx in tqdm(range(len(self.rgb_list))):
        #     try:
        #         image = cv2.imread(self.rgb_list[idx])[:,:,::-1] / 255.0  # change from BGR uinit 8 to RGB float
        #     except TypeError:
        #         print(self.rgb_list[idx])
    
                
        #     semantic = cv2.imread(self.semantic_list[idx], cv2.IMREAD_UNCHANGED)
        #     if self.semantic_instance_dir is not None:
        #         instance = cv2.imread(self.instance_list[idx], cv2.IMREAD_UNCHANGED) # uint16
        #     if (self.H is not None and self.H != image.shape[0]) or \
        #             (self.W is not None and self.W != image.shape[1]):
        #         semantic = cv2.resize(semantic, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
        #         if self.semantic_instance_dir is not None:
        #             instance = cv2.resize(instance, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
        #     self.samples["semantic"].append(semantic[None]) # HW
        #     if self.semantic_instance_dir is not None:
        #         self.samples["instance"].append(instance[None]) # HW
        # for key in self.samples.keys():  # transform list of np array to array with batch dimension
        #     self.samples[key] = np.concatenate(self.samples[key])    
        # self.semantic_classes = np.unique(self.samples["semantic"]).astype(np.uint8)
        # self.num_semantic_class = len(self.semantic_classes)

        
        self.semantic_classes = np.array(cfg['data']['semantic_classes']).astype(np.uint8)
        self.num_semantic_class = self.semantic_classes.shape[0]  # number of semantic classes, including the void class of 0
        logging.info("num semantic classes: {}".format(self.num_semantic_class))
        
        # self.samples["semantic_remap"] = self.samples["semantic"].copy()
        
        self.enable_semantic = cfg['model']['enable_semantic']

        # for i in range(self.num_semantic_class):
        #     self.samples["semantic_remap"][self.samples["semantic"]== self.semantic_classes[i]] = i
                 
    def load_poses(self, path):
        self.poses = []
        with open(path, "r") as f:
            lines = f.readlines()
        for i in range(self.n_img):
            line = lines[i]
            c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
            # The codebase assumes that the camera coordinate system is X left to right,
            # Y down to up and Z in the negative viewing direction. Most datasets assume
            # X left to right, Y up to down and Z in the positive viewing direction.
            # Therefore, we need to rotate the camera coordinate system.
            # Multiplication of R_x (rotation aroun X-axis 180 degrees) from the right.
            c2w[:3, 1] *= -1
            c2w[:3, 2] *= -1
            c2w = torch.from_numpy(c2w).float()
            self.poses.append(c2w)

    def __getitem__(self, index):
        
        color = cv2.imread(self.rgb_list[index])[:,:,::-1] / 255.0  # change from BGR uinit 8 to RGB float
        depth = cv2.imread(self.depth_list[index], cv2.IMREAD_UNCHANGED) / 1000.0  # uint16 mm depth, then turn depth from mm to meter

        semantic = cv2.imread(self.semantic_list[index], cv2.IMREAD_UNCHANGED)
        if self.semantic_instance_dir is not None:
            instance = cv2.imread(self.instance_list[index], cv2.IMREAD_UNCHANGED) # uint16

        if (self.H is not None and self.H != color.shape[0]) or \
                (self.W is not None and self.W != color.shape[1]):
            color = cv2.resize(color, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
            depth = cv2.resize(depth, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
            semantic = cv2.resize(semantic, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
            if self.semantic_instance_dir is not None:
                instance = cv2.resize(instance, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
        pose = self.poses[index]  
        
        semantic_remap = semantic.copy()     

        for i in range(self.num_semantic_class):
            semantic_remap[semantic == self.semantic_classes[i]] = i


        color = torch.from_numpy(color).float()
        depth = torch.from_numpy(depth).float()
        semantic = torch.from_numpy(semantic_remap).long()
        if self.crop_size is not None:
            color = color.permute(2, 0, 1) # HWC -> CHW
            color = F.interpolate(
                color[None], self.crop_size, mode='bilinear', align_corners=True)[0]
            depth = F.interpolate(
                depth[None, None], self.crop_size, mode='nearest')[0, 0]
            semantic = F.interpolate(
                semantic[None, None], self.crop_size, mode='nearest')[0, 0]
            color = color.permute(1, 2, 0).contiguous()

        edge = self.crop_edge
        if edge > 0:
            color = color[edge:-edge, edge:-edge]
            depth = depth[edge:-edge, edge:-edge]
            semantic = semantic[edge:-edge, edge:-edge]
        if self.enable_semantic:
            return index, color.to(self.device), depth.to(self.device), pose, semantic.to(self.device)
        else:
            return index, color.to(self.device), depth.to(self.device), pose, torch.tensor(0.0).to(self.device)


dataset_dict = {
    "replica": SemanticReplica,
    "scannet": ScanNet,
    "tumrgbd": TUM_RGBD
}


if __name__ == "__main__":
    from easydict import EasyDict
    from src.config import load_config
    from src.common import setup_seed
    
    # easydict allows to access dictionary keys as attributes
    args = EasyDict()
    args.input_folder = None
    
    for name in ["office0","room0", "room1", "room2"]:
        args.config = "configs/Replica/{}.yaml".format(name)
        cfg = load_config(args.config, "configs/q_slam.yaml")
        
        print("Loading dataset {}".format(name))

        dataset = get_dataset(cfg, args, "cpu")
        # release the memory
        del dataset
        
        print("=" * 80)
    
    # i, color, depth, pose, semantic = dataset[0]
    # print('range of depth: ', depth.min(), depth.max()) 
    # print('color shape: ', color.shape)
    # print('semantic shape: ', semantic.shape)
    # print('range of semantic: ', semantic.min(), semantic.max())

