import torch
from torch.utils.data import Dataset
import glob
import numpy as np
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *

from colmapUtils.read_write_model import *
from dataLoader.get_sparse_depth import *
import pandas as pd

def change_coordinate_system(poses: np.ndarray, p: np.ndarray):
        changed_poses = []
        for pose in poses:
            r = pose[:3, :3]
            t = pose[:3, 3:]
            rc = p.T @ r @ p
            tc = p @ t
            changed_pose = np.concatenate([np.concatenate([rc, tc], axis=1), pose[3:]], axis=0)
            changed_poses.append(changed_pose)
        changed_poses = np.stack(changed_poses)
        return changed_poses

def compute_average_pose(poses: np.ndarray):
    def normalize(x):
        return x / np.linalg.norm(x)

    def viewmatrix(z, up, pos):
        vec2 = normalize(z)
        vec1_avg = up
        vec0 = normalize(np.cross(vec1_avg, vec2))
        vec1 = normalize(np.cross(vec2, vec0))
        m = np.stack([vec0, vec1, vec2, pos], 1)
        bottom = np.array([0, 0, 0, 1])[None]
        matrix = np.concatenate([m, bottom], axis=0)
        return matrix

    # compute average pose in camera2world system
    rot_mats = poses[:, :3, :3]
    rot_inverted = np.transpose(rot_mats, axes=[0, 2, 1])
    translations = poses[:, :3, 3:]
    rotated_translations = -rot_inverted @ translations
    avg_translation = np.mean(rotated_translations, axis=0)[:, 0]

    vec2 = normalize(rot_inverted[:, :3, 2].sum(0))
    up = rot_inverted[:, :3, 1].sum(0)
    avg_pose_c2w = viewmatrix(vec2, up, avg_translation)
    avg_pose = np.linalg.inv(avg_pose_c2w)  # convert avg_pose to world2camera system
    return avg_pose

def recenter_poses(poses, avg_pose):
    centered_poses = avg_pose[None] @ np.linalg.inv(poses)
    return centered_poses

def convert_pose_to_standard_coordinates(poses):
        # Convert from Colmap/RE10K convention to NeRF convention: (x,-y,-z) to (x,y,z)
        perm_matrix = np.eye(3)
        perm_matrix[1, 1] = -1
        perm_matrix[2, 2] = -1
        std_poses = change_coordinate_system(poses, perm_matrix)
        return std_poses

def normalize(v):
    """Normalize a vector."""
    return v / np.linalg.norm(v)


def average_poses(poses):
    """
    Calculate the average pose, which is then used to center all poses
    using @center_poses. Its computation is as follows:
    1. Compute the center: the average of pose centers.
    2. Compute the z axis: the normalized average z axis.
    3. Compute axis y': the average y axis.
    4. Compute x' = y' cross product z, then normalize it as the x axis.
    5. Compute the y axis: z cross product x.

    Note that at step 3, we cannot directly use y' as y axis since it's
    not necessarily orthogonal to z axis. We need to pass from x to y.
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        pose_avg: (3, 4) the average pose
    """
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(z, y_))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(x, z)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg


def center_poses(poses, blender2opencv):
    """
    Center the poses so that we can use NDC.
    See https://github.com/bmild/nerf/issues/34
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        poses_centered: (N_images, 3, 4) the centered poses
        pose_avg: (3, 4) the average pose
    """
    poses = poses @ blender2opencv
    pose_avg = average_poses(poses[[10, 20]])  # (3, 4)
    pose_avg_homo = np.eye(4)
    pose_avg_homo[:3] = pose_avg  # convert to homogeneous coordinate for faster computation
    pose_avg_homo = pose_avg_homo
    # by simply adding 0, 0, 0, 1 as the last row
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)
    poses_homo = \
        np.concatenate([poses, last_row], 1)  # (N_images, 4, 4) homogeneous coordinate

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    #     poses_centered = poses_centered  @ blender2opencv
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, pose_avg_homo


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.eye(4)
    m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
    return m


def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=60):
    render_poses = []
    rads = np.array(list(rads) + [1.])

    for theta in np.linspace(0., 1. * np.pi * N_rots, N + 1)[:-1]:
        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
        render_poses.append(viewmatrix(z, up, c))
    return render_poses



