#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import sys
import json
import numpy as np
from PIL import Image
from typing import NamedTuple
from d3gs.scene.colmap_loader import (
    read_extrinsics_text, read_intrinsics_text, qvec2rotmat,
    read_extrinsics_binary, read_intrinsics_binary,
    read_points3D_binary, read_points3D_text
)
from d3gs.utils.graphics_utils import getWorld2View2, focal2fov, fov2focal

from plyfile import PlyData, PlyElement
from d3gs.scene.gaussian_model import BasicPointCloud

class CameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    image: np.array
    image_path: str
    image_name: str
    width: int
    height: int

class PhysCameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    image: np.array
    image_path: str
    image_name: str
    width: int
    height: int
    view: int
    step: int

class SceneInfo(NamedTuple):
    point_cloud: BasicPointCloud
    train_cameras: list
    test_cameras: list
    nerf_normalization: dict
    ply_path: str

def getNerfppNorm(cam_info):
    def get_center_and_diag(cam_centers):
        cam_centers = np.hstack(cam_centers)
        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
        center = avg_cam_center
        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
        diagonal = np.max(dist)
        return center.flatten(), diagonal

    cam_centers = []

    for cam in cam_info:
        W2C = getWorld2View2(cam.R, cam.T)
        C2W = np.linalg.inv(W2C)
        cam_centers.append(C2W[:3, 3:4])

    center, diagonal = get_center_and_diag(cam_centers)
    radius = diagonal * 1.1

    translate = -center

    return {"translate": translate, "radius": radius}


def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
    cam_infos = []
    for idx, key in enumerate(cam_extrinsics):
        sys.stdout.write('\r')
        # the exact output you're looking for:
        sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
        sys.stdout.flush()

        extr = cam_extrinsics[key]
        intr = cam_intrinsics[extr.camera_id]
        height = intr.height
        width = intr.width

        uid = intr.id
        R = np.transpose(qvec2rotmat(extr.qvec))
        T = np.array(extr.tvec)

        if intr.model=="SIMPLE_PINHOLE":
            focal_length_x = intr.params[0]
            FovY = focal2fov(focal_length_x, height)
            FovX = focal2fov(focal_length_x, width)
        elif intr.model=="PINHOLE":
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)
            FovX = focal2fov(focal_length_x, width)
        else:
            assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

        image_path = os.path.join(images_folder, os.path.basename(extr.name))
        image_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(image_path)

        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                              image_path=image_path, image_name=image_name, width=width, height=height)
        cam_infos.append(cam_info)
    sys.stdout.write('\n')
    return cam_infos


def fetchPly(path):
    plydata = PlyData.read(path)
    vertices = plydata['vertex']
    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
    normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
    return BasicPointCloud(points=positions, colors=colors, normals=normals)


def storePly(path, xyz, rgb):
    # Define the dtype for the structured array
    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
            ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
    
    normals = np.zeros_like(xyz)

    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))

    # Create the PlyData object and write to file
    vertex_element = PlyElement.describe(elements, 'vertex')
    ply_data = PlyData([vertex_element])
    ply_data.write(path)


def readColmapSceneInfo(path, images, eval, llffhold=8):
    try:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
    except:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

    reading_dir = "images" if images == None else images
    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info


def readNeuMASyntheticCameras(path, transformsfile, white_background, extension=".png", init_frame=None, exclude_steps=[-1], used_views=None, **kwargs):  # 其他参数
    """Read the camera setup from the NeuMA Synthetic dataset."""
    cam_infos = list()
    subfolder = transformsfile.split(".")[0]
    print(f"Reading NeuMA Synthetic [{subfolder}] Cameras with init_frame={init_frame}...")

    # check how many views do we have automatically
    views = set()
    steps = set()
    data = os.listdir(os.path.join(path, subfolder))  # 每张图片的名称是 d_0_000.png
    for d in data:  # 遍历所有图片
        # 从右边开始，以 "_" 为分隔符，最多分割1次： [d_0, 000.png]
        view = str(d.rsplit("_", 1)[0])  # 视图名称，对于d_0_000.png，得到的结果是d_0
        if used_views is None or len(used_views) == 0 or view in used_views:  # 如果used_views为空，或者used_views为None，或者view在used_views中，则将view添加到views中
            views.add(view)   # 只处理那些在used_views中的视图
        step = int(d.rsplit("_", 1)[1].split(".")[0])  # 步骤名称，对于d_0_000.png，得到的结果是000.png -> 000
        if step not in exclude_steps:
            steps.add(step)
    views = sorted(list(views))
    steps = sorted(list(steps))
    print(f"Views found: {views if len(views) < 20 else views[:20]} {'' if len(views) < 20 else f'#all: {len(views)} ...'}\n"
          f"Steps found: {steps if len(steps) < 20 else steps[:20]} {'' if len(steps) < 20 else f'#all: {len(steps)} ...'}")

    idx = 0
    
    # 这边才开始读取呢！！！
    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)  # 读取相机位姿文件
        meta_info = dict()
        for entry in contents:
            file_path = entry.pop("file_path")   # 获取文件路径
            meta_info[file_path] = entry # 将相机位姿信息存储在meta_info字典中 key: 文件路径, value: 相机位姿信息
        # only read the first frame if `init_frame` is set
        steps = [init_frame] if init_frame is not None else steps
        for view in views:
            for step in steps:
                file_path_to_fetch = f"./{subfolder}/{view}_{step:03d}{extension}"  # 获取文件路径

                assert file_path_to_fetch in meta_info, f"File {file_path_to_fetch} not found in meta_info!"

                # NeRF 'transform_matrix' is a camera-to-world transform
                c2w = np.array(meta_info[file_path_to_fetch]["c2w"])
                if c2w.shape[0] == 3: # (3, 4)
                    c2w = np.concatenate([c2w, np.array([[0, 0, 0, 1]])], axis=0)
                # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
                c2w[:3, 1:3] *= -1

                # get the world-to-camera transform and set R, T
                w2c = np.linalg.inv(c2w)
                R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
                T = w2c[:3, 3]

                image_name = os.path.join(path, file_path_to_fetch)
                image = Image.open(image_name)

                im_data = np.array(image.convert("RGBA"))

                bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])

                norm_data = im_data / 255.0
                arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
                image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")

                intrinsics = meta_info[file_path_to_fetch]["intrinsic"]
                focalx = intrinsics[0][0]
                focaly = intrinsics[1][1]
                FovX = focal2fov(focalx, image.size[0])
                FovY = focal2fov(focaly, image.size[1])
                # 记录相机信息
                cam_infos.append(
                    PhysCameraInfo(
                        uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                        image_path=image_name, image_name=image_name,
                        width=image.size[0], height=image.size[1],
                        view=view, step=step
                    )
                )  # 已经是记录了第view个视角，第step个步骤的相机信息
                idx += 1

    # cam_infos: 相机信息列表，每个元素是一个PhysCameraInfo对象，包含相机信息，按照视图 + 步骤排序后的顺序展开排列
    # views: 相机视图列表 ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8', 'd_9', ..., 'd_n']
    # steps: 相机步骤列表 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ..., T]
    # 到目前为止，都没涉及到torch和cuda
    return {'cam_infos': cam_infos, 'views': views, 'steps': steps}


