import numpy as np
import open3d as o3d
import torch

from pytorch3d.ops import knn_points
from models.utils import apply_transform, weighted_svd


def registration_ransac_based_on_correspondence(
    ref_corres_xyz: torch.Tensor,
    src_corres_xyz: torch.Tensor,
    corr_weight: torch.Tensor,
    verified_ref_points: torch.Tensor = None,
    verified_src_points: torch.Tensor = None,
    inlier_threshold = 0.05,
    topk = 250,
    ransac_iters = 50000,
    ransac_n = 4
):
    indices = torch.argsort(corr_weight, descending=True)[:topk]
    if verified_ref_points is None:
        ref_points = ref_corres_xyz[indices]
    else:
        ref_points = torch.cat([ref_corres_xyz[indices], verified_ref_points], dim=0)
    if verified_src_points is None:
        src_points = src_corres_xyz[indices]
    else:
        src_points = torch.cat([src_corres_xyz[indices], verified_src_points], dim=0)
    indices = np.arange(indices.shape[0])
    correspondences = np.stack([indices, indices], axis=1)
    correspondences = o3d.utility.Vector2iVector(correspondences)
        
    ref_pcd = o3d.geometry.PointCloud()
    ref_pcd.points = o3d.utility.Vector3dVector(ref_points.detach().cpu().numpy())
    src_pcd = o3d.geometry.PointCloud()
    src_pcd.points = o3d.utility.Vector3dVector(src_points.detach().cpu().numpy())

    transform = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
        src_pcd, ref_pcd, correspondences, inlier_threshold,
        o3d.pipelines.registration.TransformationEstimationPointToPoint(False), ransac_n,
        [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
        o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(inlier_threshold)],
        o3d.pipelines.registration.RANSACConvergenceCriteria(ransac_iters, 0.999)
    ).transformation
    
    return torch.FloatTensor(np.array(transform)).to(corr_weight.device)


