import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.common import knn, rigid_transform_3d
from utils.SE3 import transform


##111
class ResNet_Block(nn.Module):
    def __init__(self, inchannel, outchannel, pre=False):
        super(ResNet_Block, self).__init__()
        self.pre = pre
        self.right = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, kernel_size=1),
        )
        self.left = nn.Sequential(
            nn.Conv1d(inchannel, outchannel, kernel_size=1),
            nn.InstanceNorm1d(outchannel),
            nn.BatchNorm1d(outchannel),
            nn.ReLU(),
            nn.Conv1d(outchannel, outchannel, kernel_size=1),
            nn.InstanceNorm1d(outchannel),
            nn.BatchNorm1d(outchannel),
        )

    def forward(self, x):
        x1 = self.right(x) if self.pre is True else x
        out = self.left(x)
        out = out + x1
        return torch.relu(out)


def knnms(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)topk���Դ�����
    return idx


def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_pts = x.size(2)  # [8,64.1000,1]
    # x = x.view(batch_size, -1, num_pts)  # change  #[8,64,1000]
    if idx is None:
        idx_out = knnms(x, k=k)  # (batch_size, num_points, k) [8,1000,20]
    else:
        idx_out = idx
    device = torch.device('cuda')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_pts  # change [8,1,1]

    idx = idx_out + idx_base  # [8,1000,20]

    idx = idx.view(-1)  # [160000,]

    _, num_dims, _ = x.size()

    x = x.transpose(2,
                    1).contiguous()  # [8,1000,64] # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_pts, -1)[idx, :]  # [160000,64]
    feature = feature.view(batch_size, num_pts, k, num_dims)  # change [8,1000,20,64]

    x = x.view(batch_size, num_pts, 1, num_dims).repeat(1, 1, k, 1)  # change #[8,1000,20,64]
    feature = torch.cat((x, x - feature), dim=3).permute(0, 3, 1, 2).contiguous()  # [8.128.1000,20]
    return feature


