import torch
from arguments import OptimizationParams
from scene.cameras import Camera
from scene.gaussian_model import GaussianModel
from torch_kdtree import build_kd_tree
from torch_kdtree.nn_distance import TorchKDTree
from utils.general_utils import build_rotation
from utils.plane_utils import normal_to_quaternion


def fit_plane_3_points(points: torch.Tensor):

    p1, p2, p3 = points[0], points[1], points[2]

    v1 = p2 - p1
    v2 = p3 - p1

    # Compute the normal vector to the plane
    normal = torch.linalg.cross(v1, v2, dim=0)
    normal = normal / torch.norm(normal)

    A, B, C = normal
    D = -torch.dot(normal, p1)
    return A, B, C, D


def points_to_plane_distance(points: torch.Tensor, plane: torch.Tensor):
    A, B, C, D = plane
    x, y, z = points[:, 0], points[:, 1], points[:, 2]
    numerator = torch.abs(A * x + B * y + C * z + D)
    denominator = torch.sqrt(A**2 + B**2 + C**2) + 1e-6
    return numerator / denominator


def ransac(points: torch.Tensor, threshold: float, max_iter: int):
    """
    RANSAC algorithm to fit a plane in 3D.
    """
    best_plane = None
    best_inlier_count = 0
    best_inliers = None
    best_distances = None

    N = points.size(0)

    for _ in range(max_iter):
        # Randomly sample 3 points
        indices = torch.randperm(N)[:3]
        sampled_points = points[indices]

        # Fit a plane to the sampled points
        plane = fit_plane_3_points(sampled_points)

        # Compute distances to the plane for all points
        distances = points_to_plane_distance(points, plane)

        # Determine inliers
        inliers = distances < threshold
        inlier_count = inliers.sum().item()

        # Update the best plane if more inliers are found
        if inlier_count > best_inlier_count:
            best_plane = plane
            best_inlier_count = inlier_count
            best_inliers = inliers
            best_distances = distances

    return best_plane, best_inliers, best_distances


def get_pixel_grid(height: int, width: int):
    u = torch.arange(0, width, device="cuda").repeat(height, 1)
    v = torch.arange(0, height, device="cuda").repeat(width, 1).t()

    # Stack to get (height, width, 2)
    pixels = torch.stack((u, v), dim=-1).to(torch.float32)
    pixels += 0.5  # Center of the pixel as in the camera model
    return pixels


def get_xyz_depth_view_h(depth: torch.Tensor, cam: Camera):
    pixels = get_pixel_grid(depth.shape[0], depth.shape[1])

    pixels = pixels.reshape(-1, 2)
    depth = depth.reshape(-1)

    # Convert to camera coordinates
    xyz_view_h = torch.stack(
        [
            (pixels[:, 0] - cam.Cx) * depth / cam.Fx,
            (pixels[:, 1] - cam.Cy) * depth / cam.Fy,
            depth,
            torch.ones_like(depth),
        ],
        dim=1,
    )

    return xyz_view_h


def get_xyz_depth(depth: torch.Tensor, cam: Camera):
    xyz_view_h = get_xyz_depth_view_h(depth, cam)

    # Convert to world coordinates
    view_world_transform = cam.world_view_transform.inverse()
    xyz_world = xyz_view_h @ view_world_transform
    xyz_world = xyz_world[:, :3] / xyz_world[:, 3].unsqueeze(1)

    return xyz_world


def get_xyz_view_h(xyz_world: torch.Tensor, cam: Camera):
    xyz_world_h = torch.cat(
        [
            xyz_world,
            torch.ones(xyz_world.shape[0], 1, device="cuda"),
        ],
        dim=1,
    )
    xyz_view_h = xyz_world_h @ cam.world_view_transform
    return xyz_view_h


def get_image_pixels(xyz_world: torch.Tensor, cam: Camera):

    xyz_view_h = get_xyz_view_h(xyz_world, cam)
    xyz_view = xyz_view_h[:, :3]

    # Convert xyz to uv image pixel coordinates
    xyz_view = xyz_view / xyz_view[:, 2].unsqueeze(1)

    K = torch.tensor(
        [
            [cam.Fx, 0, cam.Cx],
            [0, cam.Fy, cam.Cy],
            [0, 0, 1],
        ],
        device="cuda",
    )

    image_pixels = xyz_view @ K.transpose(0, 1)
    image_pixels = image_pixels[:, :2]

    return image_pixels


def build_image_pixels_kdtree(xyz_world: torch.Tensor, cam: Camera):
    image_pixels = get_image_pixels(xyz_world, cam)
    kdtree = build_kd_tree(image_pixels)
    return kdtree


