#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import sys
from PIL import Image
from typing import NamedTuple
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
    read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
import numpy as np
import json
from pathlib import Path
from plyfile import PlyData, PlyElement
from utils.sh_utils import SH2RGB
from scene.gaussian_model import BasicPointCloud
import pickle as pkl
import torch
import math
import pytorch3d
from tqdm import tqdm
from scipy.spatial.transform import Rotation, RotationSpline
sys.path.append('../ext/NeuralHaircut')
from NeuS.models.dataset import load_K_Rt_from_P
from typing import NamedTuple, Optional



class CameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    image: np.array
    image_path: str
    image_name: str
    width: int
    height: int
    bg: np.array = np.array([0, 0, 0])
    timestep: Optional[int] = None
    camera_id: Optional[int] = None
    camera_index: Optional[int] = None
    num_time_steps: Optional[int] = None

class SceneInfo(NamedTuple):
    point_cloud: BasicPointCloud
    train_cameras: list
    test_cameras: list
    nerf_normalization: dict
    ply_path: str
    train_meshes: dict = {}
    test_meshes: dict = {}
    tgt_train_meshes: dict = {}
    tgt_test_meshes: dict = {}

def getNerfppNorm(cam_info):
    def get_center_and_diag(cam_centers):
        cam_centers = np.hstack(cam_centers)
        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
        center = avg_cam_center
        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
        diagonal = np.max(dist)
        return center.flatten(), diagonal

    cam_centers = []

    for cam in cam_info:
        W2C = getWorld2View2(cam.R, cam.T)
        C2W = np.linalg.inv(W2C)
        cam_centers.append(C2W[:3, 3:4])

    center, diagonal = get_center_and_diag(cam_centers)
    radius = diagonal * 1.1

    translate = -center

    return {"translate": translate, "radius": radius}

def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
    cam_infos = []
    for idx, key in enumerate(cam_extrinsics):
        sys.stdout.write('\r')
        # the exact output you're looking for:
        sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
        sys.stdout.flush()

        extr = cam_extrinsics[key]
        intr = cam_intrinsics[extr.camera_id]
        width = intr.width
        height = intr.height

        uid = intr.id
        R = np.transpose(qvec2rotmat(extr.qvec))
        T = np.array(extr.tvec)

        if intr.model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL"]:
            focal_length_x = intr.params[0]
            FovY = focal2fov(focal_length_x, height)
            FovX = focal2fov(focal_length_x, width)
        elif intr.model=="PINHOLE":
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)
            FovX = focal2fov(focal_length_x, width)
        elif intr.model == "OPENCV":
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)
            FovX = focal2fov(focal_length_x, width)
        else:
            assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

        image_path = os.path.join(images_folder, os.path.basename(extr.name))
        if not os.path.exists(image_path):
            continue

        image_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(image_path)
        
        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                              image_path=image_path, image_name=image_name, width=width, height=height)
        cam_infos.append(cam_info)
    sys.stdout.write('\n')
    return cam_infos

def fetchPly(path):
    plydata = PlyData.read(path)
    vertices = plydata['vertex']
    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
    normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
    return BasicPointCloud(points=positions, colors=colors, normals=normals)

def fetchPly_wo_RGB_Normal(path):
    plydata = PlyData.read(path)
    vertices = plydata['vertex']
    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    # colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
    # normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
    # colors = np.ones_like(positions)
    colors = np.zeros_like(positions)
    normals = np.random.random((positions.shape[0], 3))
    # xyz = np.random.random((positions, 3)) * 2.6 - 1.3
    # shs = np.random.random((num_pts, 3)) / 255.0
    return BasicPointCloud(points=positions, colors=colors, normals=normals)

def storePly(path, xyz, rgb):
    # Define the dtype for the structured array
    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
            ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
    
    normals = np.zeros_like(xyz)

    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))

    # Create the PlyData object and write to file
    vertex_element = PlyElement.describe(elements, 'vertex')
    ply_data = PlyData([vertex_element])
    ply_data.write(path)

