#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

from argparse import ArgumentParser
from pathlib import Path

import mapbox_earcut as earcut
import numpy as np
import open3d as o3d
import torch
from PIL import Image
from rtree import index
from scipy.spatial import cKDTree
from shapely.geometry import Polygon
from skimage import measure
from torchvision import transforms

from arguments import ModelParams, PipelineParams, get_combined_args
from scene import Scene
from scene.planar_model import GaussianModelPlanes
from utils.general_utils import safe_state
from utils.plane_utils import quaternion_to_rotation_matrix


def get_plane_camera(camera, plane):
    translation = plane.get_translation
    normal = plane.get_normal

    translation_view = (
        (translation - camera.camera_center).unsqueeze(0)
        @ camera.world_view_transform[:3, :3]
    ).squeeze(0)
    normal_view = (normal.unsqueeze(0) @ camera.world_view_transform[:3, :3]).squeeze(0)

    D = -torch.dot(translation_view, normal_view)
    parameters = torch.tensor(
        [normal_view[0], normal_view[1], normal_view[2], D],
    ).to(normal.device)
    return parameters


def get_rays(camera, mask):
    h, w = mask.shape[1], mask.shape[2]

    # Compute pixel coordinates with principal point correction
    y = torch.linspace(0, h - 1, h, device=mask.device) - camera.Cy
    x = torch.linspace(0, w - 1, w, device=mask.device) - camera.Cx
    yy, xx = torch.meshgrid(y, x, indexing="ij")

    # Compute ray directions in camera space
    rays = torch.stack([xx / camera.Fx, yy / camera.Fy, torch.ones_like(xx)], dim=-1)

    # Normalize rays to unit vectors
    rays = rays / torch.norm(rays, dim=-1, keepdim=True)
    return rays


def get_ray_plane_intersection(parameters, rays, eps=1e-3):
    A, B, C, D = parameters

    t = -D / (A * rays[:, :, 0] + B * rays[:, :, 1] + C * rays[:, :, 2])
    intersection = rays * t.unsqueeze(-1)

    return intersection


def intersection_camera_to_world(intersection, camera):
    """
    Args:
        intersection (torch.Tensor): Intersection points in the camera frame, shape (N, 3).
        camera (Camera): Camera object containing world_view_transform.

    Returns:
        torch.Tensor: Intersection points in the world frame, shape (N, 3).
    """
    # Compute the inverse of the world-to-view transformation to get view-to-world
    view_to_world_transform = torch.inverse(camera.world_view_transform)

    # Convert intersection points to homogeneous coordinates (N, 4)
    N = intersection.shape[0]
    intersection_h = torch.cat(
        [intersection, torch.ones((N, 1), device=intersection.device)], dim=1
    )

    # Apply the inverse transformation matrix
    intersection_world_h = intersection_h @ view_to_world_transform

    # Convert back to 3D coordinates by dividing by the homogeneous component
    intersection_world = intersection_world_h[:, :3] / intersection_world_h[:, 3:4]
    return intersection_world


def project_to_plane_2d(points_3d, origin, R):
    """
    Projects 3D points to 2D coordinates in the plane using rotation R.

    Args:
        points_3d: (N, 3) tensor of 3D points in world space.
        plane_origin: (3,) tensor of plane position.
        R: (3,3) rotation matrix defining plane basis.

    Returns:
        (N, 2) tensor of 2D coordinates.
    """
    points_local = points_3d - origin

    # Project onto the plane basis (R is assumed to be an orthonormal basis)
    R_inverse = R.transpose(0, 1)
    points_2d = torch.matmul(R_inverse, points_local.unsqueeze(-1)).squeeze(-1)
    points_2d = points_2d[:, :2]
    return points_2d


def voxel_downsample(points, voxel_size=0.01):
    voxel_indices = np.floor(points / voxel_size).astype(int)
    _, unique_indices = np.unique(voxel_indices, axis=0, return_index=True)
    downsampled_points = points[unique_indices]
    return downsampled_points