@torch.no_grad()
def find_closest_mask_points(
    mask: torch.Tensor,
    kdtree: TorchKDTree,
    kdtree_filter: torch.Tensor,
):
    kdtree_filter = kdtree_filter.squeeze(-1)
    kdtree_indices = torch.where(kdtree_filter)[0]

    mask_pixels = get_pixel_grid(mask.shape[0], mask.shape[1])[mask > 0.5]

    dists, close_indices = kdtree.query(mask_pixels, nr_nns_searches=1)
    dists = torch.sqrt(dists)

    close_indices = close_indices.squeeze(-1)[dists.squeeze(-1) < 1]
    close_indices = torch.unique(close_indices)

    mask_xyz = torch.zeros_like(kdtree_filter, device="cuda")
    mask_xyz[kdtree_indices[close_indices]] = True

    return mask_xyz


import open3d as o3d
import numpy as np

INDEX = 0


@torch.no_grad()
def find_plane_fitting_points(
    mask: torch.Tensor,
    cam: Camera,
    render_pkg: dict,
    gaussians: GaussianModel,
    opt: OptimizationParams,
    kdtree: TorchKDTree,
    kdtree_filter: torch.Tensor,
):

    vis_filter = kdtree_filter
    vis_indices = torch.where(vis_filter)[0]

    if vis_indices.shape[0] < opt.plane_fit_min_points:
        print("Not enough visible points")
        return None

    # Take visible points from the kdtree
    xyz = gaussians.get_xyz
    vis_xyz = xyz[vis_indices]

    mask_pixels = get_pixel_grid(mask.shape[0], mask.shape[1])[mask > 0.5]
    dists, close_indices = kdtree.query(mask_pixels, nr_nns_searches=1)
    dists = torch.sqrt(dists)

    # distance of 1 pixel
    close_indices = close_indices.squeeze(-1)[dists.squeeze(-1) < 1]
    close_indices = torch.unique(close_indices)

    # Take points that are visible and close to the mask
    close_xyz = vis_xyz[close_indices]

    # # DEBUG
    # close_pixels = get_image_pixels(close_xyz, cam)
    # close_pixels = close_pixels.cpu().numpy()

    # plt.imshow(mask.cpu().numpy(), cmap="gray")
    # plt.scatter(close_pixels[:, 0], close_pixels[:, 1], c="r", s=1)
    # plt.show()
    # plt.close()
    # # END DEBUG

    if close_indices.shape[0] < opt.plane_fit_min_points:
        return None

    neighbour_xyz = close_xyz
    neighbour_indices = vis_indices[close_indices]

    if neighbour_indices.shape[0] < opt.plane_fit_min_points:
        return None

    # weights_radii = render_pkg["radii"][neighbour_indices].squeeze(-1)
    # weights_radii = weights_radii / weights_radii.sum()

    # weights_opacity = gaussians.get_opacity[neighbour_indices].squeeze(-1)
    # weights_opacity = weights_opacity / weights_opacity.sum()

    # High opacity means that the point is more likely to be on the surface
    # We also want to take into account the radii of the gaussians
    # weights = weights_opacity * weights_radii
    # weights = 1 - (1 - weights_opacity) * (1 - weights_radii)

    # weights = weights_radii

    # if weights.isnan().any():
    #     return None

    # upsampling_factor = 10
    # upsampled_indices = torch.multinomial(
    #     weights, weights.shape[0] * upsampling_factor, replacement=True
    # )

    # Upsample points based on the Gaussian model
    # scale = gaussians.get_scaling[neighbour_indices]
    # rotation = gaussians.get_rotation[neighbour_indices]
    # R = build_rotation(rotation)

    # sample = torch.normal(
    #     mean=torch.zeros((upsampled_indices.shape[0], 3), device="cuda"),
    #     std=scale[upsampled_indices],
    # )

    # upsampled_xyz = (
    #     torch.bmm(R[upsampled_indices], sample.unsqueeze(-1)).squeeze(-1)
    #     + neighbour_xyz[upsampled_indices]
    # )

    # upsampled_xyz = torch.cat([neighbour_xyz, upsampled_xyz], dim=0)
    # upsampled_indices = torch.cat([neighbour_indices, upsampled_indices], dim=0)

    upsampled_xyz = neighbour_xyz
    upsampled_indices = neighbour_indices

    # Filter points based on depth
    depth = render_pkg["depth"].squeeze(0)
    valid_depth_mask = (depth > 0) & (mask > 0.5)

    depth_xyz = get_xyz_depth(depth, cam)
    depth_xyz = depth_xyz[valid_depth_mask.reshape(-1)]

    # # DEBUG
    # pcd2 = o3d.geometry.PointCloud()
    # depth_points = np.ascontiguousarray(depth_xyz.cpu().numpy())
    # pcd2.points = o3d.utility.Vector3dVector(depth_points)
    # pcd2.paint_uniform_color([1, 0, 0])
    # # END DEBUG

    # Filter out points that are in out of the bounds of the depth image
    offset = 0.1
    depth_filter = (
        (upsampled_xyz[:, 0] > depth_xyz[:, 0].min() - offset)
        & (upsampled_xyz[:, 1] > depth_xyz[:, 1].min() - offset)
        & (upsampled_xyz[:, 2] > depth_xyz[:, 2].min() - offset)
        & (upsampled_xyz[:, 0] < depth_xyz[:, 0].max() + offset)
        & (upsampled_xyz[:, 1] < depth_xyz[:, 1].max() + offset)
        & (upsampled_xyz[:, 2] < depth_xyz[:, 2].max() + offset)
    )

    upsampled_indices = upsampled_indices[depth_filter]
    upsampled_xyz = upsampled_xyz[depth_filter]

    if upsampled_xyz.shape[0] < opt.plane_fit_min_points:
        return None

    # # DEBUG
    # pcd4 = o3d.geometry.PointCloud()
    # points = np.ascontiguousarray(upsampled_xyz.cpu().numpy())
    # pcd4.points = o3d.utility.Vector3dVector(points)
    # pcd4.paint_uniform_color([0.5, 0.5, 0.5])

    # pcd2 += pcd4

    # global INDEX
    # o3d.io.write_point_cloud(f"output/0e75f3c4d9/xyz_{INDEX}.ply", pcd2)
    # INDEX += 1
    # # END DEBUG

    return upsampled_xyz, upsampled_indices