def readColmapSceneInfo(path, images, eval, llffhold=2, interpolate_cameras=False, speed_up=4, max_frames=300, frame_offset=0):
    try:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
    except:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

    reading_dir = "images" if images == None else images
    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)

    if interpolate_cameras:
        R = [cam_info.R for cam_info in cam_infos]
        rotations = Rotation.from_matrix(np.stack(R))
        frames = [int(cam_info.image_name) for cam_info in cam_infos]
        spline = RotationSpline(frames, rotations)
        R_interp = spline(list(range(frames[-1]))).as_matrix()

        prev_j = -1
        next_j = 0
        cam_infos_interp = []

        for i in range(frames[-1]):
            if i in frames:
                prev_j += 1
                next_j += 1

            alpha = 1 - (i - frames[prev_j]) / (frames[next_j] - frames[prev_j])

            cam_info_interp = CameraInfo(
                uid = int(cam_infos[prev_j].uid * alpha + cam_infos[next_j].uid * (1 - alpha)), 
                R = R_interp[i], 
                T = cam_infos[prev_j].T * alpha + cam_infos[next_j].T * (1 - alpha), 
                FovY = cam_infos[prev_j].FovY * alpha + cam_infos[next_j].FovY * (1 - alpha), 
                FovX = cam_infos[prev_j].FovX * alpha + cam_infos[next_j].FovX * (1 - alpha), 
                image = cam_infos[prev_j].image,
                image_path = cam_infos[prev_j].image_path.replace('%06d' % prev_j, '%06d' % i), 
                image_name = '%06d' % i, 
                width = cam_infos[prev_j].width, 
                height = cam_infos[prev_j].height
            )

            cam_infos_interp.append(cam_info_interp)
        
        cam_infos = cam_infos_interp[frames[0]:frames[-1]][::speed_up][frame_offset:frame_offset+max_frames]

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []
    
    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None
    
    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

def readCamerasFromTransformsOld(path, transformsfile, white_background, extension=".png"):
    cam_infos = []

    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)
        fovx = contents["camera_angle_x"]

        frames = contents["frames"]
        for idx, frame in enumerate(frames):
            cam_name = os.path.join(path, frame["file_path"] + extension)

            # NeRF 'transform_matrix' is a camera-to-world transform
            c2w = np.array(frame["transform_matrix"])
            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
            c2w[:3, 1:3] *= -1

            # get the world-to-camera transform and set R, T
            w2c = np.linalg.inv(c2w)
            R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
            T = w2c[:3, 3]

            image_path = os.path.join(path, cam_name)
            image_name = Path(cam_name).stem
            image = Image.open(image_path)

            im_data = np.array(image.convert("RGBA"))

            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])

            norm_data = im_data / 255.0
            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")

            fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
            FovY = fovy 
            FovX = fovx

            cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                            image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
            
    return cam_infos

def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png", start_time_step=-1, _num_time_steps=-1,except_name=None,include_name=None):
    cam_infos = []
    with open(os.path.join(path, "flame_fitting","preprocessed_data",transformsfile)) as json_file:
        contents = json.load(json_file)
        if 'camera_angle_x' in contents:
            fovx_shared = contents["camera_angle_x"]
        
        num_time_steps = len(contents["timestep_indices"])
        if _num_time_steps == -1 or _num_time_steps >= num_time_steps:
            _num_time_steps = num_time_steps
        if start_time_step == -1:
            start_time_step = 0

        frames = contents["frames"]
        for idx, frame in tqdm(enumerate(frames), total=len(frames)):
            if except_name != None:
                if frame["camera_id"] == except_name:
                    continue
            if include_name != None:
                if frame["camera_id"] != include_name:
                    continue
            file_path = frame["file_path"]
            # if frame["timestep_index"] > _num_time_steps - 1:
            if frame["timestep_index"] < start_time_step or frame["timestep_index"] > start_time_step + _num_time_steps - 1:
                continue
            if extension not in frame["file_path"]:
                file_path += extension
            # if frame["timestep_index"] != 50:
            #     continue
            cam_name = os.path.join(path, file_path)

            # NeRF 'transform_matrix' is a camera-to-world transform
            c2w = np.array(frame["transform_matrix"])
            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
            c2w[:3, 1:3] *= -1

            # get the world-to-camera transform and set R, T
            w2c = np.linalg.inv(c2w)
            R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
            T = w2c[:3, 3]

            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])

            image_path = os.path.join(path, cam_name)
            image_name = Path(cam_name).stem
            
            # image = Image.open(image_path)
            # import ipdb;ipdb.set_trace()
            image = Image.open(image_path.replace(f'fg_images', f'images'))
            im_data = np.array(image.convert("RGBA"))
            norm_data = im_data / 255.0
            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
            width, height = image.size

            if 'camera_angle_x' in frame:
                fovx = frame["camera_angle_x"]
            else:
                fovx = fovx_shared
            fovy = focal2fov(fov2focal(fovx, width), height)

            timestep = frame["timestep_index"] if 'timestep_index' in frame else None
            # camera_id = frame["camera_index"] if 'camera_index' in frame else None
            camera_id = frame["camera_id"] if 'camera_id' in frame else None
            camera_index = frame["camera_index"] if 'camera_index' in frame else None
            
            cam_infos.append(CameraInfo(
                uid=idx, R=R, T=T, FovY=fovy, FovX=fovx, bg=bg, image=image, 
                image_path=image_path, image_name=image_name, 
                width=width, height=height, 
                timestep = timestep - start_time_step, camera_id=camera_id,camera_index = camera_index ,num_time_steps = _num_time_steps))
                # timestep=timestep, camera_id=camera_id,num_time_steps = num_time_steps))
    return cam_infos