def generate_spiral_path_dtu(poses, n_frames=30, n_rots=4, zrate=.05, perc=60, radius_ratio=0.5):
    """Calculates a forward facing spiral path for rendering for DTU."""
    def focus_pt_fn(poses):
        """Calculate nearest point to all focal axes in poses."""
        directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
        m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
        mt_m = np.transpose(m, [0, 2, 1]) @ m
        focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
        return focus_pt
    
    def viewmatrix(lookdir, up, position, subtract_position=False):
        """Construct lookat view matrix."""
        vec2 = normalize((lookdir - position) if subtract_position else lookdir)
        vec0 = normalize(np.cross(up, vec2))
        vec1 = normalize(np.cross(vec2, vec0))
        m = np.stack([vec0, vec1, vec2, position], axis=1)
        return m
    # Get radii for spiral path using 60th percentile of camera positions.
    positions = poses[:, :3, 3]
    radii = np.percentile(np.abs(positions), perc, 0) * radius_ratio
    radii = np.concatenate([radii, [1.]]) 
    # Generate poses for spiral path.
    render_poses = []
    cam2world = average_poses(poses)
    up = poses[:, :3, 1].mean(0)
    z_axis = focus_pt_fn(poses)
    for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False):
        t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]
        position = cam2world @ t
        render_poses.append(viewmatrix(z_axis, up, position, True))
    render_poses = np.stack(render_poses, axis=0)
    # Draw render_path 
    # draw_render_path(poses, cam2world, render_poses)
    # breakpoint()
    return render_poses

def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120, zrate=.5, N_rots=2):
    # center pose
    c2w = average_poses(c2ws_all)
    # ## SimpleNeRF center poses
    # c2w = compute_average_pose(c2ws_all)

    # Get average pose
    up = normalize(c2ws_all[:, :3, 1].sum(0))

    # Find a reasonable "focus depth" for this dataset
    dt = 0.99
    close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
    focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))

    # Get radii for spiral path
    zdelta = near_fars.min() * .2
    tt = c2ws_all[:, :3, 3]
    rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
    render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=zrate, N=N_views, N_rots=N_rots)
    # breakpoint()
    # render_poses.append(c2w[:3, :])
    return np.stack(render_poses)

def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
    
    poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) # 3 x 5 x N
    bds = poses_arr[:, -2:].transpose([1,0])
    
    return poses, bds

def get_poses(images):
    poses = []
    for i in images:
        R = images[i].qvec2rotmat()
        t = images[i].tvec.reshape([3,1])
        bottom = np.array([0,0,0,1.]).reshape([1,4])
        w2c = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
        c2w = np.linalg.inv(w2c)
        poses.append(c2w)
    return np.array(poses)

def load_colmap_depth(datadir, factor=8, bd_factor=.75):
    data_file = datadir + '/colmap_depth.npy'
    
    images = read_images_binary(datadir + '/sparse/0/images.bin')
    points = read_points3d_binary(datadir + '/sparse/0/points3D.bin')

    Errs = np.array([point3D.error for point3D in points.values()])
    Err_mean = np.mean(Errs)
    print("Mean Projection Error:", Err_mean)
    
    poses = get_poses(images)
    # _, bds_raw = _load_data(datadir, factor=factor) # factor=8 downsamples original imgs by 8x
    bds_raw = np.array([1, 100]).astype('float32')
    bds_raw = np.moveaxis(bds_raw, -1, 0).astype(np.float32)
    # print(bds_raw.shape)
    # Rescale if bd_factor is provided
    sc = 1. if bd_factor is None else 1./(bds_raw.min() * bd_factor)
    
    near = np.ndarray.min(bds_raw) * .9 * sc
    far = np.ndarray.max(bds_raw) * 1. * sc
    print('near/far:', near, far)

    data_list = []
    for id_im in range(1, len(images)+1):
        depth_list = []
        coord_list = []
        weight_list = []
        for i in range(len(images[id_im].xys)):
            point2D = images[id_im].xys[i]
            id_3D = images[id_im].point3D_ids[i]
            if id_3D == -1:
                continue
            point3D = points[id_3D].xyz
            depth = (poses[id_im-1,:3,2].T @ (point3D - poses[id_im-1,:3,3])) * sc
            if depth < bds_raw[id_im-1,0] * sc or depth > bds_raw[id_im-1,1] * sc:
                continue
            err = points[id_3D].error
            weight = 2 * np.exp(-(err/Err_mean)**2)
            depth_list.append(depth)
            coord_list.append(point2D/factor)
            weight_list.append(weight)

        if len(depth_list) > 0:
            # print(id_im, len(depth_list), np.min(depth_list), np.max(depth_list), np.mean(depth_list))
            data_list.append({"depth":np.array(depth_list), "coord":np.array(coord_list), "error":np.array(weight_list)})
        # else:
        #     print(id_im, len(depth_list))
    # json.dump(data_list, open(data_file, "w"))
    np.save(data_file, data_list)
    return data_list

