#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import sys
from PIL import Image
from typing import NamedTuple
from scene.colmap_loader import (
    read_extrinsics_text,
    read_intrinsics_text,
    qvec2rotmat,
    read_extrinsics_binary,
    read_intrinsics_binary,
    read_points3D_binary,
    read_points3D_text,
)
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
import numpy as np
import json
from pathlib import Path
from plyfile import PlyData, PlyElement
from utils.sh_utils import SH2RGB
from scene.gaussian_model import BasicPointCloud
import trimesh
import zipfile
from io import BytesIO


class CameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    image: np.array
    image_path: str
    image_name: str
    width: int
    height: int


class SceneInfo(NamedTuple):
    point_cloud: BasicPointCloud
    train_cameras: list
    test_cameras: list
    nerf_normalization: dict
    ply_path: str


def convert_cam_coords(transform_matrix):
    P = np.array([
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, -1, 0, 0],
        [0, 0, 0, 1]
    ])

    C = np.array([
        [1,  0,  0, 0],
        [0, -1,  0, 0],
        [0,  0, -1, 0],
        [0,  0,  0, 1]
    ])

    new_transform_matrix = P @ transform_matrix @ C
    return new_transform_matrix


def getNerfppNorm(cam_info):
    def get_center_and_diag(cam_centers):
        cam_centers = np.hstack(cam_centers)
        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
        center = avg_cam_center
        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
        diagonal = np.max(dist)
        return center.flatten(), diagonal

    cam_centers = []

    for cam in cam_info:
        W2C = getWorld2View2(cam.R, cam.T)
        C2W = np.linalg.inv(W2C)
        cam_centers.append(C2W[:3, 3:4])

    center, diagonal = get_center_and_diag(cam_centers)
    radius = diagonal * 1.1

    translate = -center

    return {"translate": translate, "radius": radius}


def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
    cam_infos = []
    for idx, key in enumerate(cam_extrinsics):
        sys.stdout.write("\r")
        # the exact output you're looking for:
        sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics)))
        sys.stdout.flush()

        extr = cam_extrinsics[key]
        intr = cam_intrinsics[extr.camera_id]
        height = intr.height
        width = intr.width

        uid = intr.id
        R = np.transpose(qvec2rotmat(extr.qvec))
        T = np.array(extr.tvec)

        if intr.model == "SIMPLE_PINHOLE":
            focal_length_x = intr.params[0]
            FovY = focal2fov(focal_length_x, height)
            FovX = focal2fov(focal_length_x, width)
        elif intr.model == "PINHOLE":
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)
            FovX = focal2fov(focal_length_x, width)
        else:
            assert (
                False
            ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

        image_path = os.path.join(images_folder, os.path.basename(extr.name))
        image_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(image_path)

        cam_info = CameraInfo(
            uid=uid,
            R=R,
            T=T,
            FovY=FovY,
            FovX=FovX,
            image=image,
            image_path=image_path,
            image_name=image_name,
            width=width,
            height=height,
        )
        cam_infos.append(cam_info)
    sys.stdout.write("\n")
    return cam_infos


def fetchObj(path):
    print("##### loading pointcloud from mesh object ######")

    mesh = trimesh.load(path)
    num_points = 5000  # number of points to sample, change accordingly
    points, face_index = trimesh.sample.sample_surface(mesh, num_points)
    face_normals = mesh.face_normals
    # Retrieve the normals for the sampled points based on face index
    sampled_normals = face_normals[face_index]
    try:
        mesh.visual = mesh.visual.to_color()
        vertex_colors = mesh.visual.vertex_colors[:, :3]  # Ignore the alpha channel
        sampled_colors = np.zeros((num_points, 3))

        for i in range(num_points):
            # Get the indices of the vertices of the face
            face_vertices = mesh.faces[face_index[i]]
            # Get the vertex coordinates and their corresponding colors
            vertices = mesh.vertices[face_vertices]
            colors = vertex_colors[face_vertices]

            # Compute barycentric coordinates
            v0 = vertices[1] - vertices[0]
            v1 = vertices[2] - vertices[0]
            v2 = points[i] - vertices[0]
            d00 = np.dot(v0, v0)
            d01 = np.dot(v0, v1)
            d11 = np.dot(v1, v1)
            d20 = np.dot(v2, v0)
            d21 = np.dot(v2, v1)
            denom = d00 * d11 - d01 * d01
            v = (d11 * d20 - d01 * d21) / denom
            w = (d00 * d21 - d01 * d20) / denom
            u = 1.0 - v - w

            # Interpolate color using barycentric coordinates
            sampled_colors[i] = u * colors[0] + v * colors[1] + w * colors[2]
    except:
        # set color to gray
        sampled_colors = np.ones((num_points, 3)) * 0.5
    positions = points
    colors = sampled_colors
    normals = sampled_normals
    positions = np.array(positions)
    colors = np.array(colors) / 255.0
    normals = np.array(normals)

    # for out of view cases need to ignore some
    # print("positions", positions.shape, "colors", colors.shape, "normals", normals.shape)
    # storePly(path.replace('point_cloud.obj','points3d.ply'), positions, colors * 255)

    return BasicPointCloud(points=positions, colors=colors, normals=normals)


def fetchPly(path):
    plydata = PlyData.read(path)
    vertices = plydata["vertex"]
    positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
    colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0
    normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T
    return BasicPointCloud(points=positions, colors=colors, normals=normals)


def storePly(path, xyz, rgb):
    # Define the dtype for the structured array
    dtype = [
        ("x", "f4"),
        ("y", "f4"),
        ("z", "f4"),
        ("nx", "f4"),
        ("ny", "f4"),
        ("nz", "f4"),
        ("red", "u1"),
        ("green", "u1"),
        ("blue", "u1"),
    ]

    normals = np.zeros_like(xyz)

    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))

    # Create the PlyData object and write to file
    vertex_element = PlyElement.describe(elements, "vertex")
    ply_data = PlyData([vertex_element])
    ply_data.write(path)