def readRealCaptureCameras(
    path,
    white_background,
    extension=".jpg",
    width=1920,             # hardcoded for RealCapture
    height=1080,            # hardcoded for RealCapture
    init_frame=None,
    exclude_steps=[-1],
    used_views=None,
    read_mask_only=False,    # if the images are binary masks
    **kwargs   # 其他参数
):
    import cv2


    """Read the camera setup from the RealCapture dataset."""
    cam_infos = list()
    print(f"Reading RealCapture Cameras...")

    cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")  # 读取相机内参
    cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)

    K = np.array([
        [cam_intrinsics[1].params[0] * width / 4752, 0, width / 2],
        [0, cam_intrinsics[1].params[1] * height / 2672, height / 2],
        [0, 0, 1]
    ])

    FovY = focal2fov(K[0][0], height)
    FovX = focal2fov(K[1][1], width)

    with open(os.path.join(path, 'cameras_calib.json'), 'r') as f:
        cam_calib = json.load(f)  # 读取相机位姿，三个角度

    # check how many views do we have automatically
    views = set()
    steps = set()
    data = os.listdir(os.path.join(path, "dynamics"))
    for d in data:  # 遍历所有图片
        view = str(d.rsplit("_", 1)[0])
        if used_views is None or len(used_views) == 0 or view in used_views:  # 如果used_views为空，或者used_views为None，或者view在used_views中，则将view添加到views中
            views.add(view)
        step = int(d.rsplit("_", 1)[1].split(".")[0])
        if step not in exclude_steps:
            steps.add(step)
    views = sorted(list(views))
    steps = sorted(list(steps))
    print(f"Views found: {views if len(views) < 20 else views[:20]} {'' if len(views) < 20 else f'#all: {len(views)} ...'}\n"
          f"Steps found: {steps if len(steps) < 20 else steps[:20]} {'' if len(steps) < 20 else f'#all: {len(steps)} ...'}")

    idx = 0
    steps = [init_frame] if init_frame is not None else steps
    for view in views:
        rvecs = cam_calib[view]["rvecs"]
        tvecs = cam_calib[view]["tvecs"]
        rot_mat, _ = cv2.Rodrigues(np.array(rvecs))
        R = np.transpose(rot_mat)
        T = np.array(tvecs).reshape(3)

        for step in steps:  # 遍历所有帧
            file_path_to_fetch = f"./dynamics/{view}_{step}{extension}"
            image_name = os.path.join(path, file_path_to_fetch)
            mask_name = image_name.replace("/dynamics/", "/dynamic_masks/").replace(extension, ".png")

            bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])

            if read_mask_only:
                mask = Image.open(mask_name)
                image = Image.fromarray(np.repeat(np.array(mask)[:, :, np.newaxis], 3, axis=-1), "RGB")
            else:
                image = Image.open(image_name)
                im_data = np.array(image)
                mask = Image.open(mask_name)
                mask = np.array(mask)[:, :, np.newaxis] / 255.0
                arr = (im_data / 255.0) * mask + bg * (1 - mask)
                image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")

            cam_infos.append(
                PhysCameraInfo(
                    uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                    image_path=image_name, image_name=image_name,
                    width=width, height=height,
                    view=view, step=step
                )
            )
            idx += 1

    return {'cam_infos': cam_infos, 'views': views, 'steps': steps}