def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
    print("Reading Training Transforms")
    train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
    print("Reading Test Transforms")
    test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
    
    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "points3d.ply")
    if not os.path.exists(ply_path):
        # Since this data set has no colmap data, we start with random points
        num_pts = 100_000
        print(f"Generating random point cloud ({num_pts})...")
        
        # We create random points inside the bounds of the synthetic Blender scenes
        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
        shs = np.random.random((num_pts, 3)) / 255.0
        pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))

        storePly(ply_path, xyz, SH2RGB(shs) * 255)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info


def readSyntheticCameras(extrinsics_all, intrinsics_all, images_folder, resolution=1024):
    cam_infos = []

    for i in range(extrinsics_all.shape[0]):
        # setup camera
        R = np.transpose(extrinsics_all[i, :3,:3].cpu().numpy())  # R is stored transposed due to 'glm' in CUDA code
        T = extrinsics_all[i, :3, 3].cpu().numpy()

        fx = intrinsics_all[i][0][0].cpu()
        fy = intrinsics_all[i][1][1].cpu()

        FoVx = 2 * math.atan(resolution / 2 / fx)
        FoVy = 2 * math.atan(resolution / 2 / fy)
        
        image_name = '%04d' % i
        image_path = f'{images_folder}/{image_name}.png'
        image = Image.open(image_path)

        cam_infos.append(CameraInfo(uid=0, R=R, T=T, FovY=FoVy, FovX=FoVx, image=image,
                                    image_path=image_path, image_name=image_name, width=resolution, height=resolution))

    return cam_infos


def scale_matrix(mat, scale_factor):
    mat[0, 0] /= scale_factor
    mat[1, 1] /= scale_factor
    mat[0, 2] /= scale_factor
    mat[1, 2] /= scale_factor
    return mat


def readSyntheticSceneInfo(path, images, eval, llffhold=2):
    try:
        camera_dict = np.load(f'{path}/projection.npy')
    except:
        camera_dict = np.load(f'{path}/cameras.npz')['arr_0']

    intrinsics_all = []
    pose_all = []
    for world_mat in camera_dict:
        intrinsics, pose = load_K_Rt_from_P(None, world_mat[:3, :4])
        intrinsics_all.append(scale_matrix(torch.from_numpy(intrinsics).float(), 2))
        pose_all.append(torch.from_numpy(pose).float())

    intrinsics_all = torch.stack(intrinsics_all)   # [n_images, 4, 4]
    pose_all = torch.stack(pose_all)  
    pose_all_inv = torch.inverse(pose_all)

    reading_dir = "images" if images == None else images
    cam_infos_unsorted = readSyntheticCameras(extrinsics_all=pose_all_inv, intrinsics_all=intrinsics_all, images_folder=os.path.join(path, reading_dir))
    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []
    
    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "points3d.ply")
    # if not os.path.exists(ply_path):
    # Since this data set has no colmap data, we start with random points
    num_pts = 100_000
    print(f"Generating random point cloud ({num_pts})...")
    
    # head_mesh = pytorch3d.io.load_objs_as_meshes([f'{path}/flame_fitting/stage_3/mesh_final.obj'])
    # xyz = pytorch3d.ops.sample_points_from_meshes(head_mesh, num_samples=num_pts)[0]

    # We create random points inside the bounds of the synthetic Blender scenes
    xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
    shs = np.random.random((num_pts, 3)) / 255.0
    pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))

    storePly(ply_path, xyz, SH2RGB(shs) * 255)
    
    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

def readMeshesFromTransforms(path, transformsfile):
    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)
        frames = contents["frames"]
        
        mesh_infos = {}
        for idx, frame in tqdm(enumerate(frames), total=len(frames)):
            if not 'timestep_index' in frame or frame["timestep_index"] in mesh_infos:
                continue

            flame_param = dict(np.load(os.path.join(path, frame['flame_param_path']), allow_pickle=True))
            mesh_infos[frame["timestep_index"]] = flame_param
    return mesh_infos