def readColmapSceneInfo(path, images, eval, llffhold=8):
    try:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
    except:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

    reading_dir = "images" if images == None else images
    cam_infos_unsorted = readColmapCameras(
        cam_extrinsics=cam_extrinsics,
        cam_intrinsics=cam_intrinsics,
        images_folder=os.path.join(path, reading_dir),
    )
    cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print(
            "Converting point3d.bin to .ply, will happen only the first time you open the scene."
        )
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
    cam_infos = []

    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)
        fovx = contents["camera_angle_x"]

        frames = contents["frames"]
        # read from zip files
        zip_path = os.path.join(path, "image.zip")

        if os.path.exists(zip_path):
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                # zip_contents = zip_ref.namelist()
                # print("zip_contents", zip_contents)

                for idx, frame in enumerate(frames):
                    cam_name = os.path.join(path, frame["file_path"] + extension)

                    # Note, process 'transform_matrix' into a opencv style camera-to-world matrix
                    c2w = np.array(frame["transform_matrix"])
                    c2w = convert_cam_coords(c2w)

                    # get the world-to-camera transform and set R, T
                    w2c = np.linalg.inv(c2w)
                    R = np.transpose(
                        w2c[:3, :3]
                    )  # R is stored transposed due to 'glm' in CUDA code
                    T = w2c[:3, 3]

                    image_path = os.path.join(path, cam_name)
                    image_name = Path(cam_name).stem

                    # zip directly
                    image_data = zip_ref.read(image_name + extension)
                    image_file = BytesIO(image_data)

                    image = Image.open(image_file)

                    im_data = np.array(image.convert("RGBA"))

                    bg = (
                        np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])
                    )

                    norm_data = im_data / 255.0
                    arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (
                        1 - norm_data[:, :, 3:4]
                    )
                    image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")

                    fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
                    FovY = fovy
                    FovX = fovx

                    cam_infos.append(
                        CameraInfo(
                            uid=idx,
                            R=R,
                            T=T,
                            FovY=FovY,
                            FovX=FovX,
                            image=image,
                            image_path=image_path,
                            image_name=image_name,
                            width=image.size[0],
                            height=image.size[1],
                        )
                    )
        else:
            # no zip files
            for idx, frame in enumerate(frames):
                cam_name = os.path.join(path, frame["file_path"] + extension)

                # Note, process 'transform_matrix' into a opencv style camera-to-world matrix
                c2w = np.array(frame["transform_matrix"])
                c2w = convert_cam_coords(c2w)

                # get the world-to-camera transform and set R, T
                w2c = np.linalg.inv(c2w)
                R = np.transpose(
                    w2c[:3, :3]
                )  # R is stored transposed due to 'glm' in CUDA code
                T = w2c[:3, 3]

                image_path = os.path.join(path, cam_name)
                image_name = Path(cam_name).stem

                image = Image.open(image_path)

                im_data = np.array(image.convert("RGBA"))

                bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])

                norm_data = im_data / 255.0
                arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (
                    1 - norm_data[:, :, 3:4]
                )
                image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")

                fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
                FovY = fovy
                FovX = fovx

                cam_infos.append(
                    CameraInfo(
                        uid=idx,
                        R=R,
                        T=T,
                        FovY=FovY,
                        FovX=FovX,
                        image=image,
                        image_path=image_path,
                        image_name=image_name,
                        width=image.size[0],
                        height=image.size[1],
                    )
                )

    return cam_infos


def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
    print("Reading Training Transforms")
    if not os.path.exists(os.path.join(path, "transforms_train.json")):
        # rename the transforms.json to transforms_train.json
        os.rename(
            os.path.join(path, "transforms.json"),
            os.path.join(path, "transforms_train.json"),
        )

    train_cam_infos = readCamerasFromTransforms(
        path, "transforms_train.json", white_background, extension
    )
    print("Reading Test Transforms")
    test_cam_infos = readCamerasFromTransforms(
        path, "transforms_test.json", white_background, extension
    )

    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "points3d.ply")
    mesh_path = os.path.join(path, "point_cloud.obj")
    if not os.path.exists(mesh_path):
        # if mesh do not exist, then randomize
        num_pts = 100_000
        print(f"Generating random point cloud ({num_pts})...")

        # create random points inside the bounds of the synthetic Blender scenes
        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
        shs = np.random.random((num_pts, 3)) / 255.0
        pcd = BasicPointCloud(
            points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
        )
        storePly(ply_path, xyz, SH2RGB(shs) * 255)
    else:
        pcd = fetchObj(mesh_path)

    scene_info = SceneInfo(
        point_cloud=pcd,
        train_cameras=train_cam_infos,
        test_cameras=test_cam_infos,
        nerf_normalization=nerf_normalization,
        ply_path=ply_path,
    )
    return scene_info


sceneLoadTypeCallbacks = {
    "Colmap": readColmapSceneInfo,
    "Blender": readNerfSyntheticInfo,
}