def draw_render_path(render_poses=None, poses=None, avg_poses=None, bbox=None):
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_box_aspect([1, 1, 1])
    if render_poses is not None:
        data = render_poses[:, :3, 3]
        ax.scatter(data[:,0], data[:,1], data[:,2], c='b')
        ax.quiver(render_poses[:, :3, 3][:,0], render_poses[:, :3, 3][:,1], render_poses[:, :3, 3][:,2], render_poses[:, :3, 2][:, 0], render_poses[:, :3, 2][:, 1], render_poses[:, :3, 2][:, 2], color='b', length=0.3,
                        arrow_length_ratio=0.01, zorder=1, alpha=0.5)
    if poses is not None:
        ax.scatter(poses[:, :3, 3][:,0], poses[:, :3, 3][:,1], poses[:, :3, 3][:,2], c='orange')
        ax.quiver(poses[:, :3, 3][:,0], poses[:, :3, 3][:,1], poses[:, :3, 3][:,2], poses[:, :3, 2][:, 0], poses[:, :3, 2][:, 1], poses[:, :3, 2][:, 2], color='orange', length=0.3,
                        arrow_length_ratio=0.01, zorder=1)
    if avg_poses is not None:
        ax.scatter(avg_poses[:3, 3][0], avg_poses[:3, 3][1], avg_poses[:3, 3][2], c='g')
        # ax.quiver(avg_poses[:,0], avg_poses[:,1], avg_poses[:,2], avg_poses[:, 0], avg_poses[:, 1], avg_poses[:, 2], color='g', length=0.1,
                        # arrow_length_ratio=0.1, zorder=1)
        ax.scatter(0, 0, 0, c='pink')
    
    # if bbox is not None:
    #     ax.scatter(bbox[0,0], bbox[0,1], bbox[0,2], c='r')
    #     ax.scatter(-bbox[0,0], bbox[0,1], bbox[0,2], c='r')
    #     ax.scatter(bbox[0,0], -bbox[0,1], bbox[0,2], c='r')
    #     ax.scatter(-bbox[0,0], -bbox[0,1], bbox[0,2], c='r')
    #     ax.scatter(bbox[1,0], bbox[1,1], bbox[1,2], c='r')
    #     ax.scatter(-bbox[1,0], bbox[1,1], bbox[1,2], c='r')
    #     ax.scatter(bbox[1,0], -bbox[1,1], bbox[1,2], c='r')
    #     ax.scatter(-bbox[1,0], -bbox[1,1], bbox[1,2], c='r')
    
    # plot origin
    ax.scatter(0, 0, 0, c='purple')
    # plot axis
    ax.set_xlabel('X Axis')
    ax.set_ylabel('Y Axis')
    ax.set_zlabel('Z Axis')
    
    ani = FuncAnimation(fig, lambda i: ax.view_init(elev=30, azim=i), frames=np.arange(0, 360, 1), interval=10)
    ani.save('Real_2view.gif', writer='imagemagick', fps=30)  
    # breakpoint()