class LocalGlobalRegistration(torch.nn.Module):
    def __init__(self, acceptance_radius, correspondence_n=4, num_refinement_steps=5):
        super(LocalGlobalRegistration, self).__init__()
        self.acceptance_radius = acceptance_radius
        self.correspondence_n = correspondence_n
        self.num_refinement_steps = num_refinement_steps

    @staticmethod
    def convert_to_batch(ref_corr_points, src_corr_points, corr_scores, chunks):
        """Convert stacked correspondences to batched points.

        The extracted dense correspondences from all patch correspondences are stacked. However, to compute the
        transformations from all patch correspondences in parallel, the dense correspondences need to be reorganized
        into a batch.

        Args:
            ref_corr_points (Tensor): (C, 3)
            src_corr_points (Tensor): (C, 3)
            corr_scores (Tensor): (C,)
            chunks (List[Tuple[int, int]]): the starting index and ending index of each patch correspondences.

        Returns:
            batch_ref_corr_points (Tensor): (B, K, 3), padded with zeros.
            batch_src_corr_points (Tensor): (B, K, 3), padded with zeros.
            batch_corr_scores (Tensor): (B, K), padded with zeros.
        """
        batch_size = len(chunks)
        indices = torch.cat([torch.arange(x, y) for x, y in chunks], dim=0).cuda()
        ref_corr_points = ref_corr_points[indices]  # (total, 3)
        src_corr_points = src_corr_points[indices]  # (total, 3)
        corr_scores = corr_scores[indices]  # (total,)

        max_corr = np.max([y - x for x, y in chunks])
        target_chunks = [(i * max_corr, i * max_corr + y - x) for i, (x, y) in enumerate(chunks)]
        indices = torch.cat([torch.arange(x, y) for x, y in target_chunks], dim=0).cuda()
        indices0 = indices.unsqueeze(1).expand(indices.shape[0], 3)  # (total,) -> (total, 3)
        indices1 = torch.arange(3).unsqueeze(0).expand(indices.shape[0], 3).cuda()  # (3,) -> (total, 3)

        batch_ref_corr_points = torch.zeros(batch_size * max_corr, 3).cuda()
        batch_ref_corr_points.index_put_([indices0, indices1], ref_corr_points)
        batch_ref_corr_points = batch_ref_corr_points.view(batch_size, max_corr, 3)

        batch_src_corr_points = torch.zeros(batch_size * max_corr, 3).cuda()
        batch_src_corr_points.index_put_([indices0, indices1], src_corr_points)
        batch_src_corr_points = batch_src_corr_points.view(batch_size, max_corr, 3)

        batch_corr_scores = torch.zeros(batch_size * max_corr).cuda()
        batch_corr_scores.index_put_([indices], corr_scores)
        batch_corr_scores = batch_corr_scores.view(batch_size, max_corr)

        return batch_ref_corr_points, batch_src_corr_points, batch_corr_scores

    def recompute_correspondence_scores(self, ref_corr_points, src_corr_points, corr_scores, estimated_transform):
        aligned_src_corr_points = apply_transform(src_corr_points, estimated_transform)
        corr_residuals = torch.norm(ref_corr_points - aligned_src_corr_points, dim=1)
        inlier_masks = torch.lt(corr_residuals, self.acceptance_radius)
        new_corr_scores = corr_scores * inlier_masks.float()
        return new_corr_scores
    
    def forward(self,
        ref_patch_node: torch.Tensor,
        ref_corres_xyz: torch.Tensor,
        src_corres_xyz: torch.Tensor,
        corr_weight: torch.Tensor,
        verified_ref_points: torch.Tensor,
        verified_src_points: torch.Tensor,
    ):
        point_to_patch_indices = knn_points(ref_corres_xyz.unsqueeze(0), ref_patch_node.unsqueeze(0))[1]
        point_to_patch_indices, sorted_indices = torch.sort(point_to_patch_indices.squeeze())
        ref_corres_xyz = ref_corres_xyz[sorted_indices]
        src_corres_xyz = src_corres_xyz[sorted_indices]
        corr_weight = corr_weight[sorted_indices]

        unique_masks = torch.ne(point_to_patch_indices[1:], point_to_patch_indices[:-1])
        unique_indices = torch.nonzero(unique_masks, as_tuple=True)[0] + 1
        unique_indices = unique_indices.detach().cpu().numpy().tolist()
        unique_indices = [0] + unique_indices + [point_to_patch_indices.shape[0]]
        chunks = [(x, y) for x, y in zip(unique_indices[:-1], unique_indices[1:]) if y - x >= self.correspondence_n]
        
        batch_size = len(chunks)
        if batch_size > 0:
            # local registration
            batch_ref_corr_points, batch_src_corr_points, batch_corr_scores = self.convert_to_batch(
                ref_corres_xyz, src_corres_xyz, corr_weight, chunks
            )
            batch_transforms = weighted_svd(batch_src_corr_points, batch_ref_corr_points, batch_corr_scores)
            aligned_src_points = apply_transform(verified_src_points.unsqueeze(0), batch_transforms)  # (B, M, 3)
            verified_ref_points = verified_ref_points.unsqueeze(0).expand(aligned_src_points.shape[0], -1, -1)
            dist = knn_points(aligned_src_points, verified_ref_points)[0]  # (B, M, 1)
            best_index = torch.lt(dist.squeeze(-1), self.acceptance_radius).float().mean(1).argmax()
            estimated_transform = batch_transforms[best_index]
        else:
            # degenerate: initialize transformation with all correspondences
            estimated_transform = weighted_svd(src_corres_xyz, ref_corres_xyz, corr_weight)
        
        cur_corr_scores = self.recompute_correspondence_scores(
            ref_corres_xyz, src_corres_xyz, corr_weight, estimated_transform
        )

        # global refinement
        estimated_transform = weighted_svd(src_corres_xyz, ref_corres_xyz, cur_corr_scores)
        for _ in range(self.num_refinement_steps - 1):
            cur_corr_scores = self.recompute_correspondence_scores(
                ref_corres_xyz, src_corres_xyz, corr_weight, estimated_transform
            )
            estimated_transform = weighted_svd(src_corres_xyz, ref_corres_xyz, cur_corr_scores)

        return estimated_transform