def remove_statistical_outliers(points, nb_neighbors=20, std_ratio=2.0):
    tree = cKDTree(points)
    distances, _ = tree.query(points, k=nb_neighbors)
    mean_distances = distances.mean(axis=1)
    std_distances = distances.std(axis=1)
    is_outlier = distances[:, -1] > mean_distances + std_ratio * std_distances
    inlier_points = points[~is_outlier]
    return inlier_points


def triangulate(points2d, grid_resolution):

    try:
        grid_min = points2d.min(axis=0) - grid_resolution
        grid_max = points2d.max(axis=0) + grid_resolution
        grid_shape = np.ceil((grid_max - grid_min) / grid_resolution).astype(int)

        # Create the grid (initialized to 0)
        grid = np.zeros((grid_shape[0], grid_shape[1]), dtype=bool)
        grid_indices = np.floor((points2d - grid_min) / grid_resolution).astype(int)

        valid_mask = (
            (grid_indices[:, 0] >= 0)
            & (grid_indices[:, 0] < grid_shape[0])
            & (grid_indices[:, 1] >= 0)
            & (grid_indices[:, 1] < grid_shape[1])
        )

        grid[grid_indices[valid_mask, 0], grid_indices[valid_mask, 1]] = True

        """ countour reconstruction """
        contours = measure.find_contours(grid, 0.5)

        # Convert to (x, y) format and scale back to world coordinates
        contours = [grid_min + c * grid_resolution for c in contours]

    except Exception as e:
        print("Failed to find contours")
        return [], []

    # Sort contours by length (longest are likely outer boundaries)
    contours = sorted(contours, key=len, reverse=True)
    contours = [c for c in contours if len(c) > 100]

    outer_contours = []
    holes_dict = {}

    # Build R-tree for spatial queries
    idx = index.Index()
    contour_polygons = {}

    for i, contour in enumerate(contours):
        poly = Polygon(contour)
        contour_polygons[i] = poly
        idx.insert(i, poly.bounds)

    for i, contour in enumerate(contours):
        poly = contour_polygons[i]

        # Find potential parent contours
        possible_parents = list(idx.intersection(poly.bounds))
        inside_existing = False

        for parent_idx in possible_parents:
            if parent_idx == i:
                continue  # Skip itself

            parent_poly = contour_polygons[parent_idx]
            if parent_poly.contains(poly):  # Check full containment
                holes_dict.setdefault(
                    tuple(map(tuple, contours[parent_idx])), []
                ).append(contour)
                inside_existing = True
                break

        if not inside_existing:
            outer_contours.append(contour)

    vertices_all = []
    indices_all = []

    for outer in outer_contours:
        holes = holes_dict.get(tuple(map(tuple, outer)), [])

        if len(outer) < 3:
            continue

        vertices = np.vstack([outer] + holes)

        hole_indices = [len(outer)]
        for hole in holes:
            hole_indices.append(hole_indices[-1] + len(hole))

        hole_indices = np.array(hole_indices, dtype=np.uint32)
        hole_indices[-1] = len(vertices)

        try:
            indices = earcut.triangulate_float64(vertices, hole_indices)

            vertices_all.append(vertices)
            indices_all.append(indices)
        except Exception as e:
            print("Failed to triangulate contour")
            continue

    return vertices_all, indices_all