class RealEstate10KDataset(Dataset):
    def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8, frame_num=[]):
        """
        spheric_poses: whether the images are taken in a spheric inward-facing manner
                       default: False (forward-facing)
        val_num: number of val images (used for multigpu training, validate same image for all gpus)
        """

        self.root_dir = datadir
        self.split = split
        self.hold_every = hold_every
        self.is_stack = is_stack
        self.downsample = downsample
        self.define_transforms()
        self.frame_num = frame_num
        self.frame_len = len(frame_num)
        self.scene_name = datadir.split("/")[-1]
        self.dataset_name = datadir.split("/")[-2]

        # self.scene_bbox = torch.tensor([[-2.0, -2.0, -2.0], [2.0, 2.0, 2.0]])
        # self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
        self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
        self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.white_bg = False

        #         self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
        self.near_far = [0.0, 1.0]
        # self.scene_bbox = torch.tensor([[-0.5, -6.0, -2.0], [0.5, 6.0, 2.0]])
        # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
        # self.scene_bbox = torch.tensor([[-1.0, -1.0, -2.0], [1.0, 1.0, 10.0]])
        
        self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
        self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

        self.mask_lis = None
        
    def pre_calculate_nearest_pose(self, img_list):
        num_camera_pose = len(img_list)

        nearest_dist = np.full(len(self.poses), np.inf) # index; input_pose_index, output: its nearest_pose_index
        nearest_pose = np.full(len(self.poses), -1)

        dist = 0
        cur, next = -1, -1
        for i in range(num_camera_pose - 1):
            cur = img_list[i]
            for j in range(i + 1, num_camera_pose):
                next = img_list[j]
                dist = np.linalg.norm(self.poses[cur][:, 3] - self.poses[next][:, 3])
                if dist < nearest_dist[cur]:
                    nearest_dist[cur] = dist
                    nearest_pose[cur] = next
                if dist < nearest_dist[next]:
                    nearest_dist[next] = dist
                    nearest_pose[next] = cur
        return nearest_pose
    
    def get_nearest_pose(self, c2w, img_list, i):
        # calculate neighbor poses
        min_distance = -1
        for j in img_list:
            if j == i and self.split == 'train':
                continue
            distance = (torch.sum(((c2w[:3,3] - self.poses[j,:,3])**2)))**0.5
            
            if min_distance == -1 or distance < min_distance:
                min_distance = distance
                nearest_id = j
        return nearest_id

    def read_meta(self):

        # self.depth_gts = load_colmap_depth(self.root_dir, factor=self.downsample)
        # poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy'))  # (N_images, 17)
        self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'rgb/*')))
        # # load full resolution image then resize
        # if self.split in ['train', 'test', 'novel', 'novel_cheat']:
        #     assert len(poses_bounds) == len(self.image_paths), \
        #         'Mismatch between number of images and number of poses! Please rerun COLMAP!'

        # poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
        # self.near_fars = poses_bounds[:, -2:]  # (N_images, 2)
        
        # hwf = poses[:, :, -1]
        
        
        # read camera poses
        extrinsics_path = os.path.join(self.root_dir, 'CameraExtrinsics.csv')
        extrinsics_df = pd.read_csv(extrinsics_path, header=None)
        extrinsic_matrices = extrinsics_df.to_numpy().reshape(-1, 4, 4)
        homo_poses = extrinsic_matrices
        poses =  homo_poses[:, :3, :]
        # read camera intrinsics
        intrinsics_path = os.path.join(self.root_dir, "CameraIntrinsics.csv")
        intrinsics_df = pd.read_csv(intrinsics_path, header=None)
        intrinsic_matrices = intrinsics_df.to_numpy().reshape(-1, 3, 3)
        intrinsics = intrinsic_matrices
        # set pose bound
        self.near_fars = np.full((poses.shape[0], 2), [0.1, 100]) # pose bounds


        # Step 1: rescale focal length according to training resolution
        # H, W, self.focal = poses[0, :, -1]  # original intrinsics, same for all images
        # self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
        # self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
        H, W = intrinsics[0, 1, 2] * 2, intrinsics[0, 0, 2] * 2
        self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
        self.focal = [intrinsics[0][0, 0] * self.img_wh[0] / W, intrinsics[0][1, 1] * self.img_wh[1] / H]

        # Step 2: correct poses
        # Original poses has rotation in form "down right back", change to "right up back"
        # See https://github.com/bmild/nerf/issues/34
        ##### TODO: check if we need below code
        # poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 
        
        # ## SimpleNeRF center poses
        avg_pose = compute_average_pose(homo_poses)
        # avg_pose = compute_average_pose(homo_poses[self.frame_num])
        self.poses = recenter_poses(homo_poses, avg_pose)
        self.poses = convert_pose_to_standard_coordinates(self.poses)
        self.poses = self.poses[:, :3, :]
        
        
        ##### TODO: check if we need below code
        # (N_images, 3, 4) exclude H, W, focal
        # I. center poses
        # self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
        # II. do not center poses
        # self.poses = poses

        # Step 3: correct scale so that the nearest depth is at a little more than 1.0
        # See https://github.com/bmild/nerf/issues/34
        near_original = self.near_fars.min()
        scale_factor = near_original * 7.5  # 0.75 is the default parameter
        # scale_factor = np.abs(self.poses[..., 3]).max() * 2.5
        
        # the nearest depth is at 1/0.75=1.33
        self.near_fars /= scale_factor
        self.poses[..., 3] /= scale_factor
        # build rendering path
        N_views, N_rots = 60, 2
        tt = self.poses[:, :3, 3]  # ptstocam(poses[:3,3,:].T, c2w).T
        up = normalize(self.poses[:, :3, 1].sum(0))
        rads = np.percentile(np.abs(tt), 90, 0)
        
        # if self.frame_num is not None and len(self.frame_num) > 0:
        #     self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
        # else:
        #     self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
        
        if self.split == 'train':
            print('train')
            self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
        elif self.split == 'novel':
            print('novel')
            # self.render_path = np.expand_dims(average_poses(self.poses[self.frame_num]), axis=0)
            # self.render_path = self.poses[[5,6,7,8,9,11,12,13,14,15,16,17,18,19,21,22,23,24,25]]
            self.render_path = get_spiral(self.poses[self.frame_num], self.near_fars, N_views=N_views, zrate=0.5, rads_scale=2.0, N_rots=4)
            # self.render_path = self.poses
            
            # self.render_path = generate_spiral_path_dtu(self.poses[self.frame_num], n_frames=60, n_rots=4, zrate=.5, perc=60, radius_ratio=2.0)
            
            # draw_render_path(render_poses=self.render_path, poses=self.poses[self.frame_num], bbox=self.scene_bbox)
        elif self.split == 'test':
            print('test')
            self.render_path = self.poses[self.frame_num]
            # breakpoint()
            # draw_render_path(poses=self.poses[self.frame_num], bbox=self.scene_bbox)
            
            
            
        
            
        # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
        # val_idx = np.argmin(distances_from_center)  # choose val image as the closest to
        # center image
        # ray directions for all pixels, same for all images (same H, W, focal)
        W, H = self.img_wh
        self.directions = get_ray_directions_blender(H, W, self.focal)  # (H, W, 3)

        average_pose = average_poses(self.poses)
        dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
        if self.frame_num is not None and len(self.frame_num) > 0:
            img_list = self.frame_num
        elif self.split == 'novel':
            if self.frame_num is not None and len(self.frame_num) > 0:
                img_list = self.frame_num
            else:
                img_list = []
        else:
            i_test = np.arange(0, self.poses.shape[0], self.hold_every)  # [np.argmin(dists)]
            img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))

        # use first N_images-1 to train, the LAST is val
        # nearest_pose_ids = torch.from_numpy(self.pre_calculate_nearest_pose(img_list))
        self.all_rays = []
        self.all_rgbs = []
        self.all_ids = []
        self.all_nearest_ids = []
        self.all_depths = []
        self.all_depth_weights = []
        
        if self.split != 'novel':
            self.frameid2_startpoints_in_allray = [-10] * self.poses.shape[0] # -10 represent
            cnt = 0
            # breakpoint()
            for i in img_list:
                image_path = self.image_paths[i]
                c2w = torch.FloatTensor(self.poses[i])
                img = Image.open(image_path).convert('RGB')
                if self.downsample != 1.0:
                    img = img.resize(self.img_wh, Image.LANCZOS)
                
                img = self.transform(img)  # (3, h, w)
                
                depth = -torch.ones(H, W)
                weight = -torch.ones(H, W)
                # for j in range(len(self.depth_gts[i]['coord'])):
                #     # if self.depth_gts[i]['error'][j] < 0.5:

                #         # avoid out of bound
                #         x = round(self.depth_gts[i]['coord'][j,1]) 
                #         x = x if x < H else H-1
                #         y = round(self.depth_gts[i]['coord'][j,0])
                #         y = y if y < W else W-1
                #         depth[x, y] = self.depth_gts[i]['depth'][j]
                #         weight[x, y] = self.depth_gts[i]['error'][j]
                
                # In training, we use sparse depth
                if self.split == "train":
                    SD = load_sparse_depth(self.dataset_name, self.scene_name, self.frame_len, i, int(self.downsample))
                    for j in range(len(SD)):
                        depth[round(SD.y[j]), round(SD.x[j])] = SD.depth[j] / scale_factor
                        weight[round(SD.y[j]), round(SD.x[j])] = SD.weight[j]
                depth = depth.view(-1)
                weight = weight.view(-1)

                nearest_id = self.get_nearest_pose(c2w, img_list, i)
                                
                        

                img = img.view(3, -1).permute(1, 0)  # (h*w, 3) RGB
                id = torch.ones_like(depth).int() * i
                self.all_rgbs += [img]
                # self.all_view_ids += [id]
                self.all_depths += [depth]
                self.all_depth_weights += [weight]
                rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
                rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
                # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
                self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)
                cur_ids = torch.full([rays_o.shape[0]], i)
                self.all_ids += [cur_ids]
                self.all_nearest_ids += [torch.ones_like(cur_ids).int() * nearest_id]
                self.frameid2_startpoints_in_allray[i] = cnt * cur_ids.shape[0] - 1
                cnt += 1 
        
        if self.split == 'novel':
            cnt = 0
            self.frameid2_startpoints_in_allray = [-10] * self.render_path.shape[0]
            for i, c2w in enumerate(self.render_path):
                c2w = torch.FloatTensor(c2w)
                rays_o, rays_d = get_rays(self.directions, c2w)
                rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
                self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)
                cur_ids = torch.full([rays_o.shape[0]], i)
                self.all_ids += [cur_ids]
                nearest_id = self.get_nearest_pose(c2w, img_list, i)
                self.all_nearest_ids += [torch.ones_like(cur_ids).int() * nearest_id]
                self.frameid2_startpoints_in_allray[i] = cnt * cur_ids.shape[0] - 1
                cnt += 1
            self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
            self.all_ids = torch.cat(self.all_ids, 0).to(torch.int)
            self.all_nearest_ids = torch.cat(self.all_nearest_ids, 0).to(torch.int)
        else:
            if not self.is_stack:
                self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
                self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
                self.all_depths = torch.cat(self.all_depths, 0)
                self.all_depth_weights = torch.cat(self.all_depth_weights, 0)
                # self.all_view_ids = torch.cat(self.all_view_ids, 0)
                self.all_ids = torch.cat(self.all_ids, 0).to(torch.int)
                self.all_nearest_ids = torch.cat(self.all_nearest_ids, 0).to(torch.int)
            else:
                self.all_rays = torch.stack(self.all_rays, 0)   # (len(self.meta['frames]),h,w, 3)
                self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)
                self.all_depths = torch.stack(self.all_depths, 0).reshape(-1,*self.img_wh[::-1], 1)
                self.all_depth_weights = torch.stack(self.all_depth_weights, 0).reshape(-1,*self.img_wh[::-1], 1)
                # self.all_view_ids = torch.stack(self.all_view_ids, 0)
                self.all_ids = torch.stack(self.all_ids, 0).to(torch.int)
                self.all_nearest_ids = torch.stack(self.all_nearest_ids, 0).to(torch.int)
        self.poses = torch.FloatTensor(self.poses)
        self.render_path = torch.FloatTensor(self.render_path[:,:3,:])

    def define_transforms(self):
        self.transform = T.ToTensor()

    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        sample = {'rays': self.all_rays[idx],
                  'rgbs': self.all_rgbs[idx]}

        return sample