"""
@Description :   局部特征匹配数据集，返回数据集中所有的成对碎片和它们的匹配方式（对应点的 mask）
@Author      :   tqychy 
@Time        :   2025/02/14 10:31:17
"""
import sys

sys.path.append("./")
import numpy as np
import torch

from dataset.base_dataset import BaseDataSet


class MatchingDataset(BaseDataSet):
    def __init__(self, data_path, *args, calc_adjs=True):
        super().__init__(data_path, *args, calc_adjs=calc_adjs)
        self.gt_pairs = np.array(self.data['GT_pairs'])
        self.source_ind = self.data['source_ind']
        self.target_ind = self.data['target_ind']

    def __len__(self):
        return len(self.gt_pairs)

    def __getitem__(self, idx):
        idx_s, idx_t = self.gt_pairs[idx]
        data_s = super().__getitem__(idx_s)
        data_t = super().__getitem__(idx_t)

        # 生成匹配mask
        mask = torch.zeros(
            (self.max_points, self.max_points), dtype=torch.bool)
        # 这个条件不满足说明我们是在 DeNoise 数据集里，不用 mask
        if idx < len(self.source_ind):
            mask[self.source_ind[idx], self.target_ind[idx]] = True

        return (
            (mask, self.long[idx_s], self.long[idx_t], idx_s, idx_t),
            (data_s['img'], data_t['img']),
            (data_s['full_pcd'], data_t['full_pcd']),
            (data_s['c_input'], data_t['c_input']),
            (data_s['t_input'], data_t['t_input']),
            (data_s['adj'], data_t['adj']),
            (data_s['factor'], data_t['factor']),
            (torch.zeros((1, 1)), torch.zeros((1, 1)))  # att_mask占位符
        )
    

class DeNoiseDataset(MatchingDataset):
    """
    局部特征匹配（全局匹配后去噪）使用的数据集
    """
    def __init__(self, data_path, *args, adjs=None, calc_adjs=True):
        real_calc_adjs = calc_adjs and (adjs is None)
        super().__init__(data_path, *args, calc_adjs=real_calc_adjs)
        if calc_adjs and (adjs is not None):
            self.adj_all = adjs

    def set_noisy_pairs(self, noisy_pairs, idx_convert):
        self.gt_pairs = []
        self.noisy_pairs = noisy_pairs
        for local_idx1, local_idx2 in self.noisy_pairs:
            local_idx1, local_idx2 = local_idx1.item(), local_idx2.item()
            global_idx1, global_idx2 = idx_convert[local_idx1], idx_convert[local_idx2]
            self.gt_pairs.append((global_idx1, global_idx2))

    def __getitem__(self, idx):
        res = super().__getitem__(idx)
        local_idx1, local_idx2 = self.noisy_pairs[idx]
        local_idx1, local_idx2 = local_idx1.item(), local_idx2.item()
        return res + ((local_idx1, local_idx2),)