


import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON

class Scene:

    gaussians : GaussianModel

    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
        """
        :param path: Path to colmap scene main folder.
        Initialize Scene class instance.

        Args:
        args (ModelParams): Object containing model parameters.
        gaussians (GaussianModel): Gaussian model object.
        load_iteration (int, optional): Training iteration to load. Defaults to None.
        shuffle (bool, optional): Whether to randomly shuffle train and test camera lists. Defaults to True.
        resolution_scales (list, optional): List of camera resolution scaling factors. Defaults to [1.0].
        """
        # Save model path
        self.model_path = args.model_path
        # Initialize loaded iteration to None
        self.loaded_iter = None
        # Save Gaussian model object
        self.gaussians = gaussians

        # If load iteration is specified
        if load_iteration:
            # If load iteration is -1
            if load_iteration == -1:
                # Search for max iteration in point_cloud directory under model path
                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
            else:
                # Use specified load iteration directly
                self.loaded_iter = load_iteration
            # Print training model iteration being loaded
            print("Loading trained model at iteration {}".format(self.loaded_iter))

        # Initialize training camera dictionary
        self.train_cameras = {}
        # Initialize test camera dictionary
        self.test_cameras = {}

        # If sparse directory exists under source path
        if os.path.exists(os.path.join(args.source_path, "sparse")):
            # Call Colmap scene loading callback to load scene info
            scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
        # If transforms_train.json file exists under source path
        elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
            # Print info indicating Blender dataset files found
            print("Found transforms_train.json file, assuming Blender data set!")
            # Call Blender scene loading callback to load scene info
            scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
        else:
            # If scene type cannot be recognized, raise assertion error
            assert False, "Could not recognize scene type!"

        # 如果没有指定加载迭代次数
        if not self.loaded_iter:
            # 打开源点云文件和目标点云文件
            with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
                # 将源文件内容复制到目标文件
                dest_file.write(src_file.read())
            # 初始化 JSON 相机列表
            json_cams = []
            # 初始化相机列表
            camlist = []
            # 如果场景信息中包含测试相机
            if scene_info.test_cameras:
                # 将测试相机添加到相机列表
                camlist.extend(scene_info.test_cameras)
            # 如果场景信息中包含训练相机
            if scene_info.train_cameras:
                # 将训练相机添加到相机列表
                camlist.extend(scene_info.train_cameras)
            # 遍历相机列表
            for id, cam in enumerate(camlist):
                # 将相机信息转换为 JSON 格式并添加到 JSON 相机列表
                json_cams.append(camera_to_JSON(id, cam))
            # 打开相机信息 JSON 文件
            with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
                # 将 JSON 相机列表写入文件
                json.dump(json_cams, file)

        # 如果需要打乱相机列表
        if shuffle:
            # 对训练相机列表进行随机打乱
            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling
            # 对测试相机列表进行随机打乱
            random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling

        # 保存相机的范围
        self.cameras_extent = scene_info.nerf_normalization["radius"]

        # 遍历分辨率缩放比例列表
        for resolution_scale in resolution_scales:
            # 打印当前分辨率缩放比例
            print("Resolution: ", resolution_scale)
            # 打印正在加载训练相机的信息
            print("Loading Training Cameras")
            # 根据分辨率缩放比例和参数加载训练相机
            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
            # 打印正在加载测试相机的信息
            print("Loading Test Cameras")
            # 根据分辨率缩放比例和参数加载测试相机
            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)

        # 如果指定了加载迭代次数
        if self.loaded_iter:
            # 从指定迭代次数的点云文件中加载高斯模型
            self.gaussians.load_ply(os.path.join(self.model_path,
                                                           "point_cloud",
                                                           "iteration_" + str(self.loaded_iter),
                                                           "point_cloud.ply"))
        else:
            # 根据场景信息的点云和相机范围创建高斯模型
            self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent)

    def save(self, iteration):
        point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
        self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))

    def getTrainCameras(self, scale=1.0):
        return self.train_cameras[scale]

    def getTestCameras(self, scale=1.0):
        return self.test_cameras[scale]