#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import glob
import json
import os
import sys
from pathlib import Path
from typing import NamedTuple

import numpy as np
from PIL import Image
from plyfile import PlyData, PlyElement

from scene.colmap_loader import (
    qvec2rotmat,
    read_extrinsics_binary,
    read_extrinsics_text,
    read_intrinsics_binary,
    read_intrinsics_text,
    read_points3D_binary,
    read_points3D_text,
)
from utils.graphics_utils import BasicPointCloud, focal2fov, fov2focal, getWorld2View2
from utils.sh_utils import SH2RGB


class CameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    image: np.array
    image_path: str
    image_name: str
    width: int
    height: int


class SceneInfo(NamedTuple):
    point_cloud: BasicPointCloud
    train_cameras: list
    test_cameras: list
    nerf_normalization: dict
    ply_path: str


def getNerfppNorm(cam_info):
    def get_center_and_diag(cam_centers):
        cam_centers = np.hstack(cam_centers)
        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
        center = avg_cam_center
        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
        diagonal = np.max(dist)
        return center.flatten(), diagonal

    cam_centers = []

    for cam in cam_info:
        W2C = getWorld2View2(cam.R, cam.T)
        C2W = np.linalg.inv(W2C)
        cam_centers.append(C2W[:3, 3:4])

    center, diagonal = get_center_and_diag(cam_centers)
    radius = diagonal * 1.1

    translate = -center

    return {"translate": translate, "radius": radius}


def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
    cam_infos = []
    for idx, key in enumerate(cam_extrinsics):
        sys.stdout.write("\r")
        # the exact output you're looking for:
        sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics)))
        sys.stdout.flush()

        extr = cam_extrinsics[key]
        intr = cam_intrinsics[extr.camera_id]
        height = intr.height
        width = intr.width

        uid = intr.id
        R = np.transpose(qvec2rotmat(extr.qvec))
        T = np.array(extr.tvec)

        if intr.model == "SIMPLE_PINHOLE":
            focal_length_x = intr.params[0]
            FovY = focal2fov(focal_length_x, height)
            FovX = focal2fov(focal_length_x, width)
        elif intr.model == "PINHOLE":
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)
            FovX = focal2fov(focal_length_x, width)
        else:
            assert (
                False
            ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

        image_path = os.path.join(images_folder, os.path.basename(extr.name))
        image_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(image_path)

        cam_info = CameraInfo(
            uid=uid,
            R=R,
            T=T,
            FovY=FovY,
            FovX=FovX,
            image=image,
            image_path=image_path,
            image_name=image_name,
            width=width,
            height=height,
        )
        cam_infos.append(cam_info)

    cam_infos = sorted(cam_infos, key=lambda x: x.image_name)

    sys.stdout.write("\n")
    return cam_infos


def fetchPly(path):
    plydata = PlyData.read(path)
    vertices = plydata["vertex"]
    positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
    colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0
    normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T
    return BasicPointCloud(points=positions, colors=colors, normals=normals)


def storePly(path, xyz, rgb):
    # Define the dtype for the structured array
    dtype = [
        ("x", "f4"),
        ("y", "f4"),
        ("z", "f4"),
        ("nx", "f4"),
        ("ny", "f4"),
        ("nz", "f4"),
        ("red", "u1"),
        ("green", "u1"),
        ("blue", "u1"),
    ]

    normals = np.zeros_like(xyz)

    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))

    # Create the PlyData object and write to file
    vertex_element = PlyElement.describe(elements, "vertex")
    ply_data = PlyData([vertex_element])
    ply_data.write(path)