@torch.no_grad()
def plane_fitting_pipeline(
    mask: torch.Tensor,
    cam: Camera,
    render_pkg: dict,
    gaussians: GaussianModel,
    opt: OptimizationParams,
    kdtree: TorchKDTree,
    kdtree_filter: torch.Tensor,
):

    plane_points = find_plane_fitting_points(
        mask, cam, render_pkg, gaussians, opt, kdtree, kdtree_filter
    )

    if plane_points is None:
        return None

    upsampled_xyz, upsampled_indices = plane_points

    # Fit a plane to the upsampled points
    plane_model, inliers, distances = ransac(
        upsampled_xyz, opt.plane_fit_threshold, max_iter=1000
    )
    a, b, c, d = plane_model

    plane_indices = upsampled_indices[inliers].unique()
    num_gaussians = plane_indices.shape[0]

    median_distance = distances.median().item()
    inliers_percentage = inliers.sum().item() / upsampled_xyz.shape[0]

    # if num_gaussians < opt.plane_fit_min_points:
    #     print("Not enough points to fit a plane")
    #     return None

    if median_distance > opt.plane_reject_threshold:
        print("Median distance is too high")
        print(
            f"Plane {a:.4f}x + {b:.4f}y + {c:.4f}z + {d:.4f} = 0 REJECTED ({num_gaussians} Gaussians)"
        )
        print(f"Median residual: {median_distance:.4f}")
        print(f"Inliers percentage: {inliers_percentage:.4f}")
        print()
        return None

    translation = torch.mean(upsampled_xyz[inliers], dim=0)
    normal = torch.tensor([a, b, c], device="cuda", dtype=torch.float32)
    normal = normal / torch.norm(normal)

    rotation = normal_to_quaternion(normal)

    print(
        f"Plane {a:.4f}x + {b:.4f}y + {c:.4f}z + {d:.4f} = 0 ACCEPTED ({num_gaussians} Gaussians)"
    )
    print(f"Median residual: {median_distance:.4f}")
    print(f"Inliers percentage: {inliers_percentage:.4f}")
    print()

    # pcd = o3d.geometry.PointCloud()
    # points = np.ascontiguousarray(upsampled_xyz[inliers].cpu().numpy())
    # pcd.points = o3d.utility.Vector3dVector(points)
    # pcd.paint_uniform_color([1, 0, 0])

    # pcd2 = o3d.geometry.PointCloud()
    # points = np.ascontiguousarray(upsampled_xyz[~inliers].cpu().numpy())
    # pcd2.points = o3d.utility.Vector3dVector(points)
    # pcd2.paint_uniform_color([0, 0, 1])

    # pcd += pcd2

    # global INDEX
    # o3d.io.write_point_cloud(f"output/0e75f3c4d9/point_cloud_{INDEX}.ply", pcd)
    # INDEX += 1

    return rotation, translation, plane_indices
