#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from scene.gaussian_model_latent_strands import GaussianModelHair
from scene.gaussian_model_strands import GaussianModelCurves
from scene.gaussian_render import GaussRenderer
from scene.gaussian_weight import GaussianWeightPred
from arguments import ModelParams
from utils.camera_utils import HybirdFacialTextureDataset, cameraList_from_camInfos, camera_to_JSON, PreloadHybirdFacialTextureDataset,PreloadHybirdFacialTextureDatasetCPU

class Scene:

    gaussians : GaussianModel

    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, pointcloud_path=None, shuffle=True, resolution_scales=[1.0], scene_suffix=""):
        """b
        :param path: Path to colmap scene main folder.
        """
        self.model_path = args.model_path
        self.loaded_iter = None
        self.gaussians = gaussians

        if load_iteration:
            if load_iteration == -1:
                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
            else:
                self.loaded_iter = load_iteration
            print("Loading trained model at iteration {}".format(self.loaded_iter))

        self.train_cameras = {}
        self.test_cameras = {}
        self.train_cameras_val = {}
        self.test_cameras_val = {}

        if os.path.exists(os.path.join(args.source_path, "sparse")):
            scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, 2, args.interpolate_cameras, args.speed_up, args.max_frames, args.frame_offset)
        elif os.path.exists(os.path.join(args.source_path, "flame_fitting")) and args.static_nerf:
            print("Found FLAME parameter, assuming static NeRF data set!")
            scene_info = sceneLoadTypeCallbacks["StaticNerf"](args.source_path, args.flame_mesh_dir, args.white_background, args.eval, target_path=args.target_path, start_time_step = args.start_time_step, num_time_steps = args.num_time_steps)
        elif os.path.exists(os.path.join(args.source_path, "flame_fitting")):
            print("Found FLAME parameter, assuming dynamic NeRF data set!")
            scene_info = sceneLoadTypeCallbacks["DynamicNerf"](args.source_path, args.white_background, args.eval, target_path=args.target_path, start_time_step = args.start_time_step, num_time_steps = args.num_time_steps)
        elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
            print("Found transforms_train.json file, assuming Blender data set!")
            scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
        elif os.path.exists(os.path.join(args.source_path, "projection.npy")) or os.path.exists(os.path.join(args.source_path, "cameras.npz")):
            scene_info = sceneLoadTypeCallbacks["Synthetic"](args.source_path, args.images, args.eval)
        else:
            assert False, "Could not recognize scene type!"

        # if not self.loaded_iter:
        #     # import ipdb;ipdb.set_trace()
        #     with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
        #         dest_file.write(src_file.read())
        #     json_cams = []
        #     camlist = []
        #     if scene_info.test_cameras:
        #         camlist.extend(scene_info.test_cameras)
        #     if scene_info.train_cameras:
        #         camlist.extend(scene_info.train_cameras)
        #     for id, cam in enumerate(camlist):
        #         json_cams.append(camera_to_JSON(id, cam))
        #     with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
        #         json.dump(json_cams, file)

        if shuffle:
            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling
            random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling

        self.cameras_extent = scene_info.nerf_normalization["radius"]

        for resolution_scale in resolution_scales:
            # print("Loading Training Cameras")
            # self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
            # args.select_time_step = -1
            # print("Loading Test Cameras")
            # self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
            # print("Loading Training Val Cameras")
            # self.train_cameras_val[resolution_scale] = HybirdFacialTextureDataset(scene_info.train_cameras, resolution_scale, args)
            # print("Loading Test Val Cameras")
            # self.test_cameras_val[resolution_scale] = HybirdFacialTextureDataset(scene_info.test_cameras, resolution_scale, args)
            # self.test_cameras_val[resolution_scale] = None
            print("Loading Training Cameras")
            # self.train_cameras[resolution_scale] = PreloadHybirdFacialTextureDataset(scene_info.train_cameras, resolution_scale, args)
            # self.train_cameras[resolution_scale] = PreloadHybirdFacialTextureDatasetCPU(scene_info.train_cameras, resolution_scale, args)
            self.train_cameras[resolution_scale] = HybirdFacialTextureDataset(scene_info.train_cameras, resolution_scale, args)
            print("Loading Test Cameras")
            # self.test_cameras[resolution_scale] = PreloadHybirdFacialTextureDataset(scene_info.test_cameras, resolution_scale, args)
            # self.test_cameras[resolution_scale] = PreloadHybirdFacialTextureDatasetCPU(scene_info.test_cameras, resolution_scale, args)
            self.test_cameras[resolution_scale] = HybirdFacialTextureDataset(scene_info.test_cameras, resolution_scale, args)

        self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
        # if self.loaded_iter:
        #     if pointcloud_path is not None:
        #         print(f'Loading point cloud from {pointcloud_path}')
        #         self.gaussians.load_ply(pointcloud_path)
        #     else:
        #         self.gaussians.load_ply(os.path.join(self.model_path,
        #                                              f"point_cloud{scene_suffix}",
        #                                              "iteration_" + str(self.loaded_iter),
        #                                              "raw_point_cloud.ply"))

    def save(self, iteration):
        point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
        self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))


    def getTrainCameras(self, scale=1.0):
        return self.train_cameras[scale]

    def getTestCameras(self, scale=1.0):
        return self.test_cameras[scale]
    
    def getTrainCamerasVal(self, scale=1.0):
        return self.train_cameras_val[scale]

    def getTestCamerasVal(self, scale=1.0):
        return self.test_cameras_val[scale]