def main(
    gaussians: GaussianModelPlanes,
    scene: Scene,
    device: torch.device,
    output_path: str,
    grid_resolution: float = 0.1,
):
    cameras = scene.getTrainCameras()
    planes_mesh = []

    scene_id = Path(scene.model_path).name

    output_path = Path(output_path) / scene_id
    output_path.mkdir(exist_ok=True, parents=True)

    print(f"[{scene_id}] number of cameras: ", len(cameras))
    print(f"[{scene_id}] number of planes: ", len(gaussians.planes))

    for plane_id, plane in enumerate(gaussians.planes):

        origin_plane = plane.get_translation
        R_plane = quaternion_to_rotation_matrix(plane.get_rotation)

        # Initialize a list for storing projected points
        all_projected_points_2d = []

        # Load and project masks
        plane_masks = (
            Path(scene.model_path)
            / "train"
            / "ours_30000"
            / "renders_mask"
            / str(plane_id)
        )
        mask_paths = list(plane_masks.glob("*.png"))
        print(f"[{scene_id}] number of masks: ", len(mask_paths))

        for cam_id, camera in enumerate(cameras):
            mask_path = plane_masks / f"{cam_id:05d}.png"
            if not mask_path.exists():
                continue

            # Load mask
            mask = Image.open(mask_path).convert("L")
            mask = transforms.ToTensor()(mask).to(device)

            """ intersections in camera space """
            parameters = get_plane_camera(camera, plane)
            _, _, _, D = parameters

            if torch.abs(D) < 1e-3:  # D close to 0
                continue

            rays = get_rays(camera, mask)
            intersection_points = get_ray_plane_intersection(parameters, rays)
            intersection_points = intersection_points.reshape(-1, 3)

            """ filter out points that are not in the mask """
            mask_indices = mask > 0.8
            if mask_indices.sum() == 0:
                continue

            valid_intersection_points = intersection_points[mask_indices.reshape(-1)]
            valid_intersection_points_w = intersection_camera_to_world(
                valid_intersection_points, camera
            )
            """ project 3d points to 2d plane coordinates """
            projected_points_2d = project_to_plane_2d(
                valid_intersection_points_w, origin_plane, R_plane
            )

            projected_points_2d = projected_points_2d.cpu().numpy()
            all_projected_points_2d.append(projected_points_2d)

        # Merge all projected points from all masks
        if len(all_projected_points_2d) == 0:
            print(f"[{scene_id}] no valid points for plane: ", plane_id)
            continue

        # Compute grid bounds
        all_projected_points_2d = np.concatenate(all_projected_points_2d, axis=0)

        all_projected_points_2d = voxel_downsample(
            all_projected_points_2d, voxel_size=0.01
        )
        all_projected_points_2d = remove_statistical_outliers(
            all_projected_points_2d, nb_neighbors=20, std_ratio=2.0
        )
        vertices_all, indices_all = triangulate(
            all_projected_points_2d, grid_resolution
        )

        if len(vertices_all) == 0 and len(indices_all) == 0:
            print(f"[{scene_id}] no valid triangles for plane: ", plane_id)
            continue

        for vertices, indices in zip(vertices_all, indices_all):

            mesh = o3d.geometry.TriangleMesh()

            mesh_vertices = vertices
            mesh_vertices = np.concatenate(
                [mesh_vertices, np.zeros((len(vertices), 1))], axis=1
            )

            mesh_vertices3d = (
                np.matmul(
                    R_plane.cpu().numpy()[None, ...], mesh_vertices[..., None]
                ).squeeze(-1)
                + origin_plane.cpu().numpy()
            )
            mesh_vertices3d = np.ascontiguousarray(mesh_vertices3d)

            indices = np.ascontiguousarray(np.array(indices).reshape(-1, 3))

            mesh.vertices = o3d.utility.Vector3dVector(mesh_vertices3d)
            mesh.triangles = o3d.utility.Vector3iVector(indices)

            planes_mesh.append(mesh)

    numPlanes = len(planes_mesh)
    segmentationColor = (np.arange(numPlanes + 1) + 1) * 100
    colorMap = np.stack(
        [
            segmentationColor / (256 * 256),
            segmentationColor / 256 % 256,
            segmentationColor % 256,
        ],
        axis=1,
    ).astype("uint8")

    colored_meshes = []
    for plane_id, mesh in enumerate(planes_mesh):
        color = colorMap[plane_id] / 255.0
        mesh.paint_uniform_color(color)
        colored_meshes.append(mesh)

    mesh_big = colored_meshes[0]
    for mesh in colored_meshes[1:]:
        mesh_big += mesh

    print(f"[{scene_id}] number of planes: ", numPlanes)
    o3d.io.write_triangle_mesh(
        str(output_path / f"{scene_id}_planar_mesh.ply"), mesh_big
    )


if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--grid_resolution", default=0.1, type=float)
    parser.add_argument("--output_path", type=str)
    args = get_combined_args(parser)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    dataset = model.extract(args)
    iteration = args.iteration

    gaussians = GaussianModelPlanes(dataset.sh_degree)
    scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

    with torch.no_grad():
        main(gaussians, scene, device, args.output_path, args.grid_resolution)
