"""
@Description :   全局特征匹配数据集，返回数据集中所有的成对碎片
@Author      :   tqychy 
@Time        :   2025/02/14 10:55:02
"""
import sys

sys.path.append("./")
import numpy as np

from dataset.base_dataset import BaseDataSet


class PairingDataset(BaseDataSet):
    def __init__(self, data_path, *args, calc_adjs=True):
        super().__init__(data_path, *args, calc_adjs=calc_adjs)
        self.gt_pairs = np.array(self.data['GT_pairs'])

    def __len__(self):
        return len(self.gt_pairs)

    def __getitem__(self, idx):
        # 从父类获取动态处理后的数据
        idx_s, idx_t = self.gt_pairs[idx]
        data_s = super().__getitem__(idx_s)
        data_t = super().__getitem__(idx_t)

        return (
            (idx_s, idx_t),
            (data_s['img'], data_t['img']),
            (data_s['full_pcd'], data_t['full_pcd']),
            (data_s['c_input'], data_t['c_input']),
            (data_s['t_input'], data_t['t_input']),
            (data_s['adj'], data_t['adj']),
            (data_s['factor'], data_t['factor'])
        )


class PairingTestDataset(BaseDataSet):
    def __init__(self, data_path, *args, calc_adjs=True):
        super().__init__(data_path, *args, calc_adjs=calc_adjs)

    def __len__(self):
        # 直接使用父类记录的样本总数
        return self.n

    def __getitem__(self, idx):
        # 直接从父类获取单个样本的动态处理数据
        data = super().__getitem__(idx)
        return (
            data['img'],
            data['full_pcd'],
            data['c_input'],
            data['t_input'],
            data['adj'],
            data['factor']
        )

    def get_gt_pairs(self):
        # 直接访问原始数据中的 GT_pairs
        return self.data['GT_pairs']