import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.common import knn, rigid_transform_3d
from utils.SE3 import transform

class ResNet_Block(nn.Module):
    def __init__(self, inchannel, outchannel, pre=False):
        super(ResNet_Block, self).__init__()
        self.pre = pre
        self.right = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, kernel_size=1),
        )
        self.left = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, kernel_size=1),
            nn.InstanceNorm1d(outchannel),
            nn.BatchNorm1d(outchannel),
            nn.ReLU(),
            nn.Conv1d(outchannel, outchannel, kernel_size=1),
            nn.InstanceNorm1d(outchannel),
            nn.BatchNorm1d(outchannel),
        )

    def forward(self, x):
        x1 = self.right(x) if self.pre is True else x
        out = self.left(x)
        out = out + x1
        return torch.relu(out)

def darboux(points, normals, k):
    #normals = torch.nn.functional.normalize(normals, dim=-1, p=2)
    normals = normals.permute(0, 2, 1) #[8,3,1000]
    l1 = torch.norm(normals[:, :, None, :], p=2, dim=1).permute(0, 2, 1)  # [B, N, 1]
    l2 = torch.norm(normals_knn, p=2, dim=1)  # [B, N, K]
    a3 = torch.sum(normals[:, :, :, None] * normals[:, :, None, :], dim=1) / (l2 * l1 + 1e-10)  # [B, N, K]
    epsilon = 1e-8
    angle_radians = torch.acos(torch.clamp(a3, -1 + epsilon, 1 - epsilon))
    a3 = torch.abs(angle_radians * (180.0 / np.pi))
    return a3