def readColmapSceneInfo(
    path, images, eval, llffhold=8, init_type="sfm", num_pts=100000
):
    try:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
    except:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

    reading_dir = "images" if images == None else images
    cam_infos_unsorted = readColmapCameras(
        cam_extrinsics=cam_extrinsics,
        cam_intrinsics=cam_intrinsics,
        images_folder=os.path.join(path, reading_dir),
    )
    cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    if init_type == "sfm":
        ply_path = os.path.join(path, "sparse/0/points3D.ply")
        bin_path = os.path.join(path, "sparse/0/points3D.bin")
        txt_path = os.path.join(path, "sparse/0/points3D.txt")
        if not os.path.exists(ply_path):
            print(
                "Converting point3d.bin to .ply, will happen only the first time you open the scene."
            )
            try:
                xyz, rgb, _ = read_points3D_binary(bin_path)
            except:
                xyz, rgb, _ = read_points3D_text(txt_path)
            storePly(ply_path, xyz, rgb)
    elif init_type == "random":
        ply_path = os.path.join(path, "random.ply")
        print(f"Generating random point cloud ({num_pts})...")

        xyz = np.random.random((num_pts, 3)) * nerf_normalization["radius"] * 3 * 2 - (
            nerf_normalization["radius"] * 3
        )

        num_pts = xyz.shape[0]
        shs = np.random.random((num_pts, 3)) / 255.0
        pcd = BasicPointCloud(
            points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
        )

        storePly(ply_path, xyz, SH2RGB(shs) * 255)
    else:
        print("Please specify a correct init_type: random or sfm")
        exit(0)

    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
    cam_infos = []

    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)
        fovx = contents["camera_angle_x"]

        frames = contents["frames"]
        for idx, frame in enumerate(frames):
            cam_name = os.path.join(path, frame["file_path"] + extension)

            # NeRF 'transform_matrix' is a camera-to-world transform
            c2w = np.array(frame["transform_matrix"])
            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
            c2w[:3, 1:3] *= -1

            # get the world-to-camera transform and set R, T
            w2c = np.linalg.inv(c2w)
            R = np.transpose(
                w2c[:3, :3]
            )  # R is stored transposed due to 'glm' in CUDA code
            T = w2c[:3, 3]

            image_path = os.path.join(path, cam_name)
            image_name = Path(cam_name).stem
            image = Image.open(image_path)

            im_data = np.array(image.convert("RGBA"))

            bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])

            norm_data = im_data / 255.0
            arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (
                1 - norm_data[:, :, 3:4]
            )
            image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")

            fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
            FovY = fovy
            FovX = fovx

            cam_infos.append(
                CameraInfo(
                    uid=idx,
                    R=R,
                    T=T,
                    FovY=FovY,
                    FovX=FovX,
                    image=image,
                    image_path=image_path,
                    image_name=image_name,
                    width=image.size[0],
                    height=image.size[1],
                )
            )

    cam_infos = sorted(cam_infos, key=lambda x: x.image_name)
    return cam_infos


def readNerfSyntheticInfo(path, white_background, eval, num_pts, extension=".png"):
    print("Reading Training Transforms")
    train_cam_infos = readCamerasFromTransforms(
        path, "transforms_train.json", white_background, extension
    )
    print("Reading Test Transforms")
    test_cam_infos = readCamerasFromTransforms(
        path, "transforms_test.json", white_background, extension
    )

    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "points3d.ply")
    # Since this data set has no colmap data, we start with random points
    print(f"Generating random point cloud ({num_pts})...")

    # We create random points inside the bounds of the synthetic Blender scenes
    xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
    shs = np.random.random((num_pts, 3)) / 255.0
    pcd = BasicPointCloud(
        points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
    )

    storePly(ply_path, xyz, SH2RGB(shs) * 255)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


