import imp
from logging import root
from os import path as osp
from typing import Dict
from unicodedata import name
from kornia import depth

import numpy as np
import torch
import torch.utils as utils
from numpy.linalg import inv
from src.utils.dataset import(
    read_scannet_gray,
    read_scannet_depth,
    read_scannet_pose,
    read_scannet_instance
)


class ScanNet25k(utils.data.Dataset):
    def __init__(
        self,
        root_dir,
        npz_path,
        intrinsic_path,
        mode='train',
        augment_fn=None,
        pose_dir=None,
        **kwargs
    ):
        """Manage one scene of ScanNet25k Dataset.

        Args:
            root_dir (str): ScanNet25k root directory that contains scene folders.
            npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
            intrinsic_path (str): path to depth-camera intrinsic file.
            mode (str): options are ['train', 'val', 'test'].
            augment_fn (callable, optional): augments images with pre-defined visual effects.
            pose_dir (str): ScanNet root directory that contains all poses.
                (we use a separate (optional) pose_dir since we store images and poses separately.)
        """
        super().__init__()
        self.root_dir = root_dir
        self.pose_dir = pose_dir if pose_dir is not None else root_dir
        self.mode = mode
        # prepare data_names
        with np.load(npz_path) as data:
            print(npz_path)
            self.data_names = data['name']
        self.intrinsics = dict(np.load(intrinsic_path))
        
        # for training LoFTR
        self.augment_fn = augment_fn if mode == 'train' else None
        
        
    def __len__(self):
        return len(self.data_names)
    
    def __getitem__(self, idx):
        data_name = self.data_names[idx]
        scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name
        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
        
        # read the grayscale image which will be resized to (1, 480, 640)
        img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg').replace('\\', '/')
        img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg').replace('\\', '/')
        
        image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None)
        image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None)
        
        # read the depthmap which is stored as (480, 640)
        if self.mode in ['train', 'val']:
            depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png').replace('\\', '/'))
            depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png').replace('\\', '/'))
        else:
            depth0 = depth1 = torch.tensor([])
            
        # read the instance map which is stored as (480, 640)
        if self.mode in ['train', 'val']:
            # instance0 = instance1 = torch.tensor([])
            instance0 = read_scannet_instance(osp.join(self.root_dir, scene_name, 'instance', f'{stem_name_0}.png').replace('\\', '/'))
            instance1 = read_scannet_instance(osp.join(self.root_dir, scene_name, 'instance', f'{stem_name_1}.png').replace('\\', '/'))
        else:
            instance0 = instance1 = torch.tensor([])
        
        # read the intrinsic of depthmap
        K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3)
        
        # read and compute relative poses
        T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), dtype=torch.float32)
        T_1to0 = T_0to1.inverse()
        
        data = {
            'image0': image0, # color data read in gray scale
            'depth0': depth0, # depth data read in float
            'image1': image1,
            'depth1': depth1,
            'instance0': instance0, # instance data for the image
            'instance1': instance1,
            'T_0to1': T_0to1, # relative pose from 1 to 0
            'T_1to0': T_1to0,
            'K0': K_0, # depth camera matrix for depth data
            'K1': K_1,
            'dataset_name': 'ScanNet',
            'scene_id': scene_name,
            'pair_id': idx,
            'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg').replace('\\', '/'), osp.join(scene_name, 'color', f'{stem_name_1}.jpg').replace('\\', '/'))
        }
        return data
    
    def _read_abs_pose(self, scene_name, name):
        pth = osp.join(self.pose_dir, scene_name, 'pose', f'{name}.txt').replace('\\', '/')
        return read_scannet_pose(pth)
        
    def _compute_rel_pose(self, scene_name, name0, name1):
        pose0 = self._read_abs_pose(scene_name, name0)
        pose1 = self._read_abs_pose(scene_name, name1)
        return np.matmul(pose1, inv(pose0)) # (4, 4)
    