#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import random
import json
from gaussian_splatting.utils.system_utils import searchForMaxIteration
from gaussian_splatting.scene.dataset_readers import sceneLoadTypeCallbacks
from gaussian_splatting.scene.gaussian_model import GaussianModel
from gaussian_splatting.arguments import ModelParams
from gaussian_splatting.utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
from interpolation.optimize_pose_linear import *
from interpolation.optimize_pose_bspline import *
from interpolation.optimize_pose_adaptive import *
from interpolation.optimize_pose_ada_bspline import *
# from edit_tools.trajectory_edit import apply_trajectory_control_to_lightcam
from edit_tools.trajectory_edit_origin import apply_trajectory_control
# from edit_tools.novel_view import generate_variable_speed_interpolated_lightcam_list#, generate_novel_view_groups
from edit_tools.novel_view_origin import generate_variable_speed_interpolated_camera_list, generate_novel_view_groups
from plot.plot import analyze_camera_trajectory
from plot.plot_pro import analyze_camera_trajectory_3d

class Scene:

    gaussians : GaussianModel

    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
        """b
        :param path: Path to colmap scene main folder.
        """
        self.model_path = args.model_path
        self.loaded_iter = None
        self.gaussians = gaussians

        # Load model if trained point cloud is available
        if load_iteration:
            if load_iteration == -1:
                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
            else:
                self.loaded_iter = load_iteration
            print("Loading trained model at iteration {}".format(self.loaded_iter))

        # Initialize camera containers
        self.train_cameras = {}
        self.test_cameras = {}

        # Load dataset
        if os.path.exists(os.path.join(args.source_path, "sparse")):
            scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.depths, args.eval, args.train_test_exp)
        elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
            print("Found transforms_train.json file, assuming Blender data set!")
            scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.depths, args.eval)
        else:
            assert False, "Could not recognize scene type!"

        # If first iteration, copy input ply and camera jsons
        if not self.loaded_iter:
            with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
                dest_file.write(src_file.read())
            json_cams = []
            camlist = []
            if scene_info.test_cameras:
                camlist.extend(scene_info.test_cameras)
            if scene_info.train_cameras:
                camlist.extend(scene_info.train_cameras)
            for id, cam in enumerate(camlist):
                json_cams.append(camera_to_JSON(id, cam))
            with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
                json.dump(json_cams, file)

        if shuffle:
            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling
            random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling

        self.cameras_extent = scene_info.nerf_normalization["radius"]

        # Build camera lists (multi-resolution)
        for resolution_scale in resolution_scales:
            print("Loading Training Cameras")
            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, False)
            print("Loading Test Cameras")
            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args, scene_info.is_nerf_synthetic, True)

        if self.loaded_iter:
            self.gaussians.load_ply(os.path.join(self.model_path,
                                                           "point_cloud",
                                                           "iteration_" + str(self.loaded_iter),
                                                           "point_cloud.ply"), args.train_test_exp)
        else:
            self.gaussians.create_from_pcd(scene_info.point_cloud, scene_info.train_cameras, self.cameras_extent)

    def save(self, iteration):
        point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
        self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
        exposure_dict = {
            image_name: self.gaussians.get_exposure_from_name(image_name).detach().cpu().numpy().tolist()
            for image_name in self.gaussians.exposure_mapping
        }

        with open(os.path.join(self.model_path, "exposure.json"), "w") as f:
            json.dump(exposure_dict, f, indent=2)

    def getTrainCameras(self, scale=1.0):
        return self.train_cameras[scale]

    def getLinearInterpolatedCameras(self, scale=1.0, nums_inserted=10):
        # CameraList
        origin_viewpoints = self.train_cameras[scale]
        nums_inserted = nums_inserted
        Interpolated_viewpoints = generate_linear_interpolated_camera_list(origin_viewpoints, nums_inserted)
        analyze_camera_trajectory_3d(Interpolated_viewpoints, title=f"Linear 3D Trajectory:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs/interp-method", f"linear_interpolated_3D_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
        return Interpolated_viewpoints
    
    def getBspineInterpolatedCameras(self, scale=1.0, nums_inserted=10):
        # CameraList
        origin_viewpoints = self.train_cameras[scale]
        nums_inserted = nums_inserted
        Interpolated_viewpoints = generate_bspline_interpolated_camera_list(origin_viewpoints, nums_inserted)
        analyze_camera_trajectory_3d(Interpolated_viewpoints, title=f"BSpline Trajectory:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs/interp-method", f"bspline_interpolated_3D_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
        return Interpolated_viewpoints
    
    # def getAdaptiveInterpolatedCameras(self, scale=1.0, interp_multiplier=10):
    #     # CameraList
    #     origin_viewpoints = self.train_cameras[scale]
    #     Interpolated_viewpoints = generate_adaptive_interpolated_camera_list(origin_viewpoints, interp_multiplier)
    #     visualize_camera_trajectory(Interpolated_viewpoints, title=f"Adaptive Trajectory:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs", f"adaptive_interpolated_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
    #     return Interpolated_viewpoints

    def getEditedAdaptiveInterpolatedCameras(self, scale=1.0, interp_multiplier=10):
        # CameraList
        origin_viewpoints = self.train_cameras[scale]
        # edited_viewpoint = apply_trajectory_control_to_lightcam(origin_viewpoints)
        edited_viewpoint = apply_trajectory_control(origin_viewpoints)
        Interpolated_viewpoints = generate_adaptive_bspline_interpolated_camera_list(edited_viewpoint, interp_multiplier)
        analyze_camera_trajectory(Interpolated_viewpoints, title=f"Edited Adaptive Trajectory:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs", f"adaptive_interpolated_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
        return Interpolated_viewpoints

    def getEditedSpeedInterpolatedCameras(self, scale=1.0, interp_multiplier=5, speed_profile=None):
        # CameraList
        origin_viewpoints = self.train_cameras[scale]
        # edited_viewpoint = apply_trajectory_control_to_lightcam(origin_viewpoints)
        edited_viewpoint = apply_trajectory_control(origin_viewpoints)
        def speed_fn(t):
            return np.sin(np.pi * t) * 0.25 + 1  # smooth & safe
        if speed_profile == 'None':
            speed_profile=speed_fn
        Interpolated_viewpoints = generate_variable_speed_interpolated_camera_list(edited_viewpoint, interp_multiplier, speed_profile)
        analyze_camera_trajectory(Interpolated_viewpoints, title=f"Edited Ada-speed Trajectory:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs", f"ada_speed_interpolated_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
        analyze_camera_trajectory_3d(Interpolated_viewpoints, title=f"Edited Ada-speed 3D Trajectory:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs/interp-method", f"ada_speed_interpolated_3D_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
        return Interpolated_viewpoints
    
    # # Generate novel view camera
    def getNovelViewInterpolatedCameras(self, scale=1.0, interp_multiplier=5, speed_profile=None):
        # CameraList
        origin_viewpoints = self.train_cameras[scale]
        img_num = len(origin_viewpoints)
        edited_viewpoint = apply_trajectory_control(origin_viewpoints)
        # num_groups = random.randint(1, 5)
        num_groups = 3
        novels = generate_novel_view_groups(edited_viewpoint, num_groups=num_groups, max_keyframes=5, num_frames=interp_multiplier*img_num)
        for i,view in enumerate(novels):
            analyze_camera_trajectory(view, title=f"Novel View Trajectory {i}:{self.model_path.split('/')[-1]}_{scale}", save_path=os.path.join("./imgs", f"novel_view_{i}_interpolated_trajectory_{self.model_path.split('/')[-1]}_{scale}.png"))
        return novels
    
    def getTestCameras(self, scale=1.0, interpolate=False):
        if interpolate is False:
            return self.test_cameras[scale]
        else:
            return None
        