def readCamerasInstantNGPTransforms(
    transforms_file: str, images_folder: str, test: bool = None
):
    cam_infos = []

    with open(transforms_file) as json_file:
        contents = json.load(json_file)

    if "camera_model" in contents and contents["camera_model"] != "PINHOLE":
        assert False, "Only PINHOLE camera model supported!"

    fx = contents["fl_x"]
    fy = contents["fl_y"]
    w = contents["w"]
    h = contents["h"]

    if test and "test_frames" in contents:
        frames = contents["test_frames"]
    else:
        frames = contents["frames"]

    for idx, frame in enumerate(frames):

        # NeRF 'transform_matrix' is a camera-to-world transform
        c2w = np.array(frame["transform_matrix"])
        # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
        c2w[:3, 1:3] *= -1

        # world coordinate transform: inverse for map colmap gravity guess (-y) to nerfstudio convention (+z)
        # to use SfM pointcloud
        # https://github.com/nerfstudio-project/nerfstudio/blob/ec10c49d51cfebc52618ece1221ec4511ac19b67/nerfstudio/data/dataparsers/colmap_dataparser.py#L169
        c2w = c2w[np.array([1, 0, 2, 3]), :]
        c2w[2, :] *= -1

        # get the world-to-camera transform and set R, T
        w2c = np.linalg.inv(c2w)
        R = np.transpose(
            w2c[:3, :3]
        )  # R is stored transposed due to 'glm' in CUDA code
        T = w2c[:3, 3]

        image_name = frame["file_path"]
        image_path = os.path.join(images_folder, image_name)
        image = Image.open(image_path)
        image = image.convert("RGB")

        assert image.size[0] == w and image.size[1] == h, "Image size mismatch!"

        FovX = focal2fov(fx, w)
        FovY = focal2fov(fy, h)

        cam_infos.append(
            CameraInfo(
                uid=idx,
                R=R,
                T=T,
                FovY=FovY,
                FovX=FovX,
                image=image,
                image_path=image_path,
                image_name=image_name,
                width=w,
                height=h,
            )
        )

    cam_infos = sorted(cam_infos, key=lambda x: x.image_name)
    return cam_infos


def init_scannetpp_pcd(init_type, path, num_pts, nerf_normalization):

    if init_type == "sfm":
        ply_path = os.path.join(path, "colmap/points3D.ply")
        bin_path = os.path.join(path, "colmap/points3D.bin")
        txt_path = os.path.join(path, "colmap/points3D.txt")
        if not os.path.exists(ply_path):
            print(
                "Converting point3d to .ply, will happen only the first time you open the scene."
            )
            try:
                xyz, rgb, _ = read_points3D_binary(bin_path)
            except:
                xyz, rgb, _ = read_points3D_text(txt_path)
            storePly(ply_path, xyz, rgb)

    elif init_type == "random":
        ply_path = os.path.join(path, "random.ply")
        print(f"Generating random point cloud ({num_pts})...")

        xyz = np.random.random((num_pts, 3)) * nerf_normalization["radius"] * 3 * 2 - (
            nerf_normalization["radius"] * 3
        )

        num_pts = xyz.shape[0]
        shs = np.random.random((num_pts, 3)) / 255.0
        pcd = BasicPointCloud(
            points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
        )

        storePly(ply_path, xyz, SH2RGB(shs) * 255)

    else:
        print("Please specify a correct init_type: random or sfm")
        exit(0)

    return ply_path


def readScannetppDSLRInfo(path, eval, num_pts=100_000, init_type="sfm"):

    print("Reading train transforms")
    train_cam_infos = readCamerasInstantNGPTransforms(
        os.path.join(path, "nerfstudio/transforms_undistorted.json"),
        os.path.join(path, "undistorted_images"),
        test=False,
    )

    print("Reading test transforms")
    test_cam_infos = readCamerasInstantNGPTransforms(
        os.path.join(path, "nerfstudio/transforms_undistorted.json"),
        os.path.join(path, "undistorted_images"),
        test=True,
    )

    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    print(
        f"Num train cameras: {len(train_cam_infos)}, num test cameras: {len(test_cam_infos)}"
    )

    nerf_normalization = getNerfppNorm(train_cam_infos)
    ply_path = init_scannetpp_pcd(init_type, path, num_pts, nerf_normalization)

    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


def readScannetpIphoneInfo(path, eval, num_pts=100_000, llffhold=8, init_type="sfm"):

    cam_infos = readCamerasInstantNGPTransforms(
        os.path.join(path, "nerfstudio/transforms_undistorted.json"),
        os.path.join(path, "undistorted_images"),
        test=False,
    )
    if llffhold > 0:
        # take 1/llffhold of the cameras for testing
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]

        # take the rest of the cameras for training cameras
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]

    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    print(
        f"Num train cameras: {len(train_cam_infos)}, num test cameras: {len(test_cam_infos)}"
    )

    ply_path = init_scannetpp_pcd(init_type, path, num_pts, nerf_normalization)

    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