def readDynamicNerfInfo(path, white_background, eval, extension=".png", target_path="", select_time_step=-1):
    print("Reading Training Transforms")
    if target_path != "":
        train_cam_infos = readCamerasFromTransforms(target_path, "transforms_train.json", white_background, extension, select_time_step)
    else:
        train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension, select_time_step)
    
    print("Reading Test Transforms")
    if target_path != "":
        test_cam_infos = readCamerasFromTransforms(target_path, "transforms_test.json", white_background, extension, select_time_step)
    else:
        test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension, select_time_step)
    
    if target_path != "" or not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)
    
    ply_path = os.path.join(path, "points3D_multipleview.ply")
    if not os.path.exists(ply_path):
        # Since this data set has no colmap data, we start with random points
        num_pts = 100_000
        print(f"Generating random point cloud ({num_pts})...")
        
        # We create random points inside the bounds of the synthetic Blender scenes
        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
        shs = np.random.random((num_pts, 3)) / 255.0
        pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))

        storePly(ply_path, xyz, SH2RGB(shs) * 255)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info
def readStaticNerfInfo(path,flame_path ,white_background, eval, extension=".png", target_path="", start_time_step=-1, num_time_steps=-1):
    # target_path = path + "/flame_fitting/preprocessed_data"
    print("Reading Training Transforms")
    test_name = "222200037"
    if target_path != "":
        train_cam_infos = readCamerasFromTransforms(target_path, "transforms.json", white_background, extension, start_time_step, num_time_steps,except_name=test_name,include_name=None)
    else:
        train_cam_infos = readCamerasFromTransforms(path, "transforms.json", white_background, extension, start_time_step, num_time_steps,except_name=test_name,include_name=None)
    # if target_path != "":
    #     train_cam_infos = readCamerasFromTransforms(target_path, "transforms.json", white_background, extension, start_time_step, num_time_steps,except_name=None,include_name=test_name)
    # else:
    #     train_cam_infos = readCamerasFromTransforms(path, "transforms.json", white_background, extension, start_time_step, num_time_steps,except_name=None,include_name=test_name)
    
    print("Reading Test Transforms")
    if target_path != "":
        test_cam_infos = readCamerasFromTransforms(target_path, "transforms.json", white_background, extension, start_time_step, num_time_steps,except_name=None,include_name=test_name)
    else:
        test_cam_infos = readCamerasFromTransforms(path, "transforms.json", white_background, extension, start_time_step, num_time_steps,except_name=None,include_name=test_name)
        
    train_cam_infos.sort(key=lambda x: x.timestep)
    test_cam_infos.sort(key=lambda x: x.timestep)
    # train_cam_infos = [c for c in train_cam_infos if c.camera_id != "221501007"]
    # test_cam_infos = [c for c in test_cam_infos if c.camera_id == "221501007"]
    print("Reading Training items : ", len(train_cam_infos))
    print("Reading Test items     : ", len(test_cam_infos))
    
    if target_path != "" or not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)
    
    ply_base_path = f'{flame_path}/raw_data/eval_30/point_cloud'
    pcd_seq = []
    for i in range(start_time_step, start_time_step + num_time_steps):
        ply_path = os.path.join(ply_base_path, f'frame_{i:05d}.ply')
        if not os.path.exists(ply_path):
            num_pts = 100_000
            xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
            shs = np.random.random((num_pts, 3)) / 255.0
            pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
            storePly(ply_path, xyz, SH2RGB(shs) * 255)
        try:
            pcd = fetchPly_wo_RGB_Normal(ply_path)
        except:
            pcd = None
        pcd_seq.append(pcd)
    
    # ply_path = os.path.join(path, "points3D_multipleview.ply")
    # if not os.path.exists(ply_path):
    #     # Since this data set has no colmap data, we start with random points
    #     num_pts = 100_000
    #     print(f"Generating random point cloud ({num_pts})...")
        
    #     # We create random points inside the bounds of the synthetic Blender scenes
    #     xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
    #     shs = np.random.random((num_pts, 3)) / 255.0
    #     pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))

    #     storePly(ply_path, xyz, SH2RGB(shs) * 255)
    # try:
    #     pcd = fetchPly(ply_path)
    # except:
    #     pcd = None
    # print(pcd_seq)
    scene_info = SceneInfo(point_cloud=pcd_seq,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

sceneLoadTypeCallbacks = {
    "Colmap": readColmapSceneInfo,
    "Blender": readNerfSyntheticInfo,
    "Synthetic": readSyntheticSceneInfo,
    "DynamicNerf" : readDynamicNerfInfo,
    "StaticNerf" : readStaticNerfInfo,
}