class NonLocalBlock(nn.Module):
    def __init__(self, num_channels=128, num_heads=1):
        super(NonLocalBlock, self).__init__()

        self.embed_1 = nn.Sequential(
            ResNet_Block(num_channels, num_channels, pre=True),
        )
        self.projection_q = nn.Conv1d(num_channels, num_channels, kernel_size=1)
        self.projection_k = nn.Conv1d(num_channels, num_channels, kernel_size=1)
        self.projection_v = nn.Conv1d(num_channels, num_channels, kernel_size=1)

        self.num_channels = num_channels
        self.head = num_heads

    def forward(self, feat, attention):
        """
        Input:
            - feat:     [bs, num_channels, num_corr]  input feature
            - attention [bs, num_corr, num_corr]      spatial consistency matrix
        Output:
            - res:      [bs, num_channels, num_corr]  updated feature
        """
        bs, num_corr = feat.shape[0], feat.shape[-1]
        Q = self.projection_q(feat).view([bs, self.head, self.num_channels // self.head, num_corr])
        K = self.projection_k(feat).view([bs, self.head, self.num_channels // self.head, num_corr])
        V = self.projection_v(feat).view([bs, self.head, self.num_channels // self.head, num_corr])
        feat_attention = torch.einsum('bhco, bhci->bhoi', Q, K) / (self.num_channels // self.head) ** 0.5
        # combine the feature similarity with spatial consistency
        weight = torch.softmax(attention[:, None, :, :] * feat_attention, dim=-1)
        message = torch.einsum('bhoi, bhci-> bhco', weight * attention[:, None, :, :], V).reshape([bs, -1, num_corr])
        res = self.embed_1(message)
        return res


class DGC(nn.Module):
    def __init__(self, in_dim=6, num_layers=6, num_channels=128, k_num=20):
        super(DGC, self).__init__()
        self.num_layers = num_layers
        self.k_num = k_num
        self.blocks = nn.ModuleDict()
        self.layer0 = nn.Conv1d(in_dim, num_channels, kernel_size=1, bias=True)
        self.embed_0 = nn.Sequential(
            ResNet_Block(num_channels, num_channels, pre=False),
            ResNet_Block(num_channels, num_channels, pre=False),
            ResNet_Block(num_channels, num_channels, pre=False),
        )
        layer = nn.Sequential(
            nn.InstanceNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels, num_channels, kernel_size=1, bias=True),
        )
        self.blocks[f'PointCN_layer'] = layer
        self.blocks[f'NonLocal_layer'] = NonLocalBlock(num_channels)

    def forward(self, corr_feat, corr_compatibility):
        """
        Input:
            - corr_feat:          [bs, in_dim, num_corr]   input feature map
            - corr_compatibility: [bs, num_corr, num_corr] spatial consistency matrix
        Output:
            - feat:               [bs, num_channels, num_corr] updated feature
        """
        feat = self.layer0(corr_feat)
        feat = self.embed_0(feat)
        feat = self.blocks[f'PointCN_layer'](feat)
        feat = self.blocks[f'NonLocal_layer'](feat, corr_compatibility)
        return feat



class PGNet(nn.Module):
    def __init__(self,
                 in_dim=6,
                 num_layers=6,
                 num_channels=128,
                 num_iterations=10,
                 d_thre=0.10,
                 a_thre=15.0,
                 ratio=0.20,
                 pruning=0.25,
                 inlier_threshold=0.10,
                 sigma_d=0.10,
                 sigma_a=15.0,
                 k=40,
                 nms_radius=0.10,
                 k1=30,
                 k2=20,
                 #num_node=2000,
                 #num_node='all',
                 relax_match_num=30,
                 #relax_match_num=30,
                 FS_TCD_thre=0.05,
                 NS_by_IC=20,
                 ):
        super(PGNet, self).__init__()
        self.num_iterations = num_iterations  # maximum iteration of power iteration algorithm
        self.ratio = ratio  # the maximum ratio of seeds.
        self.pruning = pruning
        self.num_channels = num_channels
        self.inlier_threshold = inlier_threshold
        self.d_thre = d_thre
        self.a_thre = a_thre
        self.sigma = nn.Parameter(torch.Tensor([1.0]).float(), requires_grad=True)
        self.sigma_spat = nn.Parameter(torch.Tensor([sigma_d]).float(), requires_grad=False)
        self.sigma_a = nn.Parameter(torch.Tensor([sigma_a]).float(), requires_grad=False)
        self.k = k  # neighborhood number in NSM module.
        self.nms_radius = nms_radius  # only used during testing
        self.k1 = k1
        self.k2 = k2
        #self.num_node = num_node
        self.relax_match_num = relax_match_num
        self.FS_TCD_thre = FS_TCD_thre
        self.NS_by_IC = NS_by_IC
        self.encoder = DGC(
            in_dim=in_dim,
            num_layers=num_layers,
            num_channels=num_channels,
        )

        self.classification = nn.Sequential(
            nn.Conv1d(num_channels, 32, kernel_size=1, bias=True),
            nn.InstanceNorm1d(32, eps=1e-3),
            # nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 32, kernel_size=1, bias=True),
            nn.InstanceNorm1d(32, eps=1e-3),
            # nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 1, kernel_size=1, bias=True)
        )

        # initialization
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.xavier_normal_(m.weight, gain=1)
            elif isinstance(m, (nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, data):
        """
        Input:
            - corr_pos:   [bs, num_corr, 6]
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - testing:    flag for test phase, if False will not calculate M and post-refinement.
        Output: (dict)
            - final_trans:   [bs, 4, 4], the predicted transformation matrix.
            - final_labels:  [bs, num_corr], the predicted inlier/outlier label (0,1), for classification loss calculation.
            - M:             [bs, num_corr, num_corr], feature similarity matrix, for SM loss calculation.
            - seed_trans:    [bs, num_seeds, 4, 4],  the predicted transformation matrix associated with each seeding point, deprecated.
            - corr_features: [bs, num_corr, num_channels], the feature for each correspondence, for circle loss calculation, deprecated.
            - confidence:    [bs], confidence of returned results, for safe guard, deprecated.
        """
        corr_pos, src_keypts, tgt_keypts, src_normal, tgt_normal, gt_labels = data['corr_pos'], data['src_keypts'], \
        data['tgt_keypts'], data['src_normal'], data['tgt_normal'], data['gt_labels']
        bs, num_corr = corr_pos.shape[0], corr_pos.shape[1]
        testing = 'testing' in data.keys()
        #################################
        # Step1: extract feature for each correspondence
        #################################
        with torch.no_grad():
            src_angle2 = darboux(src_keypts, src_normal, num_corr)
            tgt_angle2 = darboux(tgt_keypts, tgt_normal, num_corr)
            angle_compatibility2 = src_angle2 - tgt_angle2
            angle_compatibility2 = torch.abs(angle_compatibility2)

            src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
            length_compatibility = src_dist - torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]),
                                                         dim=-1)
            length_compatibility = length_compatibility.to('cuda')
            angle_compatibility2 = angle_compatibility2.to('cuda')
            corr_compatibility = torch.exp(
               - length_compatibility ** 2 / (self.sigma_spat ** 2 * 2) - angle_compatibility2 ** 2 / (self.sigma_a ** 2 * 2))
            length_compatibility = torch.abs(length_compatibility)
            angle_compatibility2 = torch.abs(angle_compatibility2)
        corr_pos = corr_pos.to('cuda')
        corr_compatibility = corr_compatibility.to('cuda')
        corr_features = self.encoder(corr_pos.permute(0, 2, 1), corr_compatibility).permute(0, 2, 1)
        normed_corr_features = F.normalize(corr_features, p=2, dim=-1)
        if not testing:  # during training or validation
            # construct the feature similarity matrix M for loss calculation
            M3 = torch.matmul(normed_corr_features, normed_corr_features.permute(0, 2, 1))
            M3 = torch.clamp(1 - (1 - M3) / self.sigma ** 2, min=0, max=1)
            # set diagnal of M to zero
            M3[:, torch.arange(M3.shape[1]), torch.arange(M3.shape[1])] = 0
        else:
            M3 = None
        confidence = self.classification(corr_features.permute(0, 2, 1)).squeeze(1)
        seeds = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.ratio)]

        hard_SC_measure_tight = ((length_compatibility < self.sigma_spat) & (
                                         angle_compatibility2 < self.sigma_a)).float()  # [16,1000,1000]
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight)
        seed_trans, seed_fitness, final_trans, final_labels, final_labels1, final_trans1 = self.cal_seed_trans(seeds,
                                                                                                               normed_corr_features,
                                                                                                               SC2_measure,
                                                                                                               src_keypts,
                                                                                                               tgt_keypts)

        if testing:
            final_trans = self.post_refinement(final_trans, src_keypts, tgt_keypts)
            frag1_warp = transform(src_keypts, final_trans)
            final_labels1 = torch.sum((frag1_warp - tgt_keypts) ** 2, dim=-1) ** 0.5
            final_labels = (final_labels1 < self.inlier_threshold).float()
        if not testing:
            final_labels = confidence
        final_labels3 = final_labels
        gt_labels3 = gt_labels

        Candidates = torch.argsort(final_labels1, dim=1, descending=False)[:, 0:int(num_corr * self.pruning)]
        corr_pos = corr_pos.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 6))
        src_keypts = src_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_keypts = tgt_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_normal = src_normal.to('cuda')
        Candidates = Candidates.to('cuda')
        tgt_normal = tgt_normal.to('cuda')
        src_normal = src_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_normal = tgt_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        gt_labels = gt_labels.to('cuda')
        gt_labels = gt_labels.gather(dim=1, index=Candidates)

        bs, num_corr = corr_pos.shape[0], corr_pos.shape[1]
        testing = 'testing' in data.keys()
        # x1, x2 = corr_pos[:, :, :3], corr_pos[:, :, 3:6]
        with torch.no_grad():
            src_angle2 = darboux(src_keypts, src_normal, num_corr)
            tgt_angle2 = darboux(tgt_keypts, tgt_normal, num_corr)

            angle_compatibility2 = src_angle2 - tgt_angle2
            angle_compatibility2 = torch.abs(angle_compatibility2)

            src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
            length_compatibility = src_dist - torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]),
                                                         dim=-1)

            corr_compatibility = torch.exp(
                - length_compatibility ** 2 / (self.sigma_spat ** 2 * 2) - angle_compatibility2 ** 2 / (
                            self.sigma_a ** 2 * 2))
            length_compatibility = torch.abs(length_compatibility)
            angle_compatibility2 = torch.abs(angle_compatibility2)
        corr_features = self.encoder(corr_pos.permute(0, 2, 1), corr_compatibility).permute(0, 2, 1)
        normed_corr_features = F.normalize(corr_features, p=2, dim=-1)

        if not testing:  # during training or validation
            # construct the feature similarity matrix M for loss calculation
            M2 = torch.matmul(normed_corr_features, normed_corr_features.permute(0, 2, 1))
            M2 = torch.clamp(1 - (1 - M2) / self.sigma ** 2, min=0, max=1)
            # set diagnal of M to zero
            M2[:, torch.arange(M2.shape[1]), torch.arange(M2.shape[1])] = 0
        else:
            M2 = None

        confidence = self.classification(corr_features.permute(0, 2, 1)).squeeze(1)
        seeds = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.ratio)]
        hard_SC_measure_tight = ((length_compatibility < self.sigma_spat) & (
                angle_compatibility2 < self.sigma_a)).float()  # [16,1000,1000]

        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight)
        seed_trans, seed_fitness, final_trans, final_labels, final_labels1, final_trans1 = self.cal_seed_trans(seeds,
                                                                                                               normed_corr_features,
                                                                                                               SC2_measure,
                                                                                                               src_keypts,
                                                                                                               tgt_keypts)
        if testing:
            final_trans = self.post_refinement(final_trans, src_keypts, tgt_keypts)
            frag1_warp = transform(src_keypts, final_trans)
            final_labels1 = torch.sum((frag1_warp - tgt_keypts) ** 2, dim=-1) ** 0.5
            final_labels = (final_labels1 < self.inlier_threshold).float()
        if not testing:
            final_labels = confidence
        final_labels2 = final_labels
        gt_labels2 = gt_labels

        Candidates = torch.argsort(final_labels1, dim=1, descending=False)[:, 0:int(num_corr * self.pruning)]
        corr_pos = corr_pos.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 6))
        src_keypts = src_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_keypts = tgt_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_normal = src_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_normal = tgt_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        gt_labels = gt_labels.gather(dim=1, index=Candidates)

        bs, num_corr = corr_pos.shape[0], corr_pos.shape[1]
        testing = 'testing' in data.keys()
        # x1, x2 = corr_pos[:, :, :3], corr_pos[:, :, 3:6]
        with torch.no_grad():
            src_angle2 = darboux(src_keypts, src_normal, num_corr)
            tgt_angle2 = darboux(tgt_keypts, tgt_normal, num_corr)

            angle_compatibility2 = src_angle2 - tgt_angle2
            angle_compatibility2 = torch.abs(angle_compatibility2)

            src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
            length_compatibility = src_dist - torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]),
                                                         dim=-1)

            corr_compatibility = torch.exp(
                - length_compatibility ** 2 / (self.sigma_spat ** 2 * 2) - angle_compatibility2 ** 2 / (
                            self.sigma_a ** 2 * 2))
            length_compatibility = torch.abs(length_compatibility)
            angle_compatibility2 = torch.abs(angle_compatibility2)
        corr_features = self.encoder(corr_pos.permute(0, 2, 1), corr_compatibility).permute(0, 2, 1)
        normed_corr_features = F.normalize(corr_features, p=2, dim=-1)
        if not testing:  # during training or validation
            # construct the feature similarity matrix M for loss calculation
            M = torch.matmul(normed_corr_features, normed_corr_features.permute(0, 2, 1))
            M = torch.clamp(1 - (1 - M) / self.sigma ** 2, min=0, max=1)
            #M = torch.exp(- (1 - M) / self.sigma ** 2)
            # set diagnal of M to zero
            M[:, torch.arange(M.shape[1]), torch.arange(M.shape[1])] = 0
        else:
            M = None

        confidence = self.classification(corr_features.permute(0, 2, 1)).squeeze(1)
        seeds = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.ratio)]
        hard_SC_measure_tight = ((length_compatibility < self.sigma_spat) & (
                angle_compatibility2 < self.sigma_a)).float()  # [16,1000,1000]
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight)

        seed_trans, seed_fitness, final_trans, final_labels, final_labels1, final_trans1 = self.cal_seed_trans(seeds, normed_corr_features, SC2_measure, src_keypts, tgt_keypts)

        if testing:
            final_trans = self.post_refinement(final_trans, src_keypts, tgt_keypts)
            frag1_warp = transform(src_keypts, final_trans)
            distance = torch.sum((frag1_warp - tgt_keypts) ** 2, dim=-1) ** 0.5
            final_labels = (distance < self.inlier_threshold).float()
        if not testing:
            final_labels = confidence

        res = {
            "final_trans": final_trans,
            "final_labels": final_labels,
            "gt_labels": gt_labels,
            "final_labels3": final_labels3,
            "gt_labels3": gt_labels3,
            "final_labels2": final_labels2,
            "gt_labels2": gt_labels2,
            "M": M,
            "M3": M3,
            "M2": M2
        }
        return res

    def cal_seed_trans(self, seeds, corr_features, SC2_measure, src_keypts, tgt_keypts):
        """
        Calculate the transformation for each seeding correspondences.
        Input:
            - seeds:         [bs, num_seeds]              the index to the seeding correspondence
            - corr_features: [bs, num_corr, num_channels]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
        Output: leading eigenvector
            - pairwise_trans:    [bs, num_seeds, 4, 4]  transformation matrix for each seeding point.
            - pairwise_fitness:  [bs, num_seeds]        fitness (inlier ratio) for each seeding point
            - final_trans:       [bs, 4, 4]             best transformation matrix (after post refinement) for each batch.
            - final_labels:      [bs, num_corr]         inlier/outlier label given by best transformation matrix.
        """
        bs, num_corr, num_channels = corr_features.shape[0], corr_features.shape[1], corr_features.shape[2]
        num_seeds = seeds.shape[-1]
        k = min(self.k, num_corr - 1)
        sorted_score = torch.argsort(SC2_measure, dim=2, descending=True)
        knn_idx1 = sorted_score[:, :, 0: k]
        knn_idx = knn_idx1.contiguous()
        knn_features = corr_features.gather(dim=1,
                                            index=knn_idx.view([bs, -1])[:, :, None].expand(-1, -1, num_channels)).view(
            [bs, -1, k, num_channels])  # [bs, num_seeds, k, num_channels]
        knn_M = torch.matmul(knn_features, knn_features.permute(0, 1, 3, 2))
        knn_M = torch.clamp(1 - (1 - knn_M) / self.sigma ** 2, min=0)
        knn_M = knn_M.view([-1, k, k])
        feature_knn_M = knn_M

        idx_tmp = knn_idx.view([bs, -1])
        idx_tmp1 = idx_tmp[:, :, None]
        idx_tmp = idx_tmp1.expand(-1, -1, 3)  # [8,8000,3]

        total_knn_M = feature_knn_M

        total_knn_M[:, torch.arange(total_knn_M.shape[1]), torch.arange(total_knn_M.shape[1])] = 0
        total_weight = self.cal_leading_eigenvector(total_knn_M, method='power')
        total_weight = total_weight.view([bs, -1, k])
        total_weight = total_weight / (torch.sum(total_weight, dim=-1, keepdim=True) + 1e-6)

        #################################
        # calculate the transformation by weighted least-squares for each subsets in parallel
        #################################
        total_weight = total_weight.view([-1, k])  # [1600,40]
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])  # [8,200,40,3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])
        src_knn, tgt_knn = src_knn.view([-1, k, 3]), tgt_knn.view([-1, k, 3])

        seedwise_trans = rigid_transform_3d(src_knn, tgt_knn, total_weight)
        seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])

        #################################
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        #################################
        pred_position = torch.einsum('bsnm,bmk->bsnk', seedwise_trans[:, :, :3, :3],
                                     src_keypts.permute(0, 2, 1)) + seedwise_trans[:, :, :3,
                                                                    3:4]  # [bs, num_seeds, num_corr, 3]
        pred_position = pred_position.permute(0, 1, 3, 2)
        L2_dis = torch.norm(pred_position - tgt_keypts[:, None, :, :], dim=-1)  # [bs, num_seeds, num_corr]
        seedwise_fitness = torch.mean((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]
        # seedwise_inlier_rmse = torch.sum(L2_dis * (L2_dis < config.inlier_threshold).float(), dim=1)
        batch_best_guess = seedwise_fitness.argmax(dim=1)

        relax_num = self.NS_by_IC
        if relax_num > seedwise_fitness.shape[1]:
            relax_num = seedwise_fitness.shape[1]
        batch_best_guess_relax, batch_best_guess_relax_idx = torch.topk(seedwise_fitness, relax_num)

        # refine the pose by using all the inlier correspondences (done in the post-refinement step)
        final_trans = seedwise_trans.gather(dim=1,
                                            index=batch_best_guess[:, None, None, None].expand(-1, -1, 4, 4)).squeeze(1)
        final_labels = L2_dis.gather(dim=1,
                                     index=batch_best_guess[:, None, None].expand(-1, -1, L2_dis.shape[2])).squeeze(1)
        final_labels1 = final_labels
        final_trans1 = final_trans.reshape(final_trans.shape[0], -1)
        final_labels = (final_labels < self.inlier_threshold).float()
        seedwise_trans_relax = seedwise_trans.gather(dim=1,
                                                    index=batch_best_guess_relax_idx[:, :, None, None].expand(-1, -1,
                                                                                                              4, 4))

        return seedwise_trans_relax, seedwise_fitness, final_trans, final_labels, final_labels1, final_trans1

    def cal_leading_eigenvector(self, M, method='power'):
        """
        Calculate the leading eigenvector using power iteration algorithm or torch.symeig
        Input:
            - M:      [bs, num_corr, num_corr] the compatibility matrix
            - method: select different method for calculating the learding eigenvector.
        Output:
            - solution: [bs, num_corr] leading eigenvector
        """
        if method == 'power':
            # power iteration algorithm
            leading_eig = torch.ones_like(M[:, :, 0:1])
            leading_eig_last = leading_eig
            for i in range(self.num_iterations):
                leading_eig = torch.bmm(M, leading_eig)
                leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
                if torch.allclose(leading_eig, leading_eig_last):
                    break
                leading_eig_last = leading_eig
            leading_eig = leading_eig.squeeze(-1)
            return leading_eig
        elif method == 'eig':  # cause NaN during back-prop
            e, v = torch.symeig(M, eigenvectors=True)
            leading_eig = v[:, :, -1]
            return leading_eig
        else:
            exit(-1)

    def cal_confidence(self, M, leading_eig, method='eig_value'):
        """
        Calculate the confidence of the spectral matching solution based on spectral analysis.
        Input:
            - M:          [bs, num_corr, num_corr] the compatibility matrix
            - leading_eig [bs, num_corr]           the leading eigenvector of matrix M
        Output:
            - confidence
        """
        if method == 'eig_value':
            # max eigenvalue as the confidence (Rayleigh quotient)
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            confidence = max_eig_value.squeeze(-1)
            return confidence
        elif method == 'eig_value_ratio':
            # max eigenvalue / second max eigenvalue as the confidence
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            # compute the second largest eigen-value
            B = M - max_eig_value * leading_eig[:, :, None] @ leading_eig[:, None, :]
            solution = torch.ones_like(B[:, :, 0:1])
            for i in range(self.num_iterations):
                solution = torch.bmm(B, solution)
                solution = solution / (torch.norm(solution, dim=1, keepdim=True) + 1e-6)
            solution = solution.squeeze(-1)
            second_eig = solution
            second_eig_value = (second_eig[:, None, :] @ B @ second_eig[:, :, None]) / (
                    second_eig[:, None, :] @ second_eig[:, :, None])
            confidence = max_eig_value / second_eig_value
            return confidence
        elif method == 'xMx':
            # max xMx as the confidence (x is the binary solution)
            # rank = torch.argsort(leading_eig, dim=1, descending=True)[:, 0:int(M.shape[1]*self.ratio)]
            # binary_sol = torch.zeros_like(leading_eig)
            # binary_sol[0, rank[0]] = 1
            confidence = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
            confidence = confidence.squeeze(-1) / M.shape[1]
            return confidence

    def post_refinement(self, initial_trans, src_keypts, tgt_keypts, weights=None):
        """
        Perform post refinement using the initial transformation matrix, only adopted during testing.
        Input
            - initial_trans: [bs, 4, 4]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
            - weights:       [bs, num_corr]
        Output:
            - final_trans:   [bs, 4, 4]
        """
        #assert initial_trans.shape[0] == 1
        if self.inlier_threshold == 0.10:  # for 3DMatch
            inlier_threshold_list = [0.10] * 20
        else:  # for KITTI
            inlier_threshold_list = [1.2] * 20

        previous_inlier_num = 0
        for inlier_threshold in inlier_threshold_list:
            warped_src_keypts = transform(src_keypts, initial_trans)
            L2_dis = torch.norm(warped_src_keypts - tgt_keypts, dim=-1)
            pred_inlier = (L2_dis < inlier_threshold)[0]  # assume bs = 1
            #pred_inlier1 = (L2_dis < inlier_threshold).float()
            inlier_num = torch.sum(pred_inlier)
            if abs(int(inlier_num - previous_inlier_num)) < 1:
                break
            else:
                previous_inlier_num = inlier_num
            initial_trans = rigid_transform_3d(
                A=src_keypts[:, pred_inlier, :],
                B=tgt_keypts[:, pred_inlier, :],
                ## https://link.springer.com/article/10.1007/s10589-014-9643-2
                # weights=None,
                weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier],
                #weights=((1-L2_dis/inlier_threshold)**2)[:, pred_inlier],
            )

        return initial_trans