def readFoxInfo(source_path, eval, num_pts, llffhold=8):

    print("Reading train transforms")
    cam_infos = readCamerasInstantNGPTransforms(
        os.path.join(source_path, "transforms.json"),
        source_path,
        test=False,
    )
    if llffhold > 0:
        # take 1/llffhold of the cameras for testing
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]

        # take the rest of the cameras for training cameras
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]

    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    print(
        f"Num train cameras: {len(train_cam_infos)}, num test cameras: {len(test_cam_infos)}"
    )

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(source_path, "random.ply")
    print(f"Generating random point cloud ({num_pts})...")

    xyz = np.random.random((num_pts, 3)) * nerf_normalization["radius"] * 3 * 2 - (
        nerf_normalization["radius"] * 3
    )

    num_pts = xyz.shape[0]
    shs = np.random.random((num_pts, 3)) / 255.0
    pcd = BasicPointCloud(
        points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
    )

    storePly(ply_path, xyz, SH2RGB(shs) * 255)

    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


def readCamerasScannetV2(source_path):
    cam_infos = []

    intrinsic_path = os.path.join(source_path, "intrinsic", "intrinsic_color.txt")
    intrinsic = np.fromfile(intrinsic_path, sep=" ").reshape(4, 4)

    fx = intrinsic[0][0]
    fy = intrinsic[1][1]
    cx = intrinsic[0][2]
    cy = intrinsic[1][2]

    h = None
    w = None

    colors_path = os.path.join(source_path, "sensor_data", "*.color.jpg")

    # Note we downsample the frames by 3
    for idx, color_filename in enumerate(sorted(glob.glob(colors_path))):

        image_name = str(Path(color_filename).name)
        pose_filename = color_filename.replace("color.jpg", "pose.txt")

        C2W = np.fromfile(pose_filename, sep=" ").reshape(4, 4)

        # change gravity
        C2W = C2W[np.array([1, 0, 2, 3]), :]

        # get the world-to-camera transform and set R, T
        W2C = np.linalg.inv(C2W)
        R = np.transpose(W2C[:3, :3])
        T = W2C[:3, 3]

        image = Image.open(color_filename)
        image = image.convert("RGB")

        if h is None and w is None:
            w, h = image.size

        assert w == image.size[0] and h == image.size[1], "Image size mismatch!"

        FovX = focal2fov(fx, w)
        FovY = focal2fov(fy, h)

        cam_infos.append(
            CameraInfo(
                uid=idx,
                R=R,
                T=T,
                FovY=FovY,
                FovX=FovX,
                image=image,
                image_path=color_filename,
                image_name=image_name,
                width=w,
                height=h,
            )
        )

    cam_infos = sorted(cam_infos, key=lambda x: x.image_name)
    return cam_infos


def readScannetV2Info(source_path, eval, num_pts=100_000, llffhold=8):
    cam_infos = readCamerasScannetV2(source_path)

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []

    print("Num train cameras: ", len(train_cam_infos))
    print("Num test cameras: ", len(test_cam_infos))

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(source_path, "random.ply")
    print(f"Generating random point cloud ({num_pts})...")

    xyz = np.random.random((num_pts, 3)) * nerf_normalization["radius"] * 3 * 2 - (
        nerf_normalization["radius"] * 3
    )

    num_pts = xyz.shape[0]
    shs = np.random.random((num_pts, 3)) / 255.0
    pcd = BasicPointCloud(
        points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
    )

    storePly(ply_path, xyz, SH2RGB(shs) * 255)

    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


sceneLoadTypeCallbacks = {
    "Colmap": readColmapSceneInfo,
    "Blender": readNerfSyntheticInfo,
    "ScannetppDSLR": readScannetppDSLRInfo,
    "ScannetppIPhone": readScannetpIphoneInfo,
    "Fox": readFoxInfo,
    "ScannetV2": readScannetV2Info,
}