class NonLocalBlock(nn.Module):
    def __init__(self, num_channels, num_heads=1):
        super(NonLocalBlock, self).__init__()
        self.fc_message = nn.Sequential(
            nn.Conv1d(num_channels, num_channels // 2, kernel_size=1),
            nn.BatchNorm1d(num_channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels // 2, num_channels // 2, kernel_size=1),
            nn.BatchNorm1d(num_channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels // 2, num_channels, kernel_size=1),
        )
        self.embed_1 = nn.Sequential(
            ResNet_Block(num_channels, num_channels, pre=False),
        )
        self.projection_q = nn.Conv1d(num_channels, num_channels, kernel_size=1)
        self.projection_k = nn.Conv1d(num_channels, num_channels, kernel_size=1)
        self.projection_v = nn.Conv1d(num_channels, num_channels, kernel_size=1)
        self.num_channels = num_channels
        self.head = num_heads

    def forward(self, feat, attention):
        """
        Input:
            - feat:     [bs, num_channels, num_corr]  input feature
            - attention [bs, num_corr, num_corr]      spatial consistency matrix
        Output:
            - res:      [bs, num_channels, num_corr]  updated feature
        """
        bs, num_corr = feat.shape[0], feat.shape[-1]
        Q = self.projection_q(feat).view([bs, self.head, self.num_channels // self.head, num_corr])
        K = self.projection_k(feat).view([bs, self.head, self.num_channels // self.head, num_corr])
        V = self.projection_v(feat).view([bs, self.head, self.num_channels // self.head, num_corr])
        feat_attention = torch.einsum('bhco, bhci->bhoi', Q, K) / (self.num_channels // self.head) ** 0.5
        # combine the feature similarity with spatial consistency
        weight = torch.softmax(attention[:, None, :, :] * feat_attention, dim=-1)
        message = torch.einsum('bhoi, bhci-> bhco', weight, V).reshape([bs, -1, num_corr])
        message = self.fc_message(message)
        # message = feat + message
        res = self.embed_1(message)
        return res


class NonLocalNet(nn.Module):
    def __init__(self, in_dim, num_layers, num_channels):
        super(NonLocalNet, self).__init__()
        self.num_layers = num_layers
        #self.k_num = k_num
        # self.out_channel = num_channels
        self.blocks = nn.ModuleDict()
        self.layer0 = nn.Conv1d(in_dim, num_channels, kernel_size=1, bias=True)
        self.embed_0 = nn.Sequential(
            ResNet_Block(num_channels, num_channels, pre=False),
            ResNet_Block(num_channels, num_channels, pre=False),
            ResNet_Block(num_channels, num_channels, pre=False),
            # ResNet_Block(num_channels, num_channels, pre=False),
        )
        '''self.embed_1 = nn.Sequential(
            ResNet_Block(num_channels, num_channels, pre=False),
        )'''
        layer = nn.Sequential(
            nn.InstanceNorm1d(num_channels),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels, num_channels, kernel_size=1, bias=True),
        )
        '''layer2 = nn.Sequential(
            nn.InstanceNorm1d(num_channels // 4),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels // 4, num_channels // 2, kernel_size=1, bias=True),
            nn.InstanceNorm1d(num_channels // 2),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels // 2, num_channels // 2, kernel_size=1, bias=True),
        )'''
        '''layer = nn.Sequential(
            nn.InstanceNorm1d(num_channels // 2),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels // 2, num_channels // 2, kernel_size=1, bias=True),
            nn.InstanceNorm1d(num_channels // 2),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels // 2, num_channels, kernel_size=1, bias=True),
            nn.InstanceNorm1d(num_channels),
            # nn.BatchNorm1d(num_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(num_channels, num_channels, kernel_size=1, bias=True),
        )'''
        #self.blocks[f'PointCN_layer1'] = layer1
        #self.blocks[f'PointCN_layer2'] = layer2
        self.blocks[f'PointCN_layer'] = layer
        #self.blocks[f'NonLocal_layer1'] = NonLocalBlock(num_channels // 4)
        #self.blocks[f'NonLocal_layer2'] = NonLocalBlock(num_channels // 2)
        self.blocks[f'NonLocal_layer'] = NonLocalBlock(num_channels)
        '''for i in range(num_layers):
            layer = nn.Sequential(
                nn.InstanceNorm1d(num_channels),
                #nn.BatchNorm1d(num_channels),
                nn.ReLU(inplace=True),
                nn.Conv1d(num_channels, num_channels, kernel_size=1, bias=True),
            )
            self.blocks[f'PointCN_layer_{i}'] = layer
            self.blocks[f'NonLocal_layer_{i}'] = NonLocalBlock(num_channels)'''

    def forward(self, corr_feat, corr_compatibility):
        """
        Input:
            - corr_feat:          [bs, in_dim, num_corr]   input feature map
            - corr_compatibility: [bs, num_corr, num_corr] spatial consistency matrix
        Output:
            - feat:               [bs, num_channels, num_corr] updated feature
        """
        feat = self.layer0(corr_feat)
        feat = self.embed_0(feat)
        ##feat = get_graph_feature(feat, k=self.k_num)  # [8,256,1000,20]
        ##feat = torch.sum(feat, dim=3)
        feat = self.blocks[f'PointCN_layer'](feat)
        feat = self.blocks[f'NonLocal_layer'](feat, corr_compatibility)
        #feat = self.blocks[f'PointCN_layer1'](feat)
        ##for i in range(self.num_layers):
            ##feat = self.blocks[f'PointCN_layer_{i}'](feat)
            #feat = self.blocks[f'NonLocal_layer_{i}'](feat, corr_compatibility)
        # feat = self.embed_0(feat)
        return feat


class PGNet(nn.Module):
    def __init__(self,
                 in_dim=6,
                 num_layers=6,
                 num_channels=256,
                 num_iterations=10,
                 d_thre=0.10,
                 a_thre=10.0,
                 ratio=0.20,
                 pruning=0.50,
                 inlier_threshold=0.10,
                 sigma_d=0.10,
                 sigma_a=10.0,
                 k=40,
                 nms_radius=0.10,
                 k1=30,
                 k2=20,
                 num_node=1000,
                 #max_points=8000,
                 relax_match_num=100,
                 FS_TCD_thre=0.05,
                 NS_by_IC=50,
                 ):
        super(PGNet, self).__init__()
        self.num_iterations = num_iterations  # maximum iteration of power iteration algorithm
        self.ratio = ratio  # the maximum ratio of seeds.
        self.pruning = pruning
        self.num_channels = num_channels
        self.inlier_threshold = inlier_threshold
        self.d_thre = d_thre
        self.a_thre = a_thre
        self.sigma = nn.Parameter(torch.Tensor([1.0]).float(), requires_grad=True)
        self.sigma_spat = nn.Parameter(torch.Tensor([sigma_d]).float(), requires_grad=False)
        self.sigma_a = nn.Parameter(torch.Tensor([sigma_a]).float(), requires_grad=False)
        self.k = k  # neighborhood number in NSM module.
        self.nms_radius = nms_radius  # only used during testing
        self.k1 = k1
        self.k2 = k2
        self.num_node = num_node
        #self.max_points = max_points
        self.relax_match_num = relax_match_num
        self.FS_TCD_thre = FS_TCD_thre
        self.NS_by_IC = NS_by_IC
        # self.layer1 = nn.Conv1d(7, 6, kernel_size=1, bias=True)
        self.encoder = NonLocalNet(
            in_dim=in_dim,
            num_layers=num_layers,
            num_channels=num_channels,
        )

        self.classification = nn.Sequential(
            nn.Conv1d(num_channels, 32, kernel_size=1, bias=True),
            nn.InstanceNorm1d(32, eps=1e-3),
            # nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 32, kernel_size=1, bias=True),
            nn.InstanceNorm1d(32, eps=1e-3),
            # nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 1, kernel_size=1, bias=True)
        )

        # initialization
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.xavier_normal_(m.weight, gain=1)
            elif isinstance(m, (nn.BatchNorm1d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # add gradient clipping
        # grad_clip_norm = 100
        # for p in self.parameters():
        #     p.register_hook(lambda grad: torch.clamp(grad, -grad_clip_norm, grad_clip_norm))

    def forward(self, data):
        """
        Input:
            - corr_pos:   [bs, num_corr, 6]
            - src_keypts: [bs, num_corr, 3]
            - tgt_keypts: [bs, num_corr, 3]
            - testing:    flag for test phase, if False will not calculate M and post-refinement.
        Output: (dict)
            - final_trans:   [bs, 4, 4], the predicted transformation matrix.
            - final_labels:  [bs, num_corr], the predicted inlier/outlier label (0,1), for classification loss calculation.
            - M:             [bs, num_corr, num_corr], feature similarity matrix, for SM loss calculation.
            - seed_trans:    [bs, num_seeds, 4, 4],  the predicted transformation matrix associated with each seeding point, deprecated.
            - corr_features: [bs, num_corr, num_channels], the feature for each correspondence, for circle loss calculation, deprecated.
            - confidence:    [bs], confidence of returned results, for safe guard, deprecated.
        """
        #corr_pos, src_keypts, tgt_keypts, src_normal, tgt_normal, gt_labels = data['corr_pos'], data['src_keypts'], \
        #data['tgt_keypts'], data['src_normal'], data['tgt_normal'], data['gt_labels']
        #corr_pos, src_keypts, tgt_keypts, src_normal, tgt_normal, gt_labels, distance, = data['corr_pos'], data[
            #'src_keypts'], data['tgt_keypts'], data['src_normal'], data['tgt_normal'], data['gt_labels'], data[
            #'distance']
        corr_pos, src_keypts, tgt_keypts, src_normal, tgt_normal, gt_labels, src_features, tgt_features, src_desc, tgt_desc, distance, = \
        data['corr_pos'], data['src_keypts'], data['tgt_keypts'], data['src_normal'], data['tgt_normal'], data[
            'gt_labels'], data['src_features'], data['tgt_features'], data['src_desc'], data['tgt_desc'], data[
            'distance']

        # corr_pos, src_keypts, tgt_keypts = data['corr_pos'], data['src_keypts'], data['tgt_keypts']
        bs, num_corr = corr_pos.shape[0], corr_pos.shape[1]
        testing = 'testing' in data.keys()
        # src_keypts, tgt_keypts = corr_pos[:, :, :3], corr_pos[:, :, 3:6]

        #################################
        # Step1: extract feature for each correspondence
        #################################
        with torch.no_grad():
            src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
            length_compatibility = src_dist - torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]),
                                                         dim=-1)

            src_normal1 = torch.nn.functional.normalize(src_normal, dim=-1, p=2)
            tgt_normal1 = torch.nn.functional.normalize(tgt_normal, dim=-1, p=2)
            src_dot_products = torch.sum(src_normal1[:, :, None, :] * src_normal1[:, None, :, :], dim=-1)
            tgt_dot_products = torch.sum(tgt_normal1[:, :, None, :] * tgt_normal1[:, None, :, :], dim=-1)
            # ����ÿ������ļнǣ���Ԫ�ز�����
            src_angle = torch.acos(torch.clamp((src_dot_products - 1) / 2.0, min=-1, max=1))
            src_angle = src_angle * (180.0 / np.pi)
            tgt_angle = torch.acos(torch.clamp((tgt_dot_products - 1) / 2.0, min=-1, max=1))
            tgt_angle = tgt_angle * (180.0 / np.pi)
            angle_compatibility = src_angle - tgt_angle
            # corr_compatibility = torch.clamp(1.0 - corr_compatibility ** 2 / self.sigma_spat ** 2, min=0)
            # corr_compatibility = torch.exp(-(length_compatibility ** 2) / (self.sigma_spat ** 2) - (angle_compatibility ** 2) / (self.sigma_a ** 2))
            corr_compatibility = torch.exp(
                - length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
            length_compatibility = torch.abs(length_compatibility)
            angle_compatibility = torch.abs(angle_compatibility)
            hard_corr_compatibility = (
                        (length_compatibility < self.sigma_spat) & (angle_compatibility < self.sigma_a)).float()
            # corr_compatibility = torch.exp(- length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
            # corr_compatibility = torch.clamp(1.0 - angle_compatibility ** 2 / self.sigma_a ** 2 - length_compatibility ** 2 / self.sigma_spat ** 2, min=0)
        corr_features = self.encoder(corr_pos.permute(0, 2, 1), corr_compatibility).permute(0, 2, 1)
        normed_corr_features = F.normalize(corr_features, p=2, dim=-1)

        #################################
        # Step 2.1: estimate initial confidence by MLP, find highly confident and well-distributed points as seeds.
        #################################
        # confidence = self.cal_leading_eigenvector(M.to(corr_pos.device), method='power')
        confidence = self.classification(corr_features.permute(0, 2, 1)).squeeze(1)

        # if testing:
        # seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))
        # else:
        seeds = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.ratio)]

        hard_SC_measure_tight = ((length_compatibility < self.sigma_spat) & (
                    angle_compatibility < self.sigma_a)).float()  # [16,1000,1000]
        # [16,100,1000]
        seed_hard_SC_measure = hard_corr_compatibility.gather(dim=1,
                                                              index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure

        #################################
        # Step 3 & 4: calculate transformation matrix for each seed, and find the best hypothesis.
        #################################
        seed_trans, seed_fitness, final_trans, final_labels, final_labels1, final_trans1 = self.cal_seed_trans(seeds,
                                                                                                               normed_corr_features,
                                                                                                               SC2_measure,
                                                                                                               src_keypts,
                                                                                                               tgt_keypts,
                                                                                                               src_normal,
                                                                                                               tgt_normal)

        Candidates = torch.argsort(final_labels1, dim=1, descending=False)[:, 0:int(num_corr * self.pruning)]
        corr_pos = corr_pos.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 6))
        src_keypts = src_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_keypts = tgt_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_normal = src_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_normal = tgt_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_features = src_features.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_features = tgt_features.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_desc = src_desc.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        gt_labels = gt_labels.gather(dim=1, index=Candidates)
        #distance = distance.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 1000))

        #################################
        # Step2
        #################################
        bs, num_corr = corr_pos.shape[0], corr_pos.shape[1]
        testing = 'testing' in data.keys()
        # x1, x2 = corr_pos[:, :, :3], corr_pos[:, :, 3:6]
        with torch.no_grad():
            src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
            length_compatibility = src_dist - torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]),
                                                         dim=-1)

            src_normal1 = torch.nn.functional.normalize(src_normal, dim=-1, p=2)
            tgt_normal1 = torch.nn.functional.normalize(tgt_normal, dim=-1, p=2)
            src_dot_products = torch.sum(src_normal1[:, :, None, :] * src_normal1[:, None, :, :], dim=-1)
            tgt_dot_products = torch.sum(tgt_normal1[:, :, None, :] * tgt_normal1[:, None, :, :], dim=-1)
            # ����ÿ������ļнǣ���Ԫ�ز�����
            src_angle = torch.acos(torch.clamp((src_dot_products - 1) / 2.0, min=-1, max=1))
            src_angle = src_angle * (180.0 / np.pi)
            tgt_angle = torch.acos(torch.clamp((tgt_dot_products - 1) / 2.0, min=-1, max=1))
            tgt_angle = tgt_angle * (180.0 / np.pi)
            angle_compatibility = src_angle - tgt_angle
            # corr_compatibility = torch.clamp(1.0 - corr_compatibility ** 2 / self.sigma_spat ** 2, min=0)
            # corr_compatibility = torch.exp(-(length_compatibility ** 2) / (self.sigma_spat ** 2) - (angle_compatibility ** 2) / (self.sigma_a ** 2))
            corr_compatibility = torch.exp(
                - length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
            length_compatibility = torch.abs(length_compatibility)
            angle_compatibility = torch.abs(angle_compatibility)
            hard_corr_compatibility = (
                    (length_compatibility < self.sigma_spat) & (angle_compatibility < self.sigma_a)).float()
            # corr_compatibility = torch.exp(- length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
            # corr_compatibility = torch.clamp(1.0 - angle_compatibility ** 2 / self.sigma_a ** 2 - length_compatibility ** 2 / self.sigma_spat ** 2, min=0)
        corr_features = self.encoder(corr_pos.permute(0, 2, 1), corr_compatibility).permute(0, 2, 1)
        normed_corr_features = F.normalize(corr_features, p=2, dim=-1)

        if not testing:  # during training or validation
            # construct the feature similarity matrix M for loss calculation
            M = torch.matmul(normed_corr_features, normed_corr_features.permute(0, 2, 1))
            M = torch.clamp(1 - (1 - M) / self.sigma ** 2, min=0, max=1)
            # set diagnal of M to zero
            M[:, torch.arange(M.shape[1]), torch.arange(M.shape[1])] = 0
        else:
            M = None

            #################################
        # Step 2.1: estimate initial confidence by MLP, find highly confident and well-distributed points as seeds.
        #################################
        # confidence = self.cal_leading_eigenvector(M.to(corr_pos.device), method='power')
        confidence = self.classification(corr_features.permute(0, 2, 1)).squeeze(1)

        # if testing:
        # seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))
        # else:
        # Candidates = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.pruning)]
        seeds = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.ratio)]

        hard_SC_measure_tight = (
                    (length_compatibility < self.sigma_spat) & (angle_compatibility < self.sigma_a)).float()
        seed_hard_SC_measure = hard_corr_compatibility.gather(dim=1,
                                                              index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure
        #################################
        # Step 3 & 4: calculate transformation matrix for each seed, and find the best hypothesis.
        #################################

        seed_trans, seed_fitness, final_trans, final_labels, final_labels1, final_trans1 = self.cal_seed_trans(seeds,
                                                                                                               normed_corr_features,
                                                                                                               SC2_measure,
                                                                                                               src_keypts,
                                                                                                               tgt_keypts,
                                                                                                               src_normal,
                                                                                                               tgt_normal)

        Candidates = torch.argsort(final_labels1, dim=1, descending=False)[:, 0:int(num_corr * self.pruning)]
        corr_pos = corr_pos.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 6))
        src_keypts = src_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_keypts = tgt_keypts.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_normal = src_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_normal = tgt_normal.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_features = src_features.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        tgt_features = tgt_features.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        src_keypts1 = src_desc.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 3))
        gt_labels = gt_labels.gather(dim=1, index=Candidates)
        #distance = distance.gather(dim=1, index=Candidates[:, :, None].expand(-1, -1, 1000))
        #################################
        # Step2
        #################################
        bs, num_corr = corr_pos.shape[0], corr_pos.shape[1]
        testing = 'testing' in data.keys()
        # x1, x2 = corr_pos[:, :, :3], corr_pos[:, :, 3:6]
        with torch.no_grad():
            src_dist = torch.norm((src_keypts[:, :, None, :] - src_keypts[:, None, :, :]), dim=-1)
            length_compatibility = src_dist - torch.norm((tgt_keypts[:, :, None, :] - tgt_keypts[:, None, :, :]),
                                                         dim=-1)

            src_normal1 = torch.nn.functional.normalize(src_normal, dim=-1, p=2)
            tgt_normal1 = torch.nn.functional.normalize(tgt_normal, dim=-1, p=2)
            src_dot_products = torch.sum(src_normal1[:, :, None, :] * src_normal1[:, None, :, :], dim=-1)
            tgt_dot_products = torch.sum(tgt_normal1[:, :, None, :] * tgt_normal1[:, None, :, :], dim=-1)
            # ����ÿ������ļнǣ���Ԫ�ز�����
            src_angle = torch.acos(torch.clamp((src_dot_products - 1) / 2.0, min=-1, max=1))
            src_angle = src_angle * (180.0 / np.pi)
            tgt_angle = torch.acos(torch.clamp((tgt_dot_products - 1) / 2.0, min=-1, max=1))
            tgt_angle = tgt_angle * (180.0 / np.pi)
            angle_compatibility = src_angle - tgt_angle
            # corr_compatibility = torch.clamp(1.0 - corr_compatibility ** 2 / self.sigma_spat ** 2, min=0)
            # corr_compatibility = torch.exp(-(length_compatibility ** 2) / (self.sigma_spat ** 2) - (angle_compatibility ** 2) / (self.sigma_a ** 2))
            corr_compatibility = torch.exp(
                - length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
            length_compatibility = torch.abs(length_compatibility)
            angle_compatibility = torch.abs(angle_compatibility)
            hard_corr_compatibility = (
                    (length_compatibility < self.sigma_spat) & (angle_compatibility < self.sigma_a)).float()
            # corr_compatibility = torch.exp(- length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
            # corr_compatibility = torch.clamp(1.0 - angle_compatibility ** 2 / self.sigma_a ** 2 - length_compatibility ** 2 / self.sigma_spat ** 2, min=0)
        corr_features = self.encoder(corr_pos.permute(0, 2, 1), corr_compatibility).permute(0, 2, 1)
        normed_corr_features = F.normalize(corr_features, p=2, dim=-1)

        if not testing:  # during training or validation
            # construct the feature similarity matrix M for loss calculation
            M = torch.matmul(normed_corr_features, normed_corr_features.permute(0, 2, 1))
            M = torch.clamp(1 - (1 - M) / self.sigma ** 2, min=0, max=1)
            # set diagnal of M to zero
            M[:, torch.arange(M.shape[1]), torch.arange(M.shape[1])] = 0
        else:
            M = None

            #################################
        # Step 2.1: estimate initial confidence by MLP, find highly confident and well-distributed points as seeds.
        #################################
        # confidence = self.cal_leading_eigenvector(M.to(corr_pos.device), method='power')
        confidence = self.classification(corr_features.permute(0, 2, 1)).squeeze(1)

        # if testing:
        # seeds = self.pick_seeds(src_dist, confidence, R=self.nms_radius, max_num=int(num_corr * self.ratio))
        # else:
        # Candidates = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.pruning)]
        seeds = torch.argsort(confidence, dim=1, descending=True)[:, 0:int(num_corr * self.ratio)]

        hard_SC_measure_tight = (
                    (length_compatibility < self.sigma_spat) & (angle_compatibility < self.sigma_a)).float()
        seed_hard_SC_measure = hard_corr_compatibility.gather(dim=1,
                                                              index=seeds[:, :, None].expand(-1, -1, num_corr))
        seed_hard_SC_measure_tight = hard_SC_measure_tight.gather(dim=1,
                                                                  index=seeds[:, :, None].expand(-1, -1, num_corr))
        SC2_measure = torch.matmul(seed_hard_SC_measure_tight, hard_SC_measure_tight) * seed_hard_SC_measure
        #################################
        # Step 3 & 4: calculate transformation matrix for each seed, and find the best hypothesis.
        #################################

        #relax_match_points, relax_distance = self.match_pair(src_keypts, tgt_keypts, src_features, tgt_features)
        #relax_match_points, relax_distance = self.match_pair(src_keypts, tgt_keypts, src_features, tgt_features)
        distance = torch.sqrt(2 - 2 * (src_features[0] @ tgt_features[0].T) + 1e-6)
        distance = distance.unsqueeze(0)#.expand(16, -1, -1)
        relax_num = self.relax_match_num
        dim_size = distance.size(-1)
        relax_num = min(relax_num, dim_size)
        relax_distance, relax_source_idx = torch.topk(distance, k=relax_num, dim=-1, largest=False)  # [8,1000,100]

        relax_source_idx = relax_source_idx.view(relax_source_idx.shape[0], -1)[:, :, None].expand(-1, -1,
                                                                                                   3)  # [1,25000,3]
        relax_match_points = tgt_keypts.gather(dim=1, index=relax_source_idx).view(relax_source_idx.shape[0], -1,
                                                                                   relax_num, 3)  # [8,1000,100,3]
        seed_trans, seed_fitness, final_trans, final_labels, final_labels1, final_trans1 = self.cal_seed_trans(seeds,
                                                                                                               normed_corr_features,
                                                                                                               SC2_measure,
                                                                                                               src_keypts,
                                                                                                               tgt_keypts,
                                                                                                               src_normal,
                                                                                                               tgt_normal)
        #src_keypts_corr = src_keypts[corr_pos[:, 0]]
        #tgt_keypts_corr = tgt_keypts[corr_pos[:, 1]]
        #src_keypts_corr, tgt_keypts_corr = corr_pos[:, :, :3], corr_pos[:, :, 3:6]
        final_trans = self.select_best_trans(seed_trans, src_keypts1, relax_match_points, relax_distance, src_keypts, tgt_keypts)

        if testing:
            final_trans = self.post_refinement(final_trans, src_keypts, tgt_keypts, 20)
            frag1_warp = transform(src_keypts, final_trans)
            distance = torch.sum((frag1_warp - tgt_keypts) ** 2, dim=-1) ** 0.5
            final_labels = (distance < self.inlier_threshold).float()
        ## during training, return the initial confidence as logits for classification loss
        ## during testing, return the final labels given by final transformation.
        # if not testing:
        # final_labels = confidence
        res = {
            "final_trans": final_trans,
            "final_labels": final_labels,
            "gt_labels": gt_labels,
            # "src_keypts": src_keypts,
            # "tgt_keypts": tgt_keypts,
            "M": M
        }
        return res

    def select_best_trans(self, seed_trans, src_keypts, relax_match_points, relax_distance, src_keypts_corr,
                          tgt_keypts_corr):

        """
        Select the best model from the rough models filtered by IC Metric
        Input:
            - seed_trans:  [bs, N_s^{'}, 4, 4]   the model selected by IC, N_s^{'} is the number of reserverd transformation
            - src_keypts   [bs, N, 3]   the source point cloud
            - relax_match_points  [1, N, K, 3]  for each source point, we find K target points as the potential correspondences
            - relax_distance [bs, N, K]  feature distance for the relaxed matches
            - src_keypts_corr [bs, N_C, 3]  source points of N_C one-to-one correspondences
            - tgt_keypts_corr [bs, N_C, 3]  target points of N_C one-to-one correspondences
        Output:
            - the best transformation selected by FS-TCD
        """

        seed_num = seed_trans.shape[1]
        # self.inlier_threshold == 0.10: # for 3DMatch

        best_trans = None
        best_fitness = 0
        SC_thre = self.FS_TCD_thre

        for i in range(seed_num):
            # 1. refine the transformation by the one-to-one correspondences
            initial_trans = seed_trans[:, i, :, :]
            initial_trans = self.post_refinement(initial_trans, src_keypts_corr, tgt_keypts_corr, 1)

            # 2. use the transformation to project the source point cloud to target point cloud, and find the nearest neighbor
            warped_src_keypts = transform(src_keypts, initial_trans)  # 过对源点云，应用初始变换，来生成变换后的点云
            cross_dist = torch.norm((warped_src_keypts[:, :, None, :] - relax_match_points),
                                    dim=-1)  # 计算了每个变换后的源点云点与目标点云中的可能匹配点之间的距离
            warped_neighbors = (
                        cross_dist <= SC_thre).float()  # 如果点之间的距离小于或等于预定义的阈值 self.inlier_threshold，则被标记为有效匹配，并在 warped_neighbors 中用1表示，否则用0表示。
            renew_distance = relax_distance + 2 * (cross_dist > SC_thre * 1.5).float()
            #renew_distance = relax_distance + (cross_dist <= SC_thre).float()
            _, mask_min_idx = renew_distance.min(dim=-1)

            # 3. find the correspondences whose alignment error is less than the threshold
            corr = torch.cat([torch.arange(mask_min_idx.shape[1])[:, None].cuda(), mask_min_idx[0][:, None]], dim=-1)
            verify_mask = warped_neighbors
            verify_mask_row = verify_mask.sum(-1) > 0

            # 4. use the spatial consistency to verify the correspondences
            if verify_mask_row.float().sum() > 0:
                verify_mask_row_idx = torch.where(verify_mask_row == True)
                corr_select = corr[verify_mask_row_idx[1]]
                select_relax_match_points = relax_match_points[:, verify_mask_row_idx[1]]
                src_keypts_corr = src_keypts[:, corr_select[:, 0]]
                tgt_keypts_corr = select_relax_match_points.gather(dim=2,
                                                                   index=corr_select[:, 1][None, :, None, None].expand(
                                                                       -1, -1, -1, 3)).squeeze(dim=2)
                src_dist = torch.norm((src_keypts_corr[:, :, None, :] - src_keypts_corr[:, None, :, :]), dim=-1)
                target_dist = torch.norm((tgt_keypts_corr[:, :, None, :] - tgt_keypts_corr[:, None, :, :]), dim=-1)
                corr_compatibility = src_dist - target_dist
                abs_corr_compatibility = torch.abs(corr_compatibility)


                corr_compatibility_2 = (abs_corr_compatibility < SC_thre).float()
                compatibility_num = torch.sum(corr_compatibility_2, -1)
                renew_fitness = torch.max(compatibility_num)
            else:
                renew_fitness = 0

            if renew_fitness > best_fitness:
                best_trans = initial_trans
                best_fitness = renew_fitness

        return best_trans

    def cal_seed_trans(self, seeds, corr_features, SC2_measure, src_keypts, tgt_keypts, src_normal, tgt_normal):
        """
        Calculate the transformation for each seeding correspondences.
        Input:
            - seeds:         [bs, num_seeds]              the index to the seeding correspondence
            - corr_features: [bs, num_corr, num_channels]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
        Output: leading eigenvector
            - pairwise_trans:    [bs, num_seeds, 4, 4]  transformation matrix for each seeding point.
            - pairwise_fitness:  [bs, num_seeds]        fitness (inlier ratio) for each seeding point
            - final_trans:       [bs, 4, 4]             best transformation matrix (after post refinement) for each batch.
            - final_labels:      [bs, num_corr]         inlier/outlier label given by best transformation matrix.
        """
        bs, num_corr, num_channels = corr_features.shape[0], corr_features.shape[1], corr_features.shape[2]
        num_seeds = seeds.shape[-1]
        k = min(self.k, num_corr - 1)
        # knn_idx = knn(corr_features, k=k, ignore_self=True, normalized=True)  # [bs, num_corr, k]
        # knn_idx = knn_idx.gather(dim=1, index=seeds[:, :, None].expand(-1, -1, k))  # [bs, num_seeds, k]
        knn_idx = knn(corr_features, k=k, ignore_self=True, normalized=True)  # [bs, num_corr, k]
        knn_idx = knn_idx.gather(dim=1, index=seeds[:, :, None].expand(-1, -1, k))  # [bs, num_seeds, k]

        #################################
        # construct the feature consistency matrix of each correspondence subset.
        #################################
        knn_features = corr_features.gather(dim=1,
                                            index=knn_idx.view([bs, -1])[:, :, None].expand(-1, -1, num_channels)).view(
            [bs, -1, k, num_channels])  # [bs, num_seeds, k, num_channels]
        knn_M = torch.matmul(knn_features, knn_features.permute(0, 1, 3, 2))
        knn_M = torch.clamp(1 - (1 - knn_M) / self.sigma ** 2, min=0)
        # knn_M = torch.exp(- (1 - knn_M) / self.sigma ** 2)
        knn_M = knn_M.view([-1, k, k])
        feature_knn_M = knn_M

        sorted_score = torch.argsort(SC2_measure, dim=2, descending=True)
        knn_idx = sorted_score[:, :, 0: k]
        sorted_value, _ = torch.sort(SC2_measure, dim=2, descending=True)
        idx_tmp = knn_idx.contiguous().view([bs, -1])
        idx_tmp1 = idx_tmp[:, :, None]
        idx_tmp = idx_tmp1.expand(-1, -1, 3)  # [8,8000,3]
        idx_tmp2 = idx_tmp1.expand(-1, -1, 128)  # [8,8000,3]
        #################################
        # construct the feature consistency matrix of each correspondence subset.
        #################################
        '''knn_features = corr_features.gather(dim=1, index=idx_tmp2).view([bs, -1, k, num_channels])
        #knn_features = corr_features.gather(dim=1, index=knn_idx.view([bs, -1])[:, :, None].expand(-1, -1, num_channels)).view(
            #[bs, -1, k, num_channels])  # [bs, num_seeds, k, num_channels]
        knn_M = torch.matmul(knn_features, knn_features.permute(0, 1, 3, 2))
        knn_M = torch.clamp(1 - (1 - knn_M) / self.sigma ** 2, min=0)
        # knn_M = torch.exp(- (1 - knn_M) / self.sigma ** 2)
        knn_M = knn_M.view([-1, k, k])
        feature_knn_M = knn_M'''

        #################################
        # construct the spatial consistency matrix of each correspondence subset.
        #################################
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])  # [bs, num_seeds, k, 3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])
        src_dist = ((src_knn[:, :, :, None, :] - src_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        tgt_dist = ((tgt_knn[:, :, :, None, :] - tgt_knn[:, :, None, :, :]) ** 2).sum(-1) ** 0.5
        length_compatibility = src_dist - tgt_dist
        src_normal_knn2 = src_normal.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])
        tgt_normal_knn2 = tgt_normal.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])
        src_normal2 = torch.nn.functional.normalize(src_normal_knn2, dim=-1, p=2)
        tgt_normal2 = torch.nn.functional.normalize(tgt_normal_knn2, dim=-1, p=2)
        src_dot_products = torch.sum(src_normal2[:, :, :, None, :] * src_normal2[:, :, None, :, :],
                                     dim=-1)  # [8,1000,1000]
        tgt_dot_products = torch.sum(tgt_normal2[:, :, :, None, :] * tgt_normal2[:, :, None, :, :], dim=-1)
        # ����ÿ������ļнǣ���Ԫ�ز�����
        src_angle = torch.acos(torch.clamp((src_dot_products - 1) / 2.0, min=-1, max=1))
        src_angle = src_angle * (180.0 / np.pi)
        tgt_angle = torch.acos(torch.clamp((tgt_dot_products - 1) / 2.0, min=-1, max=1))
        tgt_angle = tgt_angle * (180.0 / np.pi)
        angle_compatibility = src_angle - tgt_angle

        local_SC_measure = torch.exp(
            - length_compatibility ** 2 / self.sigma_spat ** 2 - angle_compatibility ** 2 / self.sigma_a ** 2)
        # local_SC2_measure = local_SC_measure * local_SC2_measure
        local_SC2_measure = local_SC_measure
        spatial_knn_M = local_SC2_measure.view([-1, k, k])

        total_knn_M = spatial_knn_M + feature_knn_M
        #total_knn_M = spatial_knn_M

        total_knn_M[:, torch.arange(total_knn_M.shape[1]), torch.arange(total_knn_M.shape[1])] = 0
        # total_knn_M = self.gamma * feature_knn_M + (1 - self.gamma) * spatial_knn_M
        total_weight = self.cal_leading_eigenvector(total_knn_M, method='power')
        total_weight = total_weight.view([bs, -1, k])
        total_weight = total_weight / (torch.sum(total_weight, dim=-1, keepdim=True) + 1e-6)

        #################################
        # calculate the transformation by weighted least-squares for each subsets in parallel
        #################################
        total_weight = total_weight.view([-1, k])  # [1600,40]
        src_knn = src_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])  # [8,200,40,3]
        tgt_knn = tgt_keypts.gather(dim=1, index=idx_tmp).view([bs, -1, k, 3])
        src_knn, tgt_knn = src_knn.view([-1, k, 3]), tgt_knn.view([-1, k, 3])
        seed_as_center = False

        if seed_as_center:
            # if use seeds as the neighborhood centers
            src_center = src_keypts.gather(dim=1, index=seeds[:, :, None].expand(-1, -1, 3))  # [bs, num_seeds, 3]
            tgt_center = tgt_keypts.gather(dim=1, index=seeds[:, :, None].expand(-1, -1, 3))  # [bs, num_seeds, 3]
            src_center, tgt_center = src_center.view([-1, 3]), tgt_center.view([-1, 3])
            src_pts = src_knn[:, :, :, None] - src_center[:, None, :, None]  # [bs*num_seeds, k, 3, 1]
            tgt_pts = tgt_knn[:, :, :, None] - tgt_center[:, None, :, None]  # [bs*num_seeds, k, 3, 1]
            cov = torch.einsum('nkmo,nkop->nkmp', src_pts, tgt_pts.permute(0, 1, 3, 2))  # [bs*num_seeds, k, 3, 3]
            Covariances = torch.einsum('nkmp,nk->nmp', cov, total_weight)  # [bs*num_seeds, 3, 3]

            # use svd to recover the transformation for each seeding point, torch.svd is much faster on cpu.
            U, S, Vt = torch.svd(Covariances.cpu())
            U, S, Vt = U.cuda(), S.cuda(), Vt.cuda()
            delta_UV = torch.det(Vt @ U.permute(0, 2, 1))
            eye = torch.eye(3)[None, :, :].repeat(U.shape[0], 1, 1).to(U.device)
            eye[:, -1, -1] = delta_UV
            R = Vt @ eye @ U.permute(0, 2, 1)  # [num_pair, 3, 3]
            t = tgt_center[:, None, :] - src_center[:, None, :] @ R.permute(0, 2, 1)  # [num_pair, 1, 3]

            seedwise_trans = torch.eye(4)[None, :, :].repeat(R.shape[0], 1, 1).to(R.device)
            seedwise_trans[:, 0:3, 0:3] = R.permute(0, 2, 1)
            seedwise_trans[:, 0:3, 3:4] = t.permute(0, 2, 1)
            seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])
        else:
            # not use seeds as neighborhood centers.
            seedwise_trans = rigid_transform_3d(src_knn, tgt_knn, total_weight)
            seedwise_trans = seedwise_trans.view([bs, -1, 4, 4])

        #################################
        # calculate the inlier number for each hypothesis, and find the best transformation for each point cloud pair
        #################################
        pred_position = torch.einsum('bsnm,bmk->bsnk', seedwise_trans[:, :, :3, :3],
                                     src_keypts.permute(0, 2, 1)) + seedwise_trans[:, :, :3,
                                                                    3:4]  # [bs, num_seeds, num_corr, 3]
        pred_position = pred_position.permute(0, 1, 3, 2)
        L2_dis = torch.norm(pred_position - tgt_keypts[:, None, :, :], dim=-1)  # [bs, num_seeds, num_corr]
        #seedwise_fitness = torch.mean((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]
        seedwise_fitness = torch.linalg.norm((L2_dis < self.inlier_threshold).float(), dim=-1)  # [bs, num_seeds]
        # seedwise_inlier_rmse = torch.sum(L2_dis * (L2_dis < config.inlier_threshold).float(), dim=1)
        batch_best_guess = seedwise_fitness.argmax(dim=1)

        relax_num = self.NS_by_IC
        if relax_num > seedwise_fitness.shape[1]:
            relax_num = seedwise_fitness.shape[1]

        batch_best_guess_relax, batch_best_guess_relax_idx = torch.topk(seedwise_fitness, relax_num)

        # refine the pose by using all the inlier correspondences (done in the post-refinement step)
        final_trans = seedwise_trans.gather(dim=1,
                                            index=batch_best_guess[:, None, None, None].expand(-1, -1, 4, 4)).squeeze(1)
        final_labels = L2_dis.gather(dim=1,
                                     index=batch_best_guess[:, None, None].expand(-1, -1, L2_dis.shape[2])).squeeze(1)
        final_labels1 = final_labels
        final_trans1 = final_trans.reshape(final_trans.shape[0], -1)
        final_labels = (final_labels < self.inlier_threshold).float()
        seedwise_trans_relax = seedwise_trans.gather(dim=1,
                                                     index=batch_best_guess_relax_idx[:, :, None, None].expand(-1, -1, 4, 4))

        return seedwise_trans_relax, seedwise_fitness, final_trans, final_labels, final_labels1, final_trans1

    def cal_leading_eigenvector(self, M, method='power'):
        """
        Calculate the leading eigenvector using power iteration algorithm or torch.symeig
        Input:
            - M:      [bs, num_corr, num_corr] the compatibility matrix
            - method: select different method for calculating the learding eigenvector.
        Output:
            - solution: [bs, num_corr] leading eigenvector
        """
        if method == 'power':
            # power iteration algorithm
            leading_eig = torch.ones_like(M[:, :, 0:1])
            leading_eig_last = leading_eig
            for i in range(self.num_iterations):
                leading_eig = torch.bmm(M, leading_eig)
                leading_eig = leading_eig / (torch.norm(leading_eig, dim=1, keepdim=True) + 1e-6)
                # leading_eig = leading_eig / (torch.max(leading_eig, dim=1, keepdim=True).values)
                if torch.allclose(leading_eig, leading_eig_last):
                    break
                leading_eig_last = leading_eig
            leading_eig = leading_eig.squeeze(-1)
            return leading_eig
        elif method == 'eig':  # cause NaN during back-prop
            e, v = torch.symeig(M, eigenvectors=True)
            leading_eig = v[:, :, -1]
            return leading_eig
        else:
            exit(-1)

    def cal_confidence(self, M, leading_eig, method='eig_value'):
        """
        Calculate the confidence of the spectral matching solution based on spectral analysis.
        Input:
            - M:          [bs, num_corr, num_corr] the compatibility matrix
            - leading_eig [bs, num_corr]           the leading eigenvector of matrix M
        Output:
            - confidence
        """
        if method == 'eig_value':
            # max eigenvalue as the confidence (Rayleigh quotient)
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            confidence = max_eig_value.squeeze(-1)
            return confidence
        elif method == 'eig_value_ratio':
            # max eigenvalue / second max eigenvalue as the confidence
            max_eig_value = (leading_eig[:, None, :] @ M @ leading_eig[:, :, None]) / (
                    leading_eig[:, None, :] @ leading_eig[:, :, None])
            # compute the second largest eigen-value
            B = M - max_eig_value * leading_eig[:, :, None] @ leading_eig[:, None, :]
            solution = torch.ones_like(B[:, :, 0:1])
            for i in range(self.num_iterations):
                solution = torch.bmm(B, solution)
                solution = solution / (torch.norm(solution, dim=1, keepdim=True) + 1e-6)
            solution = solution.squeeze(-1)
            second_eig = solution
            second_eig_value = (second_eig[:, None, :] @ B @ second_eig[:, :, None]) / (
                    second_eig[:, None, :] @ second_eig[:, :, None])
            confidence = max_eig_value / second_eig_value
            return confidence
        elif method == 'xMx':
            # max xMx as the confidence (x is the binary solution)
            # rank = torch.argsort(leading_eig, dim=1, descending=True)[:, 0:int(M.shape[1]*self.ratio)]
            # binary_sol = torch.zeros_like(leading_eig)
            # binary_sol[0, rank[0]] = 1
            confidence = leading_eig[:, None, :] @ M @ leading_eig[:, :, None]
            confidence = confidence.squeeze(-1) / M.shape[1]
            return confidence

    def post_refinement(self, initial_trans, src_keypts, tgt_keypts, it_num, weights=None):
        """
        Perform post refinement using the initial transformation matrix, only adopted during testing.
        Input
            - initial_trans: [bs, 4, 4]
            - src_keypts:    [bs, num_corr, 3]
            - tgt_keypts:    [bs, num_corr, 3]
            - weights:       [bs, num_corr]
        Output:
            - final_trans:   [bs, 4, 4]
        """
        #assert initial_trans.shape[0] == 1
        if self.inlier_threshold == 0.10:  # for 3DMatch
            inlier_threshold_list = [0.10] * it_num
        else:  # for KITTI
            inlier_threshold_list = [1.2] * it_num

        previous_inlier_num = 0
        for inlier_threshold in inlier_threshold_list:
            warped_src_keypts = transform(src_keypts, initial_trans)
            L2_dis = torch.norm(warped_src_keypts - tgt_keypts, dim=-1)
            pred_inlier = (L2_dis < inlier_threshold)[0]  # assume bs = 1
            inlier_num = torch.sum(pred_inlier)
            if abs(int(inlier_num - previous_inlier_num)) < 1:
                break
            else:
                previous_inlier_num = inlier_num
            initial_trans = rigid_transform_3d(
                A=src_keypts[:, pred_inlier, :],
                B=tgt_keypts[:, pred_inlier, :],
                ## https://link.springer.com/article/10.1007/s10589-014-9643-2
                # weights=None,
                weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier],
                # weights=((1-L2_dis/inlier_threshold)**2)[:, pred_inlier]
            )
        return initial_trans
